fix lints and also the bug in the get_current_user

This commit is contained in:
Maryam Abdoli 2023-11-03 13:01:28 -04:00
commit ef949838e2
2 changed files with 11 additions and 13 deletions

View file

@ -2,7 +2,6 @@ from fastapi import (
APIRouter,
Depends,
HTTPException,
Query,
WebSocket,
WebSocketException,
status,
@ -11,8 +10,9 @@ from fastapi.responses import StreamingResponse
from langflow.api.utils import build_input_keys_response
from langflow.api.v1.schemas import BuildStatus, BuiltResponse, InitResponse, StreamData
from langflow.services.database.models.user.user import User
from langflow.graph.graph.base import Graph
from langflow.services.auth.utils import get_current_active_user, get_current_user
from langflow.services.auth.utils import get_current_active_user
from langflow.services.cache.utils import update_build_status
from loguru import logger
from langflow.services.getters import get_chat_service, get_session, get_cache_service
@ -28,14 +28,13 @@ router = APIRouter(tags=["Chat"])
async def chat(
client_id: str,
websocket: WebSocket,
token: str = Query(...),
db: Session = Depends(get_session),
chat_service: "ChatService" = Depends(get_chat_service),
user: User = Depends(get_current_active_user),
):
"""Websocket endpoint for chat."""
try:
await websocket.accept()
user = await get_current_user(token, db)
if not user:
await websocket.close(
code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized"

View file

@ -15,7 +15,7 @@ from langflow.services.database.models.user.crud import (
from langflow.services.getters import get_session, get_settings_service
from sqlmodel import Session
oauth2_login = OAuth2PasswordBearer(tokenUrl="api/v1/login")
oauth2_login = OAuth2PasswordBearer(tokenUrl="api/v1/login", auto_error=False)
API_KEY_NAME = "x-api-key"
@ -74,23 +74,22 @@ async def get_current_user(
header_param: str = Security(api_key_header),
db: Session = Depends(get_session),
) -> User:
try:
if token:
return await get_current_user_by_jwt(token, db)
except HTTPException as exc:
else:
if not query_param and not header_param:
raise exc
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="An API key must be passed as query or header",
)
user = await api_key_security(query_param, header_param, db)
if user:
return user
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid or missing API key",
)
except Exception as exc:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Internal server error: {exc}",
)
async def get_current_user_by_jwt(