fix: validate and test database connection URLs (#5178)

* test: add unit test for database url validation

* feat: add function to validate database urls

* refactor: use new database url validation function

* fix: ruff errors

* refactor: validate database urls using sqlalchemy

* test: add more cases for database url validation
This commit is contained in:
Ítalo Johnny 2024-12-17 14:29:53 -03:00 committed by GitHub
commit 4be6b04d8c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 116 additions and 60 deletions

View file

@ -16,6 +16,7 @@ from pydantic_settings import BaseSettings, EnvSettingsSource, PydanticBaseSetti
from typing_extensions import override
from langflow.services.settings.constants import VARIABLES_TO_GET_FROM_ENVIRONMENT
from langflow.utils.util_strings import is_valid_database_url
# BASE_COMPONENTS_PATH = str(Path(__file__).parent / "components")
BASE_COMPONENTS_PATH = str(Path(__file__).parent.parent.parent / "components")
@ -240,71 +241,74 @@ class Settings(BaseSettings):
@field_validator("database_url", mode="before")
@classmethod
def set_database_url(cls, value, info):
if not value:
logger.debug("No database_url provided, trying LANGFLOW_DATABASE_URL env variable")
if langflow_database_url := os.getenv("LANGFLOW_DATABASE_URL"):
value = langflow_database_url
logger.debug("Using LANGFLOW_DATABASE_URL env variable.")
if value and not is_valid_database_url(value):
msg = f"Invalid database_url provided: '{value}'"
raise ValueError(msg)
logger.debug("No database_url provided, trying LANGFLOW_DATABASE_URL env variable")
if langflow_database_url := os.getenv("LANGFLOW_DATABASE_URL"):
value = langflow_database_url
logger.debug("Using LANGFLOW_DATABASE_URL env variable.")
else:
logger.debug("No database_url env variable, using sqlite database")
# Originally, we used sqlite:///./langflow.db
# so we need to migrate to the new format
# if there is a database in that location
if not info.data["config_dir"]:
msg = "config_dir not set, please set it or provide a database_url"
raise ValueError(msg)
from langflow.utils.version import get_version_info
from langflow.utils.version import is_pre_release as langflow_is_pre_release
version = get_version_info()["version"]
is_pre_release = langflow_is_pre_release(version)
if info.data["save_db_in_config_dir"]:
database_dir = info.data["config_dir"]
logger.debug(f"Saving database to config_dir: {database_dir}")
else:
logger.debug("No database_url env variable, using sqlite database")
# Originally, we used sqlite:///./langflow.db
# so we need to migrate to the new format
# if there is a database in that location
if not info.data["config_dir"]:
msg = "config_dir not set, please set it or provide a database_url"
raise ValueError(msg)
database_dir = Path(__file__).parent.parent.parent.resolve()
logger.debug(f"Saving database to langflow directory: {database_dir}")
from langflow.utils.version import get_version_info
from langflow.utils.version import is_pre_release as langflow_is_pre_release
version = get_version_info()["version"]
is_pre_release = langflow_is_pre_release(version)
if info.data["save_db_in_config_dir"]:
database_dir = info.data["config_dir"]
logger.debug(f"Saving database to config_dir: {database_dir}")
pre_db_file_name = "langflow-pre.db"
db_file_name = "langflow.db"
new_pre_path = f"{database_dir}/{pre_db_file_name}"
new_path = f"{database_dir}/{db_file_name}"
final_path = None
if is_pre_release:
if Path(new_pre_path).exists():
final_path = new_pre_path
elif Path(new_path).exists() and info.data["save_db_in_config_dir"]:
# We need to copy the current db to the new location
logger.debug("Copying existing database to new location")
copy2(new_path, new_pre_path)
logger.debug(f"Copied existing database to {new_pre_path}")
elif Path(f"./{db_file_name}").exists() and info.data["save_db_in_config_dir"]:
logger.debug("Copying existing database to new location")
copy2(f"./{db_file_name}", new_pre_path)
logger.debug(f"Copied existing database to {new_pre_path}")
else:
database_dir = Path(__file__).parent.parent.parent.resolve()
logger.debug(f"Saving database to langflow directory: {database_dir}")
logger.debug(f"Creating new database at {new_pre_path}")
final_path = new_pre_path
elif Path(new_path).exists():
logger.debug(f"Database already exists at {new_path}, using it")
final_path = new_path
elif Path(f"./{db_file_name}").exists():
try:
logger.debug("Copying existing database to new location")
copy2(f"./{db_file_name}", new_path)
logger.debug(f"Copied existing database to {new_path}")
except Exception: # noqa: BLE001
logger.exception("Failed to copy database, using default path")
new_path = f"./{db_file_name}"
else:
final_path = new_path
pre_db_file_name = "langflow-pre.db"
db_file_name = "langflow.db"
new_pre_path = f"{database_dir}/{pre_db_file_name}"
new_path = f"{database_dir}/{db_file_name}"
final_path = None
if is_pre_release:
if Path(new_pre_path).exists():
final_path = new_pre_path
elif Path(new_path).exists() and info.data["save_db_in_config_dir"]:
# We need to copy the current db to the new location
logger.debug("Copying existing database to new location")
copy2(new_path, new_pre_path)
logger.debug(f"Copied existing database to {new_pre_path}")
elif Path(f"./{db_file_name}").exists() and info.data["save_db_in_config_dir"]:
logger.debug("Copying existing database to new location")
copy2(f"./{db_file_name}", new_pre_path)
logger.debug(f"Copied existing database to {new_pre_path}")
else:
logger.debug(f"Creating new database at {new_pre_path}")
final_path = new_pre_path
elif Path(new_path).exists():
logger.debug(f"Database already exists at {new_path}, using it")
final_path = new_path
elif Path(f"./{db_file_name}").exists():
try:
logger.debug("Copying existing database to new location")
copy2(f"./{db_file_name}", new_path)
logger.debug(f"Copied existing database to {new_path}")
except Exception: # noqa: BLE001
logger.exception("Failed to copy database, using default path")
new_path = f"./{db_file_name}"
else:
final_path = new_path
if final_path is None:
final_path = new_pre_path if is_pre_release else new_path
if final_path is None:
final_path = new_pre_path if is_pre_release else new_path
value = f"sqlite:///{final_path}"
value = f"sqlite:///{final_path}"
return value

View file

@ -1,3 +1,5 @@
from sqlalchemy.engine import make_url
from langflow.utils import constants
@ -28,3 +30,23 @@ def truncate_long_strings(data, max_length=None):
truncate_long_strings(item, max_length)
return data
def is_valid_database_url(url: str) -> bool:
"""Validate database connection URLs compatible with SQLAlchemy.
Args:
url (str): Database connection URL to validate
Returns:
bool: True if URL is valid, False otherwise
"""
try:
parsed_url = make_url(url)
parsed_url.get_dialect()
parsed_url.get_driver_name()
except Exception: # noqa: BLE001
return False
return True

View file

@ -0,0 +1,30 @@
import pytest
from langflow.utils import util_strings
@pytest.mark.parametrize(
("value", "expected"),
[
("sqlite:///test.db", True),
("sqlite:////var/folders/test.db", True),
("sqlite:///:memory:", True),
("sqlite+aiosqlite:////var/folders/test.db", True),
("postgresql://user:pass@localhost/dbname", True),
("postgresql+psycopg2://scott:tiger@localhost:5432/mydatabase", True),
("postgresql+pg8000://dbuser:kx%40jj5%2Fg@pghost10/appdb", True),
("mysql://user:pass@localhost/dbname", True),
("mysql+mysqldb://scott:tiger@localhost/foo", True),
("mysql+pymysql://scott:tiger@localhost/foo", True),
("oracle://scott:tiger@127.0.0.1:1521/?service_name=freepdb1", True),
("oracle+cx_oracle://scott:tiger@tnsalias", True),
("oracle+oracledb://scott:tiger@127.0.0.1:1521/?service_name=freepdb1", True),
("", False),
(" invalid ", False),
("not_a_url", False),
(None, False),
("invalid://database", False),
("invalid://:@/test", False),
],
)
def test_is_valid_database_url(value, expected):
assert util_strings.is_valid_database_url(value) == expected