Discard old approach and implement a simple solution

This commit is contained in:
italojohnny 2024-06-20 12:45:45 -03:00 committed by Gabriel Luiz Freitas Almeida
commit 2a88176a8e
2 changed files with 16 additions and 65 deletions

View file

@ -1,9 +1,8 @@
import uuid
import asyncio
from functools import wraps
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Any, Awaitable, Callable
from typing import TYPE_CHECKING, Optional
from fastapi import HTTPException, Request
from platformdirs import user_cache_dir
@ -328,63 +327,9 @@ def parse_exception(exc):
if hasattr(exc, "body"):
return exc.body["message"]
return str(exc)
return str(exc)
async def disconnect_poller(request: Request, result: Any):
"""
Poll for a disconnect.
If the request disconnects, stop polling and return.
"""
try:
while not await request.is_disconnected():
await asyncio.sleep(0.01)
print("Request disconnected")
return result
except asyncio.CancelledError:
print("Stopping polling loop")
def cancel_on_disconnect(handler: Callable[[Request], Awaitable[Any]]):
"""
Decorator that will check if the client disconnects,
and cancel the task if required.
"""
@wraps(handler)
async def cancel_on_disconnect_decorator(request: Request, *args, **kwargs):
sentinel = object()
# Create two tasks, one to poll the request and check if the
# client disconnected, and another which is the request handler
poller_task = asyncio.ensure_future(disconnect_poller(request, sentinel))
handler_task = asyncio.ensure_future(handler(request, *args, **kwargs))
done, pending = await asyncio.wait([poller_task, handler_task], return_when=asyncio.FIRST_COMPLETED)
# Cancel any outstanding tasks
for t in pending:
t.cancel()
try:
await t
except asyncio.CancelledError:
print(f"{t} was cancelled")
except Exception as exc:
print(f"{t} raised {exc} when being cancelled")
# Return the result if the handler finished first
if handler_task in done:
return await handler_task
# Otherwise, raise an exception
# This is not exactly needed, but it will prevent
# validation errors if your request handler is supposed
# to return something.
print("Raising an HTTP error because I was disconnected!!")
raise HTTPException(503)
return cancel_on_disconnect_decorator
async def check_client_disconnection(request: Request):
while not await request.is_disconnected():
await asyncio.sleep(1)
raise HTTPException(status_code=499, detail="Client disconnected")

View file

@ -4,7 +4,7 @@ import uuid
from functools import partial
from typing import TYPE_CHECKING, Annotated, Optional
from fastapi import APIRouter, BackgroundTasks, Body, Depends, HTTPException
from fastapi import APIRouter, BackgroundTasks, Body, Depends, HTTPException, Request
from fastapi.responses import StreamingResponse
from loguru import logger
@ -15,7 +15,7 @@ from langflow.api.utils import (
format_exception_message,
get_top_level_vertices,
parse_exception,
cancel_on_disconnect
check_client_disconnection,
)
from langflow.api.v1.schemas import (
FlowDataRequest,
@ -57,9 +57,10 @@ async def try_running_celery_task(vertex, user_id):
@router.post("/build/{flow_id}/vertices", response_model=VerticesOrderResponse)
@cancel_on_disconnect
async def retrieve_vertices_order(
flow_id: uuid.UUID,
background_tasks: BackgroundTasks,
request: Request,
data: Optional[Annotated[Optional[FlowDataRequest], Body(embed=True)]] = None,
stop_component_id: Optional[str] = None,
start_component_id: Optional[str] = None,
@ -83,6 +84,7 @@ async def retrieve_vertices_order(
Raises:
HTTPException: If there is an error checking the build status.
"""
background_tasks.add_task(check_client_disconnection, request)
try:
flow_id_str = str(flow_id)
# First, we need to check if the flow_id is in the cache
@ -124,11 +126,11 @@ async def retrieve_vertices_order(
@router.post("/build/{flow_id}/vertices/{vertex_id}")
@cancel_on_disconnect
async def build_vertex(
flow_id: uuid.UUID,
vertex_id: str,
background_tasks: BackgroundTasks,
request: Request,
inputs: Annotated[Optional[InputValueRequest], Body(embed=True)] = None,
files: Optional[list[str]] = None,
chat_service: "ChatService" = Depends(get_chat_service),
@ -151,6 +153,8 @@ async def build_vertex(
HTTPException: If there is an error building the vertex.
"""
background_tasks.add_task(check_client_disconnection, request)
flow_id_str = str(flow_id)
next_runnable_vertices = []
@ -265,10 +269,11 @@ async def build_vertex(
@router.get("/build/{flow_id}/{vertex_id}/stream", response_class=StreamingResponse)
@cancel_on_disconnect
async def build_vertex_stream(
flow_id: uuid.UUID,
vertex_id: str,
background_tasks: BackgroundTasks,
request: Request,
session_id: Optional[str] = None,
chat_service: "ChatService" = Depends(get_chat_service),
session_service: "SessionService" = Depends(get_session_service),
@ -298,6 +303,7 @@ async def build_vertex_stream(
Raises:
HTTPException: If an error occurs while building the vertex.
"""
background_tasks.add_task(check_client_disconnection, request)
try:
flow_id_str = str(flow_id)