ref: Make create_or_update_starter_projects async (#5165)

Make create_or_update_starter_projects async
This commit is contained in:
Christophe Bornet 2024-12-10 08:51:35 +01:00 committed by GitHub
commit fe6ec1690b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 59 additions and 60 deletions

View file

@ -1,11 +1,10 @@
import asyncio
import copy
import json
import shutil
import time
from collections import defaultdict
from copy import deepcopy
from datetime import datetime, timezone
from pathlib import Path
from uuid import UUID
import anyio
@ -14,6 +13,7 @@ from aiofile import async_open
from emoji import demojize, purely_emoji
from loguru import logger
from sqlalchemy.exc import NoResultFound
from sqlalchemy.orm import selectinload
from sqlmodel import select
from langflow.base.constants import FIELD_FORMAT_ATTRIBUTES, NODE_FORMAT_ATTRIBUTES, ORJSON_OPTIONS
@ -31,7 +31,6 @@ from langflow.services.deps import (
get_settings_service,
get_storage_service,
get_variable_service,
session_scope,
)
from langflow.template.field.prompt import DEFAULT_PROMPT_INTUT_TYPES
from langflow.utils.util import escape_json_dump
@ -362,13 +361,14 @@ def log_node_changes(node_changes_log) -> None:
logger.debug("\n".join(formatted_messages))
def load_starter_projects(retries=3, delay=1) -> list[tuple[Path, dict]]:
async def load_starter_projects(retries=3, delay=1) -> list[tuple[anyio.Path, dict]]:
starter_projects = []
folder = Path(__file__).parent / "starter_projects"
for file in folder.glob("*.json"):
folder = anyio.Path(__file__).parent / "starter_projects"
async for file in folder.glob("*.json"):
attempt = 0
while attempt < retries:
content = file.read_text(encoding="utf-8")
async with async_open(str(file), "r", encoding="utf-8") as f:
content = await f.read()
try:
project = orjson.loads(content)
starter_projects.append((file, project))
@ -379,27 +379,27 @@ def load_starter_projects(retries=3, delay=1) -> list[tuple[Path, dict]]:
if attempt >= retries:
msg = f"Error loading starter project {file}: {e}"
raise ValueError(msg) from e
time.sleep(delay) # Wait before retrying
await asyncio.sleep(delay) # Wait before retrying
return starter_projects
def copy_profile_pictures() -> None:
async def copy_profile_pictures() -> None:
config_dir = get_storage_service().settings_service.settings.config_dir
if config_dir is None:
msg = "Config dir is not set in the settings"
raise ValueError(msg)
origin = Path(__file__).parent / "profile_pictures"
target = Path(config_dir) / "profile_pictures"
origin = anyio.Path(__file__).parent / "profile_pictures"
target = anyio.Path(config_dir) / "profile_pictures"
if not origin.exists():
if not await origin.exists():
msg = f"The source folder '{origin}' does not exist."
raise ValueError(msg)
if not target.exists():
target.mkdir(parents=True)
if not await target.exists():
await target.mkdir(parents=True)
try:
shutil.copytree(origin, target, dirs_exist_ok=True)
await asyncio.to_thread(shutil.copytree, str(origin), str(target), dirs_exist_ok=True)
logger.debug(f"Folder copied from '{origin}' to '{target}'")
except Exception: # noqa: BLE001
@ -434,9 +434,10 @@ def get_project_data(project):
)
def update_project_file(project_path: Path, project: dict, updated_project_data) -> None:
async def update_project_file(project_path: anyio.Path, project: dict, updated_project_data) -> None:
project["data"] = updated_project_data
project_path.write_text(orjson.dumps(project, option=ORJSON_OPTIONS).decode(), encoding="utf-8")
async with async_open(str(project_path), "w", encoding="utf-8") as f:
await f.write(orjson.dumps(project, option=ORJSON_OPTIONS).decode())
logger.info(f"Updated starter project {project['name']} file")
@ -490,31 +491,34 @@ def create_new_project(
session.add(db_flow)
def get_all_flows_similar_to_project(session, folder_id):
return session.exec(select(Folder).where(Folder.id == folder_id)).first().flows
async def get_all_flows_similar_to_project(session, folder_id):
stmt = select(Folder).options(selectinload(Folder.flows)).where(Folder.id == folder_id)
return (await session.exec(stmt)).first().flows
def delete_start_projects(session, folder_id) -> None:
flows = session.exec(select(Folder).where(Folder.id == folder_id)).first().flows
async def delete_start_projects(session, folder_id) -> None:
flows = await get_all_flows_similar_to_project(session, folder_id)
for flow in flows:
session.delete(flow)
session.commit()
await session.delete(flow)
await session.commit()
def folder_exists(session, folder_name):
folder = session.exec(select(Folder).where(Folder.name == folder_name)).first()
async def folder_exists(session, folder_name):
stmt = select(Folder).where(Folder.name == folder_name)
folder = (await session.exec(stmt)).first()
return folder is not None
def create_starter_folder(session):
if not folder_exists(session, STARTER_FOLDER_NAME):
async def create_starter_folder(session):
if not await folder_exists(session, STARTER_FOLDER_NAME):
new_folder = FolderCreate(name=STARTER_FOLDER_NAME, description=STARTER_FOLDER_DESCRIPTION)
db_folder = Folder.model_validate(new_folder, from_attributes=True)
session.add(db_folder)
session.commit()
session.refresh(db_folder)
await session.commit()
await session.refresh(db_folder)
return db_folder
return session.exec(select(Folder).where(Folder.name == STARTER_FOLDER_NAME)).first()
stmt = select(Folder).where(Folder.name == STARTER_FOLDER_NAME)
return (await session.exec(stmt)).first()
def _is_valid_uuid(val):
@ -606,12 +610,12 @@ async def find_existing_flow(session, flow_id, flow_endpoint_name):
return None
def create_or_update_starter_projects(all_types_dict: dict) -> None:
with session_scope() as session:
new_folder = create_starter_folder(session)
starter_projects = load_starter_projects()
delete_start_projects(session, new_folder.id)
copy_profile_pictures()
async def create_or_update_starter_projects(all_types_dict: dict) -> None:
async with async_session_scope() as session:
new_folder = await create_starter_folder(session)
starter_projects = await load_starter_projects()
await delete_start_projects(session, new_folder.id)
await copy_profile_pictures()
for project_path, project in starter_projects:
(
project_name,
@ -632,10 +636,10 @@ def create_or_update_starter_projects(all_types_dict: dict) -> None:
project_data = updated_project_data
# We also need to update the project data in the file
update_project_file(project_path, project, updated_project_data)
await update_project_file(project_path, project, updated_project_data)
if project_name and project_data:
for existing_project in get_all_flows_similar_to_project(session, new_folder.id):
session.delete(existing_project)
for existing_project in await get_all_flows_similar_to_project(session, new_folder.id):
await session.delete(existing_project)
create_new_project(
session=session,

View file

@ -105,7 +105,7 @@ def get_lifespan(*, fix_migration=False, version=None):
setup_llm_caching()
await initialize_super_user_if_needed()
all_types_dict = await get_and_cache_all_types_dict(get_settings_service())
await asyncio.to_thread(create_or_update_starter_projects, all_types_dict)
await create_or_update_starter_projects(all_types_dict)
telemetry_service.start()
await load_flows_from_directory()
yield

View file

@ -5,7 +5,7 @@ from fastapi import status
from httpx import AsyncClient
from langflow.graph.schema import RunOutputs
from langflow.initial_setup.setup import load_starter_projects
from langflow.load import run_flow_from_json
from langflow.load.load import arun_flow_from_json
@pytest.mark.api_key_required
@ -78,9 +78,9 @@ async def test_run_with_inputs_and_outputs(client, starter_project, created_api_
@pytest.mark.noclient
@pytest.mark.api_key_required
def test_run_flow_from_json_object():
async def test_run_flow_from_json_object():
"""Test loading a flow from a json file and applying tweaks."""
project = next(project for _, project in load_starter_projects() if "Basic Prompting" in project["name"])
results = run_flow_from_json(project, input_value="test", fallback_to_env_vars=True)
project = next(project for _, project in await load_starter_projects() if "Basic Prompting" in project["name"])
results = await arun_flow_from_json(project, input_value="test", fallback_to_env_vars=True)
assert results is not None
assert all(isinstance(result, RunOutputs) for result in results)

View file

@ -1,5 +1,3 @@
import asyncio
import pytest
from langflow.services.deps import get_settings_service
@ -72,7 +70,7 @@ async def test_create_starter_projects():
await initialize_services(fix_migration=False)
settings_service = get_settings_service()
types_dict = await get_and_cache_all_types_dict(settings_service)
await asyncio.to_thread(create_or_update_starter_projects, types_dict)
await create_or_update_starter_projects(types_dict)
assert "test_performance.db" in settings_service.settings.database_url

View file

@ -256,8 +256,8 @@ def test_update_source_handle():
assert updated_edge["data"]["sourceHandle"]["id"] == "last_node"
def test_serialize_graph():
starter_projects = load_starter_projects()
async def test_serialize_graph():
starter_projects = await load_starter_projects()
data = starter_projects[0][1]["data"]
graph = Graph.from_payload(data)
assert isinstance(graph, Graph)

View file

@ -1,4 +1,3 @@
import asyncio
import json
from typing import NamedTuple
from uuid import UUID, uuid4
@ -605,7 +604,7 @@ async def test_delete_nonexistent_flow(client: AsyncClient, logged_in_headers):
@pytest.mark.usefixtures("active_user")
async def test_read_only_starter_projects(client: AsyncClient, logged_in_headers):
response = await client.get("api/v1/flows/basic_examples/", headers=logged_in_headers)
starter_projects = await asyncio.to_thread(load_starter_projects)
starter_projects = await load_starter_projects()
assert response.status_code == 200
assert len(response.json()) == len(starter_projects)

View file

@ -1,6 +1,4 @@
import asyncio
from datetime import datetime
from pathlib import Path
import anyio
import pytest
@ -18,15 +16,15 @@ from sqlalchemy.orm import selectinload
from sqlmodel import select
def test_load_starter_projects():
projects = load_starter_projects()
async def test_load_starter_projects():
projects = await load_starter_projects()
assert isinstance(projects, list)
assert all(isinstance(project[1], dict) for project in projects)
assert all(isinstance(project[0], Path) for project in projects)
assert all(isinstance(project[0], anyio.Path) for project in projects)
def test_get_project_data():
projects = load_starter_projects()
async def test_get_project_data():
projects = await load_starter_projects()
for _, project in projects:
(
project_name,
@ -56,7 +54,7 @@ def test_get_project_data():
async def test_create_or_update_starter_projects():
async with async_session_scope() as session:
# Get the number of projects returned by load_starter_projects
num_projects = len(await asyncio.to_thread(load_starter_projects))
num_projects = len(await load_starter_projects())
# Get the number of projects in the database
stmt = select(Folder).options(selectinload(Folder.flows)).where(Folder.name == STARTER_FOLDER_NAME)

View file

@ -22,7 +22,7 @@ from langflow.load import load_flow_from_json
async def test_load_flow_from_json_object():
"""Test loading a flow from a json file and applying tweaks."""
result = await asyncio.to_thread(load_starter_projects)
result = await load_starter_projects()
project = result[0][1]
loaded = await asyncio.to_thread(load_flow_from_json, project)
assert loaded is not None