diff --git a/src/backend/base/langflow/services/settings/base.py b/src/backend/base/langflow/services/settings/base.py index f9b9db365..039939913 100644 --- a/src/backend/base/langflow/services/settings/base.py +++ b/src/backend/base/langflow/services/settings/base.py @@ -16,6 +16,7 @@ from pydantic_settings import BaseSettings, EnvSettingsSource, PydanticBaseSetti from typing_extensions import override from langflow.services.settings.constants import VARIABLES_TO_GET_FROM_ENVIRONMENT +from langflow.utils.util_strings import is_valid_database_url # BASE_COMPONENTS_PATH = str(Path(__file__).parent / "components") BASE_COMPONENTS_PATH = str(Path(__file__).parent.parent.parent / "components") @@ -240,71 +241,74 @@ class Settings(BaseSettings): @field_validator("database_url", mode="before") @classmethod def set_database_url(cls, value, info): - if not value: - logger.debug("No database_url provided, trying LANGFLOW_DATABASE_URL env variable") - if langflow_database_url := os.getenv("LANGFLOW_DATABASE_URL"): - value = langflow_database_url - logger.debug("Using LANGFLOW_DATABASE_URL env variable.") + if value and not is_valid_database_url(value): + msg = f"Invalid database_url provided: '{value}'" + raise ValueError(msg) + + logger.debug("No database_url provided, trying LANGFLOW_DATABASE_URL env variable") + if langflow_database_url := os.getenv("LANGFLOW_DATABASE_URL"): + value = langflow_database_url + logger.debug("Using LANGFLOW_DATABASE_URL env variable.") + else: + logger.debug("No database_url env variable, using sqlite database") + # Originally, we used sqlite:///./langflow.db + # so we need to migrate to the new format + # if there is a database in that location + if not info.data["config_dir"]: + msg = "config_dir not set, please set it or provide a database_url" + raise ValueError(msg) + + from langflow.utils.version import get_version_info + from langflow.utils.version import is_pre_release as langflow_is_pre_release + + version = get_version_info()["version"] + is_pre_release = langflow_is_pre_release(version) + + if info.data["save_db_in_config_dir"]: + database_dir = info.data["config_dir"] + logger.debug(f"Saving database to config_dir: {database_dir}") else: - logger.debug("No database_url env variable, using sqlite database") - # Originally, we used sqlite:///./langflow.db - # so we need to migrate to the new format - # if there is a database in that location - if not info.data["config_dir"]: - msg = "config_dir not set, please set it or provide a database_url" - raise ValueError(msg) + database_dir = Path(__file__).parent.parent.parent.resolve() + logger.debug(f"Saving database to langflow directory: {database_dir}") - from langflow.utils.version import get_version_info - from langflow.utils.version import is_pre_release as langflow_is_pre_release - - version = get_version_info()["version"] - is_pre_release = langflow_is_pre_release(version) - - if info.data["save_db_in_config_dir"]: - database_dir = info.data["config_dir"] - logger.debug(f"Saving database to config_dir: {database_dir}") + pre_db_file_name = "langflow-pre.db" + db_file_name = "langflow.db" + new_pre_path = f"{database_dir}/{pre_db_file_name}" + new_path = f"{database_dir}/{db_file_name}" + final_path = None + if is_pre_release: + if Path(new_pre_path).exists(): + final_path = new_pre_path + elif Path(new_path).exists() and info.data["save_db_in_config_dir"]: + # We need to copy the current db to the new location + logger.debug("Copying existing database to new location") + copy2(new_path, new_pre_path) + logger.debug(f"Copied existing database to {new_pre_path}") + elif Path(f"./{db_file_name}").exists() and info.data["save_db_in_config_dir"]: + logger.debug("Copying existing database to new location") + copy2(f"./{db_file_name}", new_pre_path) + logger.debug(f"Copied existing database to {new_pre_path}") else: - database_dir = Path(__file__).parent.parent.parent.resolve() - logger.debug(f"Saving database to langflow directory: {database_dir}") + logger.debug(f"Creating new database at {new_pre_path}") + final_path = new_pre_path + elif Path(new_path).exists(): + logger.debug(f"Database already exists at {new_path}, using it") + final_path = new_path + elif Path(f"./{db_file_name}").exists(): + try: + logger.debug("Copying existing database to new location") + copy2(f"./{db_file_name}", new_path) + logger.debug(f"Copied existing database to {new_path}") + except Exception: # noqa: BLE001 + logger.exception("Failed to copy database, using default path") + new_path = f"./{db_file_name}" + else: + final_path = new_path - pre_db_file_name = "langflow-pre.db" - db_file_name = "langflow.db" - new_pre_path = f"{database_dir}/{pre_db_file_name}" - new_path = f"{database_dir}/{db_file_name}" - final_path = None - if is_pre_release: - if Path(new_pre_path).exists(): - final_path = new_pre_path - elif Path(new_path).exists() and info.data["save_db_in_config_dir"]: - # We need to copy the current db to the new location - logger.debug("Copying existing database to new location") - copy2(new_path, new_pre_path) - logger.debug(f"Copied existing database to {new_pre_path}") - elif Path(f"./{db_file_name}").exists() and info.data["save_db_in_config_dir"]: - logger.debug("Copying existing database to new location") - copy2(f"./{db_file_name}", new_pre_path) - logger.debug(f"Copied existing database to {new_pre_path}") - else: - logger.debug(f"Creating new database at {new_pre_path}") - final_path = new_pre_path - elif Path(new_path).exists(): - logger.debug(f"Database already exists at {new_path}, using it") - final_path = new_path - elif Path(f"./{db_file_name}").exists(): - try: - logger.debug("Copying existing database to new location") - copy2(f"./{db_file_name}", new_path) - logger.debug(f"Copied existing database to {new_path}") - except Exception: # noqa: BLE001 - logger.exception("Failed to copy database, using default path") - new_path = f"./{db_file_name}" - else: - final_path = new_path + if final_path is None: + final_path = new_pre_path if is_pre_release else new_path - if final_path is None: - final_path = new_pre_path if is_pre_release else new_path - - value = f"sqlite:///{final_path}" + value = f"sqlite:///{final_path}" return value diff --git a/src/backend/base/langflow/utils/util_strings.py b/src/backend/base/langflow/utils/util_strings.py index 51802e85d..13f4f3870 100644 --- a/src/backend/base/langflow/utils/util_strings.py +++ b/src/backend/base/langflow/utils/util_strings.py @@ -1,3 +1,5 @@ +from sqlalchemy.engine import make_url + from langflow.utils import constants @@ -28,3 +30,23 @@ def truncate_long_strings(data, max_length=None): truncate_long_strings(item, max_length) return data + + +def is_valid_database_url(url: str) -> bool: + """Validate database connection URLs compatible with SQLAlchemy. + + Args: + url (str): Database connection URL to validate + + Returns: + bool: True if URL is valid, False otherwise + """ + try: + parsed_url = make_url(url) + parsed_url.get_dialect() + parsed_url.get_driver_name() + + except Exception: # noqa: BLE001 + return False + + return True diff --git a/src/backend/tests/unit/utils/test_util_strings.py b/src/backend/tests/unit/utils/test_util_strings.py new file mode 100644 index 000000000..820f14642 --- /dev/null +++ b/src/backend/tests/unit/utils/test_util_strings.py @@ -0,0 +1,30 @@ +import pytest +from langflow.utils import util_strings + + +@pytest.mark.parametrize( + ("value", "expected"), + [ + ("sqlite:///test.db", True), + ("sqlite:////var/folders/test.db", True), + ("sqlite:///:memory:", True), + ("sqlite+aiosqlite:////var/folders/test.db", True), + ("postgresql://user:pass@localhost/dbname", True), + ("postgresql+psycopg2://scott:tiger@localhost:5432/mydatabase", True), + ("postgresql+pg8000://dbuser:kx%40jj5%2Fg@pghost10/appdb", True), + ("mysql://user:pass@localhost/dbname", True), + ("mysql+mysqldb://scott:tiger@localhost/foo", True), + ("mysql+pymysql://scott:tiger@localhost/foo", True), + ("oracle://scott:tiger@127.0.0.1:1521/?service_name=freepdb1", True), + ("oracle+cx_oracle://scott:tiger@tnsalias", True), + ("oracle+oracledb://scott:tiger@127.0.0.1:1521/?service_name=freepdb1", True), + ("", False), + (" invalid ", False), + ("not_a_url", False), + (None, False), + ("invalid://database", False), + ("invalid://:@/test", False), + ], +) +def test_is_valid_database_url(value, expected): + assert util_strings.is_valid_database_url(value) == expected