modify get_current_user to accept api_key as authentication method (#1108)
This commit is contained in:
commit
650398f2f4
2 changed files with 30 additions and 3 deletions
|
|
@ -12,7 +12,10 @@ from langflow.api.utils import build_input_keys_response
|
|||
from langflow.api.v1.schemas import BuildStatus, BuiltResponse, InitResponse, StreamData
|
||||
|
||||
from langflow.graph.graph.base import Graph
|
||||
from langflow.services.auth.utils import get_current_active_user, get_current_user
|
||||
from langflow.services.auth.utils import (
|
||||
get_current_active_user,
|
||||
get_current_user_by_jwt,
|
||||
)
|
||||
from langflow.services.cache.utils import update_build_status
|
||||
from loguru import logger
|
||||
from langflow.services.getters import get_chat_service, get_session, get_cache_service
|
||||
|
|
@ -34,8 +37,8 @@ async def chat(
|
|||
):
|
||||
"""Websocket endpoint for chat."""
|
||||
try:
|
||||
user = await get_current_user_by_jwt(token, db)
|
||||
await websocket.accept()
|
||||
user = await get_current_user(token, db)
|
||||
if not user:
|
||||
await websocket.close(
|
||||
code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized"
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from langflow.services.database.models.user.crud import (
|
|||
from langflow.services.getters import get_session, get_settings_service
|
||||
from sqlmodel import Session
|
||||
|
||||
oauth2_login = OAuth2PasswordBearer(tokenUrl="api/v1/login")
|
||||
oauth2_login = OAuth2PasswordBearer(tokenUrl="api/v1/login", auto_error=False)
|
||||
|
||||
API_KEY_NAME = "x-api-key"
|
||||
|
||||
|
|
@ -69,6 +69,30 @@ async def api_key_security(
|
|||
|
||||
|
||||
async def get_current_user(
|
||||
token: str = Security(oauth2_login),
|
||||
query_param: str = Security(api_key_query),
|
||||
header_param: str = Security(api_key_header),
|
||||
db: Session = Depends(get_session),
|
||||
) -> User:
|
||||
if token:
|
||||
return await get_current_user_by_jwt(token, db)
|
||||
else:
|
||||
if not query_param and not header_param:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="An API key must be passed as query or header",
|
||||
)
|
||||
user = await api_key_security(query_param, header_param, db)
|
||||
if user:
|
||||
return user
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid or missing API key",
|
||||
)
|
||||
|
||||
|
||||
async def get_current_user_by_jwt(
|
||||
token: Annotated[str, Depends(oauth2_login)],
|
||||
db: Session = Depends(get_session),
|
||||
) -> User:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue