fix: Fix async usage in app startup (#4285)

Fix async usage in app startup
This commit is contained in:
Christophe Bornet 2024-10-27 16:00:28 +01:00 committed by GitHub
commit c8bdcf36b0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 24 additions and 26 deletions

View file

@ -3,7 +3,6 @@ import json
import shutil
import time
from collections import defaultdict
from collections.abc import Awaitable
from copy import deepcopy
from datetime import datetime, timezone
from pathlib import Path
@ -600,12 +599,7 @@ def find_existing_flow(session, flow_id, flow_endpoint_name):
return None
async def create_or_update_starter_projects(get_all_components_coro: Awaitable[dict]) -> None:
try:
all_types_dict = await get_all_components_coro
except Exception:
logger.exception("Error loading components")
raise
def create_or_update_starter_projects(all_types_dict: dict) -> None:
with session_scope() as session:
new_folder = create_starter_folder(session)
starter_projects = load_starter_projects()

View file

@ -8,7 +8,6 @@ from http import HTTPStatus
from pathlib import Path
from urllib.parse import urlencode
import nest_asyncio
from fastapi import FastAPI, HTTPException, Request, Response, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, JSONResponse
@ -87,28 +86,25 @@ class JavaScriptMIMETypeMiddleware(BaseHTTPMiddleware):
return response
telemetry_service_tasks = set()
def get_lifespan(*, fix_migration=False, version=None):
def _initialize():
initialize_services(fix_migration=fix_migration)
setup_llm_caching()
initialize_super_user_if_needed()
@asynccontextmanager
async def lifespan(_app: FastAPI):
nest_asyncio.apply()
# Startup message
if version:
rprint(f"[bold green]Starting Langflow v{version}...[/bold green]")
else:
rprint("[bold green]Starting Langflow...[/bold green]")
try:
initialize_services(fix_migration=fix_migration)
setup_llm_caching()
initialize_super_user_if_needed()
task = asyncio.create_task(get_and_cache_all_types_dict(get_settings_service()))
await create_or_update_starter_projects(task)
telemetry_service_task = asyncio.create_task(get_telemetry_service().start())
telemetry_service_tasks.add(telemetry_service_task)
telemetry_service_task.add_done_callback(telemetry_service_tasks.discard)
load_flows_from_directory()
await asyncio.to_thread(_initialize)
all_types_dict = await get_and_cache_all_types_dict(get_settings_service())
await asyncio.to_thread(create_or_update_starter_projects, all_types_dict)
get_telemetry_service().start()
await asyncio.to_thread(load_flows_from_directory)
yield
except Exception as exc:
if "langflow migration --fix" not in str(exc):

View file

@ -1,7 +1,6 @@
from __future__ import annotations
import asyncio
import contextlib
import os
import platform
from datetime import datetime, timezone
@ -112,7 +111,7 @@ class TelemetryService(Service):
async def log_package_component(self, payload: ComponentPayload) -> None:
await self._queue_event((self.send_telemetry_data, payload, "component"))
async def start(self) -> None:
def start(self) -> None:
if self.running or self.do_not_track:
return
try:
@ -131,6 +130,15 @@ class TelemetryService(Service):
except Exception: # noqa: BLE001
logger.exception("Error flushing logs")
async def _cancel_task(self, task: asyncio.Task, cancel_msg: str) -> None:
task.cancel(cancel_msg)
try:
await task
except asyncio.CancelledError:
current_task = asyncio.current_task()
if current_task and current_task.cancelling() > 0:
raise
async def stop(self) -> None:
if self.do_not_track or self._stopping:
return
@ -140,9 +148,9 @@ class TelemetryService(Service):
await self.flush()
self.running = False
if self.worker_task:
self.worker_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self.worker_task
await self._cancel_task(self.worker_task, "Cancel telemetry worker task")
if self.log_package_version_task:
await self._cancel_task(self.log_package_version_task, "Cancel telemetry log package version task")
await self.client.aclose()
except Exception: # noqa: BLE001
logger.exception("Error stopping tracing service")