add decorator to cancel disconnected requests
https://github.com/RedRoserade/fastapi-disconnect-example/blob/main/app.py
This commit is contained in:
parent
34333e1d50
commit
90f456b2a9
1 changed files with 63 additions and 2 deletions
|
|
@ -1,9 +1,11 @@
|
|||
import uuid
|
||||
import asyncio
|
||||
from functools import wraps
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, Optional, Any, Awaitable, Callable
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi import HTTPException, Request
|
||||
from platformdirs import user_cache_dir
|
||||
from sqlmodel import Session
|
||||
|
||||
|
|
@ -327,3 +329,62 @@ def parse_exception(exc):
|
|||
return exc.body["message"]
|
||||
return str(exc)
|
||||
return str(exc)
|
||||
|
||||
|
||||
async def disconnect_poller(request: Request, result: Any):
|
||||
"""
|
||||
Poll for a disconnect.
|
||||
If the request disconnects, stop polling and return.
|
||||
"""
|
||||
try:
|
||||
while not await request.is_disconnected():
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
print("Request disconnected")
|
||||
|
||||
return result
|
||||
except asyncio.CancelledError:
|
||||
print("Stopping polling loop")
|
||||
|
||||
|
||||
def cancel_on_disconnect(handler: Callable[[Request], Awaitable[Any]]):
|
||||
"""
|
||||
Decorator that will check if the client disconnects,
|
||||
and cancel the task if required.
|
||||
"""
|
||||
|
||||
@wraps(handler)
|
||||
async def cancel_on_disconnect_decorator(request: Request, *args, **kwargs):
|
||||
sentinel = object()
|
||||
|
||||
# Create two tasks, one to poll the request and check if the
|
||||
# client disconnected, and another which is the request handler
|
||||
poller_task = asyncio.ensure_future(disconnect_poller(request, sentinel))
|
||||
handler_task = asyncio.ensure_future(handler(request, *args, **kwargs))
|
||||
|
||||
done, pending = await asyncio.wait([poller_task, handler_task], return_when=asyncio.FIRST_COMPLETED)
|
||||
|
||||
# Cancel any outstanding tasks
|
||||
for t in pending:
|
||||
t.cancel()
|
||||
|
||||
try:
|
||||
await t
|
||||
except asyncio.CancelledError:
|
||||
print(f"{t} was cancelled")
|
||||
except Exception as exc:
|
||||
print(f"{t} raised {exc} when being cancelled")
|
||||
|
||||
# Return the result if the handler finished first
|
||||
if handler_task in done:
|
||||
return await handler_task
|
||||
|
||||
# Otherwise, raise an exception
|
||||
# This is not exactly needed, but it will prevent
|
||||
# validation errors if your request handler is supposed
|
||||
# to return something.
|
||||
print("Raising an HTTP error because I was disconnected!!")
|
||||
|
||||
raise HTTPException(503)
|
||||
|
||||
return cancel_on_disconnect_decorator
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue