ref: Make create_or_update_starter_projects async (#5165)
Make create_or_update_starter_projects async
This commit is contained in:
parent
63bdcb9d03
commit
fe6ec1690b
8 changed files with 59 additions and 60 deletions
|
|
@ -1,11 +1,10 @@
|
|||
import asyncio
|
||||
import copy
|
||||
import json
|
||||
import shutil
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from uuid import UUID
|
||||
|
||||
import anyio
|
||||
|
|
@ -14,6 +13,7 @@ from aiofile import async_open
|
|||
from emoji import demojize, purely_emoji
|
||||
from loguru import logger
|
||||
from sqlalchemy.exc import NoResultFound
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlmodel import select
|
||||
|
||||
from langflow.base.constants import FIELD_FORMAT_ATTRIBUTES, NODE_FORMAT_ATTRIBUTES, ORJSON_OPTIONS
|
||||
|
|
@ -31,7 +31,6 @@ from langflow.services.deps import (
|
|||
get_settings_service,
|
||||
get_storage_service,
|
||||
get_variable_service,
|
||||
session_scope,
|
||||
)
|
||||
from langflow.template.field.prompt import DEFAULT_PROMPT_INTUT_TYPES
|
||||
from langflow.utils.util import escape_json_dump
|
||||
|
|
@ -362,13 +361,14 @@ def log_node_changes(node_changes_log) -> None:
|
|||
logger.debug("\n".join(formatted_messages))
|
||||
|
||||
|
||||
def load_starter_projects(retries=3, delay=1) -> list[tuple[Path, dict]]:
|
||||
async def load_starter_projects(retries=3, delay=1) -> list[tuple[anyio.Path, dict]]:
|
||||
starter_projects = []
|
||||
folder = Path(__file__).parent / "starter_projects"
|
||||
for file in folder.glob("*.json"):
|
||||
folder = anyio.Path(__file__).parent / "starter_projects"
|
||||
async for file in folder.glob("*.json"):
|
||||
attempt = 0
|
||||
while attempt < retries:
|
||||
content = file.read_text(encoding="utf-8")
|
||||
async with async_open(str(file), "r", encoding="utf-8") as f:
|
||||
content = await f.read()
|
||||
try:
|
||||
project = orjson.loads(content)
|
||||
starter_projects.append((file, project))
|
||||
|
|
@ -379,27 +379,27 @@ def load_starter_projects(retries=3, delay=1) -> list[tuple[Path, dict]]:
|
|||
if attempt >= retries:
|
||||
msg = f"Error loading starter project {file}: {e}"
|
||||
raise ValueError(msg) from e
|
||||
time.sleep(delay) # Wait before retrying
|
||||
await asyncio.sleep(delay) # Wait before retrying
|
||||
return starter_projects
|
||||
|
||||
|
||||
def copy_profile_pictures() -> None:
|
||||
async def copy_profile_pictures() -> None:
|
||||
config_dir = get_storage_service().settings_service.settings.config_dir
|
||||
if config_dir is None:
|
||||
msg = "Config dir is not set in the settings"
|
||||
raise ValueError(msg)
|
||||
origin = Path(__file__).parent / "profile_pictures"
|
||||
target = Path(config_dir) / "profile_pictures"
|
||||
origin = anyio.Path(__file__).parent / "profile_pictures"
|
||||
target = anyio.Path(config_dir) / "profile_pictures"
|
||||
|
||||
if not origin.exists():
|
||||
if not await origin.exists():
|
||||
msg = f"The source folder '{origin}' does not exist."
|
||||
raise ValueError(msg)
|
||||
|
||||
if not target.exists():
|
||||
target.mkdir(parents=True)
|
||||
if not await target.exists():
|
||||
await target.mkdir(parents=True)
|
||||
|
||||
try:
|
||||
shutil.copytree(origin, target, dirs_exist_ok=True)
|
||||
await asyncio.to_thread(shutil.copytree, str(origin), str(target), dirs_exist_ok=True)
|
||||
logger.debug(f"Folder copied from '{origin}' to '{target}'")
|
||||
|
||||
except Exception: # noqa: BLE001
|
||||
|
|
@ -434,9 +434,10 @@ def get_project_data(project):
|
|||
)
|
||||
|
||||
|
||||
def update_project_file(project_path: Path, project: dict, updated_project_data) -> None:
|
||||
async def update_project_file(project_path: anyio.Path, project: dict, updated_project_data) -> None:
|
||||
project["data"] = updated_project_data
|
||||
project_path.write_text(orjson.dumps(project, option=ORJSON_OPTIONS).decode(), encoding="utf-8")
|
||||
async with async_open(str(project_path), "w", encoding="utf-8") as f:
|
||||
await f.write(orjson.dumps(project, option=ORJSON_OPTIONS).decode())
|
||||
logger.info(f"Updated starter project {project['name']} file")
|
||||
|
||||
|
||||
|
|
@ -490,31 +491,34 @@ def create_new_project(
|
|||
session.add(db_flow)
|
||||
|
||||
|
||||
def get_all_flows_similar_to_project(session, folder_id):
|
||||
return session.exec(select(Folder).where(Folder.id == folder_id)).first().flows
|
||||
async def get_all_flows_similar_to_project(session, folder_id):
|
||||
stmt = select(Folder).options(selectinload(Folder.flows)).where(Folder.id == folder_id)
|
||||
return (await session.exec(stmt)).first().flows
|
||||
|
||||
|
||||
def delete_start_projects(session, folder_id) -> None:
|
||||
flows = session.exec(select(Folder).where(Folder.id == folder_id)).first().flows
|
||||
async def delete_start_projects(session, folder_id) -> None:
|
||||
flows = await get_all_flows_similar_to_project(session, folder_id)
|
||||
for flow in flows:
|
||||
session.delete(flow)
|
||||
session.commit()
|
||||
await session.delete(flow)
|
||||
await session.commit()
|
||||
|
||||
|
||||
def folder_exists(session, folder_name):
|
||||
folder = session.exec(select(Folder).where(Folder.name == folder_name)).first()
|
||||
async def folder_exists(session, folder_name):
|
||||
stmt = select(Folder).where(Folder.name == folder_name)
|
||||
folder = (await session.exec(stmt)).first()
|
||||
return folder is not None
|
||||
|
||||
|
||||
def create_starter_folder(session):
|
||||
if not folder_exists(session, STARTER_FOLDER_NAME):
|
||||
async def create_starter_folder(session):
|
||||
if not await folder_exists(session, STARTER_FOLDER_NAME):
|
||||
new_folder = FolderCreate(name=STARTER_FOLDER_NAME, description=STARTER_FOLDER_DESCRIPTION)
|
||||
db_folder = Folder.model_validate(new_folder, from_attributes=True)
|
||||
session.add(db_folder)
|
||||
session.commit()
|
||||
session.refresh(db_folder)
|
||||
await session.commit()
|
||||
await session.refresh(db_folder)
|
||||
return db_folder
|
||||
return session.exec(select(Folder).where(Folder.name == STARTER_FOLDER_NAME)).first()
|
||||
stmt = select(Folder).where(Folder.name == STARTER_FOLDER_NAME)
|
||||
return (await session.exec(stmt)).first()
|
||||
|
||||
|
||||
def _is_valid_uuid(val):
|
||||
|
|
@ -606,12 +610,12 @@ async def find_existing_flow(session, flow_id, flow_endpoint_name):
|
|||
return None
|
||||
|
||||
|
||||
def create_or_update_starter_projects(all_types_dict: dict) -> None:
|
||||
with session_scope() as session:
|
||||
new_folder = create_starter_folder(session)
|
||||
starter_projects = load_starter_projects()
|
||||
delete_start_projects(session, new_folder.id)
|
||||
copy_profile_pictures()
|
||||
async def create_or_update_starter_projects(all_types_dict: dict) -> None:
|
||||
async with async_session_scope() as session:
|
||||
new_folder = await create_starter_folder(session)
|
||||
starter_projects = await load_starter_projects()
|
||||
await delete_start_projects(session, new_folder.id)
|
||||
await copy_profile_pictures()
|
||||
for project_path, project in starter_projects:
|
||||
(
|
||||
project_name,
|
||||
|
|
@ -632,10 +636,10 @@ def create_or_update_starter_projects(all_types_dict: dict) -> None:
|
|||
project_data = updated_project_data
|
||||
# We also need to update the project data in the file
|
||||
|
||||
update_project_file(project_path, project, updated_project_data)
|
||||
await update_project_file(project_path, project, updated_project_data)
|
||||
if project_name and project_data:
|
||||
for existing_project in get_all_flows_similar_to_project(session, new_folder.id):
|
||||
session.delete(existing_project)
|
||||
for existing_project in await get_all_flows_similar_to_project(session, new_folder.id):
|
||||
await session.delete(existing_project)
|
||||
|
||||
create_new_project(
|
||||
session=session,
|
||||
|
|
|
|||
|
|
@ -105,7 +105,7 @@ def get_lifespan(*, fix_migration=False, version=None):
|
|||
setup_llm_caching()
|
||||
await initialize_super_user_if_needed()
|
||||
all_types_dict = await get_and_cache_all_types_dict(get_settings_service())
|
||||
await asyncio.to_thread(create_or_update_starter_projects, all_types_dict)
|
||||
await create_or_update_starter_projects(all_types_dict)
|
||||
telemetry_service.start()
|
||||
await load_flows_from_directory()
|
||||
yield
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from fastapi import status
|
|||
from httpx import AsyncClient
|
||||
from langflow.graph.schema import RunOutputs
|
||||
from langflow.initial_setup.setup import load_starter_projects
|
||||
from langflow.load import run_flow_from_json
|
||||
from langflow.load.load import arun_flow_from_json
|
||||
|
||||
|
||||
@pytest.mark.api_key_required
|
||||
|
|
@ -78,9 +78,9 @@ async def test_run_with_inputs_and_outputs(client, starter_project, created_api_
|
|||
|
||||
@pytest.mark.noclient
|
||||
@pytest.mark.api_key_required
|
||||
def test_run_flow_from_json_object():
|
||||
async def test_run_flow_from_json_object():
|
||||
"""Test loading a flow from a json file and applying tweaks."""
|
||||
project = next(project for _, project in load_starter_projects() if "Basic Prompting" in project["name"])
|
||||
results = run_flow_from_json(project, input_value="test", fallback_to_env_vars=True)
|
||||
project = next(project for _, project in await load_starter_projects() if "Basic Prompting" in project["name"])
|
||||
results = await arun_flow_from_json(project, input_value="test", fallback_to_env_vars=True)
|
||||
assert results is not None
|
||||
assert all(isinstance(result, RunOutputs) for result in results)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
from langflow.services.deps import get_settings_service
|
||||
|
||||
|
|
@ -72,7 +70,7 @@ async def test_create_starter_projects():
|
|||
await initialize_services(fix_migration=False)
|
||||
settings_service = get_settings_service()
|
||||
types_dict = await get_and_cache_all_types_dict(settings_service)
|
||||
await asyncio.to_thread(create_or_update_starter_projects, types_dict)
|
||||
await create_or_update_starter_projects(types_dict)
|
||||
assert "test_performance.db" in settings_service.settings.database_url
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -256,8 +256,8 @@ def test_update_source_handle():
|
|||
assert updated_edge["data"]["sourceHandle"]["id"] == "last_node"
|
||||
|
||||
|
||||
def test_serialize_graph():
|
||||
starter_projects = load_starter_projects()
|
||||
async def test_serialize_graph():
|
||||
starter_projects = await load_starter_projects()
|
||||
data = starter_projects[0][1]["data"]
|
||||
graph = Graph.from_payload(data)
|
||||
assert isinstance(graph, Graph)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import asyncio
|
||||
import json
|
||||
from typing import NamedTuple
|
||||
from uuid import UUID, uuid4
|
||||
|
|
@ -605,7 +604,7 @@ async def test_delete_nonexistent_flow(client: AsyncClient, logged_in_headers):
|
|||
@pytest.mark.usefixtures("active_user")
|
||||
async def test_read_only_starter_projects(client: AsyncClient, logged_in_headers):
|
||||
response = await client.get("api/v1/flows/basic_examples/", headers=logged_in_headers)
|
||||
starter_projects = await asyncio.to_thread(load_starter_projects)
|
||||
starter_projects = await load_starter_projects()
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == len(starter_projects)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,4 @@
|
|||
import asyncio
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
|
|
@ -18,15 +16,15 @@ from sqlalchemy.orm import selectinload
|
|||
from sqlmodel import select
|
||||
|
||||
|
||||
def test_load_starter_projects():
|
||||
projects = load_starter_projects()
|
||||
async def test_load_starter_projects():
|
||||
projects = await load_starter_projects()
|
||||
assert isinstance(projects, list)
|
||||
assert all(isinstance(project[1], dict) for project in projects)
|
||||
assert all(isinstance(project[0], Path) for project in projects)
|
||||
assert all(isinstance(project[0], anyio.Path) for project in projects)
|
||||
|
||||
|
||||
def test_get_project_data():
|
||||
projects = load_starter_projects()
|
||||
async def test_get_project_data():
|
||||
projects = await load_starter_projects()
|
||||
for _, project in projects:
|
||||
(
|
||||
project_name,
|
||||
|
|
@ -56,7 +54,7 @@ def test_get_project_data():
|
|||
async def test_create_or_update_starter_projects():
|
||||
async with async_session_scope() as session:
|
||||
# Get the number of projects returned by load_starter_projects
|
||||
num_projects = len(await asyncio.to_thread(load_starter_projects))
|
||||
num_projects = len(await load_starter_projects())
|
||||
|
||||
# Get the number of projects in the database
|
||||
stmt = select(Folder).options(selectinload(Folder.flows)).where(Folder.name == STARTER_FOLDER_NAME)
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ from langflow.load import load_flow_from_json
|
|||
|
||||
async def test_load_flow_from_json_object():
|
||||
"""Test loading a flow from a json file and applying tweaks."""
|
||||
result = await asyncio.to_thread(load_starter_projects)
|
||||
result = await load_starter_projects()
|
||||
project = result[0][1]
|
||||
loaded = await asyncio.to_thread(load_flow_from_json, project)
|
||||
assert loaded is not None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue