diff --git a/src/backend/base/langflow/initial_setup/setup.py b/src/backend/base/langflow/initial_setup/setup.py index b3a0e3549..7a356cd5c 100644 --- a/src/backend/base/langflow/initial_setup/setup.py +++ b/src/backend/base/langflow/initial_setup/setup.py @@ -1,11 +1,10 @@ +import asyncio import copy import json import shutil -import time from collections import defaultdict from copy import deepcopy from datetime import datetime, timezone -from pathlib import Path from uuid import UUID import anyio @@ -14,6 +13,7 @@ from aiofile import async_open from emoji import demojize, purely_emoji from loguru import logger from sqlalchemy.exc import NoResultFound +from sqlalchemy.orm import selectinload from sqlmodel import select from langflow.base.constants import FIELD_FORMAT_ATTRIBUTES, NODE_FORMAT_ATTRIBUTES, ORJSON_OPTIONS @@ -31,7 +31,6 @@ from langflow.services.deps import ( get_settings_service, get_storage_service, get_variable_service, - session_scope, ) from langflow.template.field.prompt import DEFAULT_PROMPT_INTUT_TYPES from langflow.utils.util import escape_json_dump @@ -362,13 +361,14 @@ def log_node_changes(node_changes_log) -> None: logger.debug("\n".join(formatted_messages)) -def load_starter_projects(retries=3, delay=1) -> list[tuple[Path, dict]]: +async def load_starter_projects(retries=3, delay=1) -> list[tuple[anyio.Path, dict]]: starter_projects = [] - folder = Path(__file__).parent / "starter_projects" - for file in folder.glob("*.json"): + folder = anyio.Path(__file__).parent / "starter_projects" + async for file in folder.glob("*.json"): attempt = 0 while attempt < retries: - content = file.read_text(encoding="utf-8") + async with async_open(str(file), "r", encoding="utf-8") as f: + content = await f.read() try: project = orjson.loads(content) starter_projects.append((file, project)) @@ -379,27 +379,27 @@ def load_starter_projects(retries=3, delay=1) -> list[tuple[Path, dict]]: if attempt >= retries: msg = f"Error loading starter project {file}: {e}" raise ValueError(msg) from e - time.sleep(delay) # Wait before retrying + await asyncio.sleep(delay) # Wait before retrying return starter_projects -def copy_profile_pictures() -> None: +async def copy_profile_pictures() -> None: config_dir = get_storage_service().settings_service.settings.config_dir if config_dir is None: msg = "Config dir is not set in the settings" raise ValueError(msg) - origin = Path(__file__).parent / "profile_pictures" - target = Path(config_dir) / "profile_pictures" + origin = anyio.Path(__file__).parent / "profile_pictures" + target = anyio.Path(config_dir) / "profile_pictures" - if not origin.exists(): + if not await origin.exists(): msg = f"The source folder '{origin}' does not exist." raise ValueError(msg) - if not target.exists(): - target.mkdir(parents=True) + if not await target.exists(): + await target.mkdir(parents=True) try: - shutil.copytree(origin, target, dirs_exist_ok=True) + await asyncio.to_thread(shutil.copytree, str(origin), str(target), dirs_exist_ok=True) logger.debug(f"Folder copied from '{origin}' to '{target}'") except Exception: # noqa: BLE001 @@ -434,9 +434,10 @@ def get_project_data(project): ) -def update_project_file(project_path: Path, project: dict, updated_project_data) -> None: +async def update_project_file(project_path: anyio.Path, project: dict, updated_project_data) -> None: project["data"] = updated_project_data - project_path.write_text(orjson.dumps(project, option=ORJSON_OPTIONS).decode(), encoding="utf-8") + async with async_open(str(project_path), "w", encoding="utf-8") as f: + await f.write(orjson.dumps(project, option=ORJSON_OPTIONS).decode()) logger.info(f"Updated starter project {project['name']} file") @@ -490,31 +491,34 @@ def create_new_project( session.add(db_flow) -def get_all_flows_similar_to_project(session, folder_id): - return session.exec(select(Folder).where(Folder.id == folder_id)).first().flows +async def get_all_flows_similar_to_project(session, folder_id): + stmt = select(Folder).options(selectinload(Folder.flows)).where(Folder.id == folder_id) + return (await session.exec(stmt)).first().flows -def delete_start_projects(session, folder_id) -> None: - flows = session.exec(select(Folder).where(Folder.id == folder_id)).first().flows +async def delete_start_projects(session, folder_id) -> None: + flows = await get_all_flows_similar_to_project(session, folder_id) for flow in flows: - session.delete(flow) - session.commit() + await session.delete(flow) + await session.commit() -def folder_exists(session, folder_name): - folder = session.exec(select(Folder).where(Folder.name == folder_name)).first() +async def folder_exists(session, folder_name): + stmt = select(Folder).where(Folder.name == folder_name) + folder = (await session.exec(stmt)).first() return folder is not None -def create_starter_folder(session): - if not folder_exists(session, STARTER_FOLDER_NAME): +async def create_starter_folder(session): + if not await folder_exists(session, STARTER_FOLDER_NAME): new_folder = FolderCreate(name=STARTER_FOLDER_NAME, description=STARTER_FOLDER_DESCRIPTION) db_folder = Folder.model_validate(new_folder, from_attributes=True) session.add(db_folder) - session.commit() - session.refresh(db_folder) + await session.commit() + await session.refresh(db_folder) return db_folder - return session.exec(select(Folder).where(Folder.name == STARTER_FOLDER_NAME)).first() + stmt = select(Folder).where(Folder.name == STARTER_FOLDER_NAME) + return (await session.exec(stmt)).first() def _is_valid_uuid(val): @@ -606,12 +610,12 @@ async def find_existing_flow(session, flow_id, flow_endpoint_name): return None -def create_or_update_starter_projects(all_types_dict: dict) -> None: - with session_scope() as session: - new_folder = create_starter_folder(session) - starter_projects = load_starter_projects() - delete_start_projects(session, new_folder.id) - copy_profile_pictures() +async def create_or_update_starter_projects(all_types_dict: dict) -> None: + async with async_session_scope() as session: + new_folder = await create_starter_folder(session) + starter_projects = await load_starter_projects() + await delete_start_projects(session, new_folder.id) + await copy_profile_pictures() for project_path, project in starter_projects: ( project_name, @@ -632,10 +636,10 @@ def create_or_update_starter_projects(all_types_dict: dict) -> None: project_data = updated_project_data # We also need to update the project data in the file - update_project_file(project_path, project, updated_project_data) + await update_project_file(project_path, project, updated_project_data) if project_name and project_data: - for existing_project in get_all_flows_similar_to_project(session, new_folder.id): - session.delete(existing_project) + for existing_project in await get_all_flows_similar_to_project(session, new_folder.id): + await session.delete(existing_project) create_new_project( session=session, diff --git a/src/backend/base/langflow/main.py b/src/backend/base/langflow/main.py index b25b4b240..678726b5c 100644 --- a/src/backend/base/langflow/main.py +++ b/src/backend/base/langflow/main.py @@ -105,7 +105,7 @@ def get_lifespan(*, fix_migration=False, version=None): setup_llm_caching() await initialize_super_user_if_needed() all_types_dict = await get_and_cache_all_types_dict(get_settings_service()) - await asyncio.to_thread(create_or_update_starter_projects, all_types_dict) + await create_or_update_starter_projects(all_types_dict) telemetry_service.start() await load_flows_from_directory() yield diff --git a/src/backend/tests/integration/test_misc.py b/src/backend/tests/integration/test_misc.py index 28afbd19b..cf0c10f9d 100644 --- a/src/backend/tests/integration/test_misc.py +++ b/src/backend/tests/integration/test_misc.py @@ -5,7 +5,7 @@ from fastapi import status from httpx import AsyncClient from langflow.graph.schema import RunOutputs from langflow.initial_setup.setup import load_starter_projects -from langflow.load import run_flow_from_json +from langflow.load.load import arun_flow_from_json @pytest.mark.api_key_required @@ -78,9 +78,9 @@ async def test_run_with_inputs_and_outputs(client, starter_project, created_api_ @pytest.mark.noclient @pytest.mark.api_key_required -def test_run_flow_from_json_object(): +async def test_run_flow_from_json_object(): """Test loading a flow from a json file and applying tweaks.""" - project = next(project for _, project in load_starter_projects() if "Basic Prompting" in project["name"]) - results = run_flow_from_json(project, input_value="test", fallback_to_env_vars=True) + project = next(project for _, project in await load_starter_projects() if "Basic Prompting" in project["name"]) + results = await arun_flow_from_json(project, input_value="test", fallback_to_env_vars=True) assert results is not None assert all(isinstance(result, RunOutputs) for result in results) diff --git a/src/backend/tests/performance/test_server_init.py b/src/backend/tests/performance/test_server_init.py index 59fdd5962..a97f4f489 100644 --- a/src/backend/tests/performance/test_server_init.py +++ b/src/backend/tests/performance/test_server_init.py @@ -1,5 +1,3 @@ -import asyncio - import pytest from langflow.services.deps import get_settings_service @@ -72,7 +70,7 @@ async def test_create_starter_projects(): await initialize_services(fix_migration=False) settings_service = get_settings_service() types_dict = await get_and_cache_all_types_dict(settings_service) - await asyncio.to_thread(create_or_update_starter_projects, types_dict) + await create_or_update_starter_projects(types_dict) assert "test_performance.db" in settings_service.settings.database_url diff --git a/src/backend/tests/unit/graph/test_graph.py b/src/backend/tests/unit/graph/test_graph.py index af9f009cc..2feff41ed 100644 --- a/src/backend/tests/unit/graph/test_graph.py +++ b/src/backend/tests/unit/graph/test_graph.py @@ -256,8 +256,8 @@ def test_update_source_handle(): assert updated_edge["data"]["sourceHandle"]["id"] == "last_node" -def test_serialize_graph(): - starter_projects = load_starter_projects() +async def test_serialize_graph(): + starter_projects = await load_starter_projects() data = starter_projects[0][1]["data"] graph = Graph.from_payload(data) assert isinstance(graph, Graph) diff --git a/src/backend/tests/unit/test_database.py b/src/backend/tests/unit/test_database.py index d228d3e05..b8e708396 100644 --- a/src/backend/tests/unit/test_database.py +++ b/src/backend/tests/unit/test_database.py @@ -1,4 +1,3 @@ -import asyncio import json from typing import NamedTuple from uuid import UUID, uuid4 @@ -605,7 +604,7 @@ async def test_delete_nonexistent_flow(client: AsyncClient, logged_in_headers): @pytest.mark.usefixtures("active_user") async def test_read_only_starter_projects(client: AsyncClient, logged_in_headers): response = await client.get("api/v1/flows/basic_examples/", headers=logged_in_headers) - starter_projects = await asyncio.to_thread(load_starter_projects) + starter_projects = await load_starter_projects() assert response.status_code == 200 assert len(response.json()) == len(starter_projects) diff --git a/src/backend/tests/unit/test_initial_setup.py b/src/backend/tests/unit/test_initial_setup.py index b68124861..5236f0c7f 100644 --- a/src/backend/tests/unit/test_initial_setup.py +++ b/src/backend/tests/unit/test_initial_setup.py @@ -1,6 +1,4 @@ -import asyncio from datetime import datetime -from pathlib import Path import anyio import pytest @@ -18,15 +16,15 @@ from sqlalchemy.orm import selectinload from sqlmodel import select -def test_load_starter_projects(): - projects = load_starter_projects() +async def test_load_starter_projects(): + projects = await load_starter_projects() assert isinstance(projects, list) assert all(isinstance(project[1], dict) for project in projects) - assert all(isinstance(project[0], Path) for project in projects) + assert all(isinstance(project[0], anyio.Path) for project in projects) -def test_get_project_data(): - projects = load_starter_projects() +async def test_get_project_data(): + projects = await load_starter_projects() for _, project in projects: ( project_name, @@ -56,7 +54,7 @@ def test_get_project_data(): async def test_create_or_update_starter_projects(): async with async_session_scope() as session: # Get the number of projects returned by load_starter_projects - num_projects = len(await asyncio.to_thread(load_starter_projects)) + num_projects = len(await load_starter_projects()) # Get the number of projects in the database stmt = select(Folder).options(selectinload(Folder.flows)).where(Folder.name == STARTER_FOLDER_NAME) diff --git a/src/backend/tests/unit/test_loading.py b/src/backend/tests/unit/test_loading.py index 090811f63..6cac5b567 100644 --- a/src/backend/tests/unit/test_loading.py +++ b/src/backend/tests/unit/test_loading.py @@ -22,7 +22,7 @@ from langflow.load import load_flow_from_json async def test_load_flow_from_json_object(): """Test loading a flow from a json file and applying tweaks.""" - result = await asyncio.to_thread(load_starter_projects) + result = await load_starter_projects() project = result[0][1] loaded = await asyncio.to_thread(load_flow_from_json, project) assert loaded is not None