diff --git a/src/backend/langflow/database/base.py b/src/backend/langflow/database/base.py index 338298a6b..546c341c1 100644 --- a/src/backend/langflow/database/base.py +++ b/src/backend/langflow/database/base.py @@ -1,21 +1,48 @@ from contextlib import contextmanager -from langflow.settings import settings +import os + from sqlmodel import SQLModel, Session, create_engine from langflow.utils.logger import logger -if settings.database_url and settings.database_url.startswith("sqlite"): - connect_args = {"check_same_thread": False} -else: - connect_args = {} -if not settings.database_url: - raise RuntimeError("No database_url provided") -engine = create_engine(settings.database_url, connect_args=connect_args) + +class Engine: + _instance = None + + @classmethod + def get(cls): + logger.debug("Getting database engine") + if cls._instance is None: + cls.create() + return cls._instance + + @classmethod + def create(cls): + logger.debug("Creating database engine") + from langflow.settings import settings + + if langflow_database_url := os.getenv("LANGFLOW_DATABASE_URL"): + settings.DATABASE_URL = langflow_database_url + logger.debug("Using LANGFLOW_DATABASE_URL") + + if settings.DATABASE_URL and settings.DATABASE_URL.startswith("sqlite"): + connect_args = {"check_same_thread": False} + else: + connect_args = {} + if not settings.DATABASE_URL: + raise RuntimeError("No database_url provided") + cls._instance = create_engine(settings.DATABASE_URL, connect_args=connect_args) + + @classmethod + def update(cls): + logger.debug("Updating database engine") + cls._instance = None + cls.create() def create_db_and_tables(): logger.debug("Creating database and tables") try: - SQLModel.metadata.create_all(engine) + SQLModel.metadata.create_all(Engine.get()) except Exception as exc: logger.error(f"Error creating database and tables: {exc}") raise RuntimeError("Error creating database and tables") from exc @@ -23,7 +50,7 @@ def create_db_and_tables(): # and we need to create the tables again. from sqlalchemy import inspect - inspector = inspect(engine) + inspector = inspect(Engine.get()) if "flow" not in inspector.get_table_names(): logger.error("Something went wrong creating the database and tables.") logger.error("Please check your database settings.") @@ -36,7 +63,7 @@ def create_db_and_tables(): @contextmanager def session_getter(): try: - session = Session(engine) + session = Session(Engine.get()) yield session except Exception as e: print("Session rollback because of exception:", e)