modify get_current_user to accept api_key as authentication method (#1108)

This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-11-04 10:29:31 -03:00 committed by GitHub
commit 650398f2f4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 3 deletions

View file

@ -12,7 +12,10 @@ from langflow.api.utils import build_input_keys_response
from langflow.api.v1.schemas import BuildStatus, BuiltResponse, InitResponse, StreamData
from langflow.graph.graph.base import Graph
from langflow.services.auth.utils import get_current_active_user, get_current_user
from langflow.services.auth.utils import (
get_current_active_user,
get_current_user_by_jwt,
)
from langflow.services.cache.utils import update_build_status
from loguru import logger
from langflow.services.getters import get_chat_service, get_session, get_cache_service
@ -34,8 +37,8 @@ async def chat(
):
"""Websocket endpoint for chat."""
try:
user = await get_current_user_by_jwt(token, db)
await websocket.accept()
user = await get_current_user(token, db)
if not user:
await websocket.close(
code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized"

View file

@ -15,7 +15,7 @@ from langflow.services.database.models.user.crud import (
from langflow.services.getters import get_session, get_settings_service
from sqlmodel import Session
oauth2_login = OAuth2PasswordBearer(tokenUrl="api/v1/login")
oauth2_login = OAuth2PasswordBearer(tokenUrl="api/v1/login", auto_error=False)
API_KEY_NAME = "x-api-key"
@ -69,6 +69,30 @@ async def api_key_security(
async def get_current_user(
token: str = Security(oauth2_login),
query_param: str = Security(api_key_query),
header_param: str = Security(api_key_header),
db: Session = Depends(get_session),
) -> User:
if token:
return await get_current_user_by_jwt(token, db)
else:
if not query_param and not header_param:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="An API key must be passed as query or header",
)
user = await api_key_security(query_param, header_param, db)
if user:
return user
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid or missing API key",
)
async def get_current_user_by_jwt(
token: Annotated[str, Depends(oauth2_login)],
db: Session = Depends(get_session),
) -> User: