diff --git a/src/backend/base/langflow/api/utils.py b/src/backend/base/langflow/api/utils.py index fe412eb68..9e5d428d9 100644 --- a/src/backend/base/langflow/api/utils.py +++ b/src/backend/base/langflow/api/utils.py @@ -1,9 +1,8 @@ import uuid import asyncio -from functools import wraps import warnings from pathlib import Path -from typing import TYPE_CHECKING, Optional, Any, Awaitable, Callable +from typing import TYPE_CHECKING, Optional from fastapi import HTTPException, Request from platformdirs import user_cache_dir @@ -328,63 +327,9 @@ def parse_exception(exc): if hasattr(exc, "body"): return exc.body["message"] return str(exc) - return str(exc) -async def disconnect_poller(request: Request, result: Any): - """ - Poll for a disconnect. - If the request disconnects, stop polling and return. - """ - try: - while not await request.is_disconnected(): - await asyncio.sleep(0.01) - - print("Request disconnected") - - return result - except asyncio.CancelledError: - print("Stopping polling loop") - - -def cancel_on_disconnect(handler: Callable[[Request], Awaitable[Any]]): - """ - Decorator that will check if the client disconnects, - and cancel the task if required. - """ - - @wraps(handler) - async def cancel_on_disconnect_decorator(request: Request, *args, **kwargs): - sentinel = object() - - # Create two tasks, one to poll the request and check if the - # client disconnected, and another which is the request handler - poller_task = asyncio.ensure_future(disconnect_poller(request, sentinel)) - handler_task = asyncio.ensure_future(handler(request, *args, **kwargs)) - - done, pending = await asyncio.wait([poller_task, handler_task], return_when=asyncio.FIRST_COMPLETED) - - # Cancel any outstanding tasks - for t in pending: - t.cancel() - - try: - await t - except asyncio.CancelledError: - print(f"{t} was cancelled") - except Exception as exc: - print(f"{t} raised {exc} when being cancelled") - - # Return the result if the handler finished first - if handler_task in done: - return await handler_task - - # Otherwise, raise an exception - # This is not exactly needed, but it will prevent - # validation errors if your request handler is supposed - # to return something. - print("Raising an HTTP error because I was disconnected!!") - - raise HTTPException(503) - - return cancel_on_disconnect_decorator +async def check_client_disconnection(request: Request): + while not await request.is_disconnected(): + await asyncio.sleep(1) + raise HTTPException(status_code=499, detail="Client disconnected") diff --git a/src/backend/base/langflow/api/v1/chat.py b/src/backend/base/langflow/api/v1/chat.py index 7b87a03d8..ba8410f4b 100644 --- a/src/backend/base/langflow/api/v1/chat.py +++ b/src/backend/base/langflow/api/v1/chat.py @@ -4,7 +4,7 @@ import uuid from functools import partial from typing import TYPE_CHECKING, Annotated, Optional -from fastapi import APIRouter, BackgroundTasks, Body, Depends, HTTPException +from fastapi import APIRouter, BackgroundTasks, Body, Depends, HTTPException, Request from fastapi.responses import StreamingResponse from loguru import logger @@ -15,7 +15,7 @@ from langflow.api.utils import ( format_exception_message, get_top_level_vertices, parse_exception, - cancel_on_disconnect + check_client_disconnection, ) from langflow.api.v1.schemas import ( FlowDataRequest, @@ -57,9 +57,10 @@ async def try_running_celery_task(vertex, user_id): @router.post("/build/{flow_id}/vertices", response_model=VerticesOrderResponse) -@cancel_on_disconnect async def retrieve_vertices_order( flow_id: uuid.UUID, + background_tasks: BackgroundTasks, + request: Request, data: Optional[Annotated[Optional[FlowDataRequest], Body(embed=True)]] = None, stop_component_id: Optional[str] = None, start_component_id: Optional[str] = None, @@ -83,6 +84,7 @@ async def retrieve_vertices_order( Raises: HTTPException: If there is an error checking the build status. """ + background_tasks.add_task(check_client_disconnection, request) try: flow_id_str = str(flow_id) # First, we need to check if the flow_id is in the cache @@ -124,11 +126,11 @@ async def retrieve_vertices_order( @router.post("/build/{flow_id}/vertices/{vertex_id}") -@cancel_on_disconnect async def build_vertex( flow_id: uuid.UUID, vertex_id: str, background_tasks: BackgroundTasks, + request: Request, inputs: Annotated[Optional[InputValueRequest], Body(embed=True)] = None, files: Optional[list[str]] = None, chat_service: "ChatService" = Depends(get_chat_service), @@ -151,6 +153,8 @@ async def build_vertex( HTTPException: If there is an error building the vertex. """ + background_tasks.add_task(check_client_disconnection, request) + flow_id_str = str(flow_id) next_runnable_vertices = [] @@ -265,10 +269,11 @@ async def build_vertex( @router.get("/build/{flow_id}/{vertex_id}/stream", response_class=StreamingResponse) -@cancel_on_disconnect async def build_vertex_stream( flow_id: uuid.UUID, vertex_id: str, + background_tasks: BackgroundTasks, + request: Request, session_id: Optional[str] = None, chat_service: "ChatService" = Depends(get_chat_service), session_service: "SessionService" = Depends(get_session_service), @@ -298,6 +303,7 @@ async def build_vertex_stream( Raises: HTTPException: If an error occurs while building the vertex. """ + background_tasks.add_task(check_client_disconnection, request) try: flow_id_str = str(flow_id)