From c8bdcf36b07d7ae07ff3768fd27f12b8ee164ced Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Sun, 27 Oct 2024 16:00:28 +0100 Subject: [PATCH] fix: Fix async usage in app startup (#4285) Fix async usage in app startup --- .../base/langflow/initial_setup/setup.py | 8 +------ src/backend/base/langflow/main.py | 24 ++++++++----------- .../langflow/services/telemetry/service.py | 18 ++++++++++---- 3 files changed, 24 insertions(+), 26 deletions(-) diff --git a/src/backend/base/langflow/initial_setup/setup.py b/src/backend/base/langflow/initial_setup/setup.py index 14deb499f..5afaf44b1 100644 --- a/src/backend/base/langflow/initial_setup/setup.py +++ b/src/backend/base/langflow/initial_setup/setup.py @@ -3,7 +3,6 @@ import json import shutil import time from collections import defaultdict -from collections.abc import Awaitable from copy import deepcopy from datetime import datetime, timezone from pathlib import Path @@ -600,12 +599,7 @@ def find_existing_flow(session, flow_id, flow_endpoint_name): return None -async def create_or_update_starter_projects(get_all_components_coro: Awaitable[dict]) -> None: - try: - all_types_dict = await get_all_components_coro - except Exception: - logger.exception("Error loading components") - raise +def create_or_update_starter_projects(all_types_dict: dict) -> None: with session_scope() as session: new_folder = create_starter_folder(session) starter_projects = load_starter_projects() diff --git a/src/backend/base/langflow/main.py b/src/backend/base/langflow/main.py index 507ac180d..d3f5ae981 100644 --- a/src/backend/base/langflow/main.py +++ b/src/backend/base/langflow/main.py @@ -8,7 +8,6 @@ from http import HTTPStatus from pathlib import Path from urllib.parse import urlencode -import nest_asyncio from fastapi import FastAPI, HTTPException, Request, Response, status from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse, JSONResponse @@ -87,28 +86,25 @@ class JavaScriptMIMETypeMiddleware(BaseHTTPMiddleware): return response -telemetry_service_tasks = set() - - def get_lifespan(*, fix_migration=False, version=None): + def _initialize(): + initialize_services(fix_migration=fix_migration) + setup_llm_caching() + initialize_super_user_if_needed() + @asynccontextmanager async def lifespan(_app: FastAPI): - nest_asyncio.apply() # Startup message if version: rprint(f"[bold green]Starting Langflow v{version}...[/bold green]") else: rprint("[bold green]Starting Langflow...[/bold green]") try: - initialize_services(fix_migration=fix_migration) - setup_llm_caching() - initialize_super_user_if_needed() - task = asyncio.create_task(get_and_cache_all_types_dict(get_settings_service())) - await create_or_update_starter_projects(task) - telemetry_service_task = asyncio.create_task(get_telemetry_service().start()) - telemetry_service_tasks.add(telemetry_service_task) - telemetry_service_task.add_done_callback(telemetry_service_tasks.discard) - load_flows_from_directory() + await asyncio.to_thread(_initialize) + all_types_dict = await get_and_cache_all_types_dict(get_settings_service()) + await asyncio.to_thread(create_or_update_starter_projects, all_types_dict) + get_telemetry_service().start() + await asyncio.to_thread(load_flows_from_directory) yield except Exception as exc: if "langflow migration --fix" not in str(exc): diff --git a/src/backend/base/langflow/services/telemetry/service.py b/src/backend/base/langflow/services/telemetry/service.py index cb396981f..0c43cee96 100644 --- a/src/backend/base/langflow/services/telemetry/service.py +++ b/src/backend/base/langflow/services/telemetry/service.py @@ -1,7 +1,6 @@ from __future__ import annotations import asyncio -import contextlib import os import platform from datetime import datetime, timezone @@ -112,7 +111,7 @@ class TelemetryService(Service): async def log_package_component(self, payload: ComponentPayload) -> None: await self._queue_event((self.send_telemetry_data, payload, "component")) - async def start(self) -> None: + def start(self) -> None: if self.running or self.do_not_track: return try: @@ -131,6 +130,15 @@ class TelemetryService(Service): except Exception: # noqa: BLE001 logger.exception("Error flushing logs") + async def _cancel_task(self, task: asyncio.Task, cancel_msg: str) -> None: + task.cancel(cancel_msg) + try: + await task + except asyncio.CancelledError: + current_task = asyncio.current_task() + if current_task and current_task.cancelling() > 0: + raise + async def stop(self) -> None: if self.do_not_track or self._stopping: return @@ -140,9 +148,9 @@ class TelemetryService(Service): await self.flush() self.running = False if self.worker_task: - self.worker_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await self.worker_task + await self._cancel_task(self.worker_task, "Cancel telemetry worker task") + if self.log_package_version_task: + await self._cancel_task(self.log_package_version_task, "Cancel telemetry log package version task") await self.client.aclose() except Exception: # noqa: BLE001 logger.exception("Error stopping tracing service")