diff --git a/src/backend/langflow/interface/custom/custom_component.py b/src/backend/langflow/interface/custom/custom_component.py index 0d93f8d75..fdfef52f8 100644 --- a/src/backend/langflow/interface/custom/custom_component.py +++ b/src/backend/langflow/interface/custom/custom_component.py @@ -3,6 +3,7 @@ from fastapi import HTTPException from langflow.interface.custom.constants import CUSTOM_COMPONENT_SUPPORTED_TYPES from langflow.interface.custom.component import Component from langflow.interface.custom.directory_reader import DirectoryReader +from langflow.services.utils import get_db_manager from langflow.utils import validate @@ -159,7 +160,8 @@ class CustomComponent(Component, extra=Extra.allow): from langflow.processing.process import build_sorted_vertices_with_caching from langflow.processing.process import process_tweaks - with session_getter() as session: + db_manager = get_db_manager() + with session_getter(db_manager) as session: graph_data = flow.data if (flow := session.get(Flow, flow_id)) else None if not graph_data: raise ValueError(f"Flow {flow_id} not found") @@ -169,7 +171,8 @@ class CustomComponent(Component, extra=Extra.allow): def list_flows(self, *, get_session: Optional[Callable] = None) -> List[Flow]: get_session = get_session or session_getter - with get_session() as session: + db_manager = get_db_manager() + with get_session(db_manager) as session: flows = session.query(Flow).all() return flows @@ -182,8 +185,8 @@ class CustomComponent(Component, extra=Extra.allow): get_session: Optional[Callable] = None, ) -> Flow: get_session = get_session or session_getter - - with get_session() as session: + db_manager = get_db_manager() + with get_session(db_manager) as session: if flow_id: flow = session.query(Flow).get(flow_id) elif flow_name: diff --git a/src/backend/langflow/services/database/factory.py b/src/backend/langflow/services/database/factory.py index 187a29fdd..d98414382 100644 --- a/src/backend/langflow/services/database/factory.py +++ b/src/backend/langflow/services/database/factory.py @@ -12,4 +12,6 @@ class DatabaseManagerFactory(ServiceFactory): def create(self, settings_service: "SettingsManager"): # Here you would have logic to create and configure a DatabaseManager + if not settings_service.settings.DATABASE_URL: + raise ValueError("No database URL provided") return DatabaseManager(settings_service.settings.DATABASE_URL) diff --git a/src/backend/langflow/services/settings/manager.py b/src/backend/langflow/services/settings/manager.py index 598efe2d8..a357c4804 100644 --- a/src/backend/langflow/services/settings/manager.py +++ b/src/backend/langflow/services/settings/manager.py @@ -13,7 +13,7 @@ class SettingsManager(Service): self.settings = settings @classmethod - def load_settings_from_yaml(cls, file_path: str) -> Settings: + def load_settings_from_yaml(cls, file_path: str) -> "SettingsManager": # Check if a string is a valid path or a file name if "/" not in file_path: # Get current path diff --git a/src/backend/langflow/services/utils.py b/src/backend/langflow/services/utils.py index 07c67dfbe..049e82c0f 100644 --- a/src/backend/langflow/services/utils.py +++ b/src/backend/langflow/services/utils.py @@ -9,6 +9,10 @@ def get_settings_manager() -> "SettingsManager": return service_manager.get(ServiceType.SETTINGS_MANAGER) +def get_db_manager(): + return service_manager.get(ServiceType.DATABASE_MANAGER) + + def get_session(): db_manager = service_manager.get(ServiceType.DATABASE_MANAGER) yield from db_manager.get_session()