diff --git a/src/backend/base/langflow/main.py b/src/backend/base/langflow/main.py index 72728c985..7b042d448 100644 --- a/src/backend/base/langflow/main.py +++ b/src/backend/base/langflow/main.py @@ -32,6 +32,39 @@ from langflow.utils.logger import configure warnings.filterwarnings("ignore", category=PydanticDeprecatedSince20) +class RequestCancelledMiddleware: + def __init__(self, app): + self.app = app + + async def __call__(self, scope, receive, send): + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + # Let's make a shared queue for the request messages + queue = asyncio.Queue() + + async def message_poller(sentinel, handler_task): + nonlocal queue + while True: + message = await receive() + if message["type"] == "http.disconnect": + handler_task.cancel() + return sentinel # Break the loop + + # Puts the message in the queue + await queue.put(message) + + sentinel = object() + handler_task = asyncio.create_task(self.app(scope, queue.get, send)) + asyncio.create_task(message_poller(sentinel, handler_task)) + + try: + return await handler_task + except asyncio.CancelledError: + logger.debug("Cancelling request due to disconnect") + + class JavaScriptMIMETypeMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): try: @@ -96,6 +129,7 @@ def create_app(): allow_headers=["*"], ) app.add_middleware(JavaScriptMIMETypeMiddleware) + app.add_middleware(RequestCancelledMiddleware) @app.middleware("http") async def flatten_query_string_lists(request: Request, call_next):