From 90f456b2a951c755faaa52960da8f99faee3e63c Mon Sep 17 00:00:00 2001 From: italojohnny Date: Thu, 20 Jun 2024 11:07:22 -0300 Subject: [PATCH] add decorator to cancel disconnected requests https://github.com/RedRoserade/fastapi-disconnect-example/blob/main/app.py --- src/backend/base/langflow/api/utils.py | 65 +++++++++++++++++++++++++- 1 file changed, 63 insertions(+), 2 deletions(-) diff --git a/src/backend/base/langflow/api/utils.py b/src/backend/base/langflow/api/utils.py index 99d4d3681..fe412eb68 100644 --- a/src/backend/base/langflow/api/utils.py +++ b/src/backend/base/langflow/api/utils.py @@ -1,9 +1,11 @@ import uuid +import asyncio +from functools import wraps import warnings from pathlib import Path -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Any, Awaitable, Callable -from fastapi import HTTPException +from fastapi import HTTPException, Request from platformdirs import user_cache_dir from sqlmodel import Session @@ -327,3 +329,62 @@ def parse_exception(exc): return exc.body["message"] return str(exc) return str(exc) + + +async def disconnect_poller(request: Request, result: Any): + """ + Poll for a disconnect. + If the request disconnects, stop polling and return. + """ + try: + while not await request.is_disconnected(): + await asyncio.sleep(0.01) + + print("Request disconnected") + + return result + except asyncio.CancelledError: + print("Stopping polling loop") + + +def cancel_on_disconnect(handler: Callable[[Request], Awaitable[Any]]): + """ + Decorator that will check if the client disconnects, + and cancel the task if required. + """ + + @wraps(handler) + async def cancel_on_disconnect_decorator(request: Request, *args, **kwargs): + sentinel = object() + + # Create two tasks, one to poll the request and check if the + # client disconnected, and another which is the request handler + poller_task = asyncio.ensure_future(disconnect_poller(request, sentinel)) + handler_task = asyncio.ensure_future(handler(request, *args, **kwargs)) + + done, pending = await asyncio.wait([poller_task, handler_task], return_when=asyncio.FIRST_COMPLETED) + + # Cancel any outstanding tasks + for t in pending: + t.cancel() + + try: + await t + except asyncio.CancelledError: + print(f"{t} was cancelled") + except Exception as exc: + print(f"{t} raised {exc} when being cancelled") + + # Return the result if the handler finished first + if handler_task in done: + return await handler_task + + # Otherwise, raise an exception + # This is not exactly needed, but it will prevent + # validation errors if your request handler is supposed + # to return something. + print("Raising an HTTP error because I was disconnected!!") + + raise HTTPException(503) + + return cancel_on_disconnect_decorator