diff --git a/src/backend/langflow/initial_setup/setup.py b/src/backend/langflow/initial_setup/setup.py index 3b9ed8d5e..1f24ddab1 100644 --- a/src/backend/langflow/initial_setup/setup.py +++ b/src/backend/langflow/initial_setup/setup.py @@ -1,12 +1,12 @@ -import json from datetime import datetime from pathlib import Path +import orjson from loguru import logger from sqlmodel import select from langflow.services.database.models.flow.model import Flow -from langflow.services.deps import get_session +from langflow.services.deps import session_scope STARTER_FOLDER_NAME = "Starter Projects" @@ -14,60 +14,102 @@ STARTER_FOLDER_NAME = "Starter Projects" # In the folder ./starter_projects we have a few JSON files that represent # starter projects. We want to load these into the database so that users # can use them as a starting point for their own projects. + + def load_starter_projects(): - # Load the starter projects from the JSON files - # using Pathlib's glob method starter_projects = [] folder = Path(__file__).parent / "starter_projects" for file in folder.glob("*.json"): - with open(file, "r") as f: - starter_projects.append(json.load(f)) - logger.info(f"Loaded starter project {file}") + project = orjson.loads(file.read_text()) + starter_projects.append(project) + logger.info(f"Loaded starter project {file}") return starter_projects -# We want to load the starter projects into the database +def get_project_data(project): + project_name = project.get("name") + project_description = project.get("description") + project_is_component = project.get("is_component") + project_updated_at = project.get("updated_at") + updated_at_datetime = datetime.strptime(project_updated_at, "%Y-%m-%dT%H:%M:%S.%f") + project_data = project.get("data") + return ( + project_name, + project_description, + project_is_component, + updated_at_datetime, + project_data, + ) + + +def update_existing_project( + existing_project, + project_name, + project_description, + project_is_component, + updated_at_datetime, + project_data, +): + logger.info(f"Updating starter project {project_name}") + existing_project.data = project_data + existing_project.folder = STARTER_FOLDER_NAME + existing_project.description = project_description + existing_project.is_component = project_is_component + existing_project.updated_at = updated_at_datetime + + +def create_new_project( + session, + project_name, + project_description, + project_is_component, + updated_at_datetime, + project_data, +): + logger.info(f"Creating starter project {project_name}") + new_project = Flow( + name=project_name, + description=project_description, + is_component=project_is_component, + updated_at=updated_at_datetime, + folder=STARTER_FOLDER_NAME, + data=project_data, + ) + session.add(new_project) + + def create_or_update_starter_projects(): - session = next(get_session()) - starter_projects = load_starter_projects() - for project in starter_projects: - # Check if the project already exists in the database - project_name = project.get("name") - project_description = project.get("description") - project_is_component = project.get("is_component") - project_updated_at = project.get("updated_at") - # 2024-03-05T21:59:59.738081 - updated_at_datetime = datetime.strptime( - project_updated_at, "%Y-%m-%dT%H:%M:%S.%f" - ) - project_data = project.get("data") - if project_name and project_data: - existing_project = session.exec( - select(Flow).where( - Flow.name == project_name, Flow.folder == STARTER_FOLDER_NAME - ) - ).first() - if existing_project: - logger.info(f"Updating starter project {project_name}") - existing_project.data = project_data - existing_project.folder = STARTER_FOLDER_NAME - existing_project.description = project_description - existing_project.is_component = project_is_component - existing_project.updated_at = updated_at_datetime - # Now we need to update the project in the database - session.add(existing_project) - else: - logger.info(f"Creating starter project {project_name}") - session.add( - Flow( - name=project_name, - description=project_description, - is_component=project_is_component, - updated_at=updated_at_datetime, - folder=STARTER_FOLDER_NAME, - data=project_data, + with session_scope() as session: + starter_projects = load_starter_projects() + for project in starter_projects: + ( + project_name, + project_description, + project_is_component, + updated_at_datetime, + project_data, + ) = get_project_data(project) + if project_name and project_data: + existing_project = session.exec( + select(Flow).where( + Flow.name == project_name, Flow.folder == STARTER_FOLDER_NAME + ) + ).first() + if existing_project: + update_existing_project( + existing_project, + project_name, + project_description, + project_is_component, + updated_at_datetime, + project_data, + ) + else: + create_new_project( + session, + project_name, + project_description, + project_is_component, + updated_at_datetime, + project_data, ) - ) - session.commit() - session.close() - logger.info("Starter projects loaded into database") diff --git a/src/backend/langflow/services/deps.py b/src/backend/langflow/services/deps.py index 19f3dcbf0..7d2338b04 100644 --- a/src/backend/langflow/services/deps.py +++ b/src/backend/langflow/services/deps.py @@ -1,3 +1,4 @@ +from contextlib import contextmanager from typing import TYPE_CHECKING, Generator from langflow.services import ServiceType, service_manager @@ -54,6 +55,19 @@ def get_session() -> Generator["Session", None, None]: yield from db_service.get_session() +@contextmanager +def session_scope(): + session = next(get_session()) + try: + yield session + session.commit() + except: + session.rollback() + raise + finally: + session.close() + + def get_cache_service() -> "BaseCacheService": return service_manager.get(ServiceType.CACHE_SERVICE) # type: ignore