320 lines
9.9 KiB
Python
320 lines
9.9 KiB
Python
from datetime import datetime, timedelta, timezone
|
|
from fastapi import Depends, HTTPException, Security, status
|
|
from fastapi.security import APIKeyHeader, APIKeyQuery, OAuth2PasswordBearer
|
|
from jose import JWTError, jwt
|
|
from typing import Annotated, Coroutine, Optional, Union
|
|
from uuid import UUID
|
|
from langflow.services.database.models.api_key.api_key import ApiKey
|
|
from langflow.services.database.models.api_key.crud import check_key
|
|
from langflow.services.database.models.user.user import User
|
|
from langflow.services.database.models.user.crud import (
|
|
get_user_by_id,
|
|
get_user_by_username,
|
|
update_user_last_login_at,
|
|
)
|
|
from langflow.services.deps import get_session, get_settings_service
|
|
from sqlmodel import Session
|
|
from cryptography.fernet import Fernet
|
|
|
|
oauth2_login = OAuth2PasswordBearer(tokenUrl="api/v1/login")
|
|
|
|
API_KEY_NAME = "x-api-key"
|
|
|
|
api_key_query = APIKeyQuery(
|
|
name=API_KEY_NAME, scheme_name="API key query", auto_error=False
|
|
)
|
|
api_key_header = APIKeyHeader(
|
|
name=API_KEY_NAME, scheme_name="API key header", auto_error=False
|
|
)
|
|
|
|
|
|
# Source: https://github.com/mrtolkien/fastapi_simple_security/blob/master/fastapi_simple_security/security_api_key.py
|
|
async def api_key_security(
|
|
query_param: str = Security(api_key_query),
|
|
header_param: str = Security(api_key_header),
|
|
db: Session = Depends(get_session),
|
|
) -> Optional[User]:
|
|
settings_service = get_settings_service()
|
|
result: Optional[Union[ApiKey, User]] = None
|
|
if settings_service.auth_settings.AUTO_LOGIN:
|
|
# Get the first user
|
|
if not settings_service.auth_settings.SUPERUSER:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Missing first superuser credentials",
|
|
)
|
|
|
|
result = get_user_by_username(db, settings_service.auth_settings.SUPERUSER)
|
|
|
|
elif not query_param and not header_param:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="An API key must be passed as query or header",
|
|
)
|
|
|
|
elif query_param:
|
|
result = check_key(db, query_param)
|
|
|
|
else:
|
|
result = check_key(db, header_param)
|
|
|
|
if not result:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Invalid or missing API key",
|
|
)
|
|
if isinstance(result, ApiKey):
|
|
return result.user
|
|
elif isinstance(result, User):
|
|
return result
|
|
|
|
|
|
async def get_current_user(
|
|
token: Annotated[str, Depends(oauth2_login)],
|
|
db: Session = Depends(get_session),
|
|
) -> User:
|
|
settings_service = get_settings_service()
|
|
|
|
credentials_exception = HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Could not validate credentials",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
if isinstance(token, Coroutine):
|
|
token = await token
|
|
|
|
if settings_service.auth_settings.SECRET_KEY is None:
|
|
raise credentials_exception
|
|
|
|
try:
|
|
payload = jwt.decode(
|
|
token,
|
|
settings_service.auth_settings.SECRET_KEY,
|
|
algorithms=[settings_service.auth_settings.ALGORITHM],
|
|
)
|
|
user_id: UUID = payload.get("sub") # type: ignore
|
|
token_type: str = payload.get("type") # type: ignore
|
|
if expires := payload.get("exp", None):
|
|
expires_datetime = datetime.fromtimestamp(expires, timezone.utc)
|
|
# TypeError: can't compare offset-naive and offset-aware datetimes
|
|
if datetime.now(timezone.utc) > expires_datetime:
|
|
raise credentials_exception
|
|
|
|
if user_id is None or token_type:
|
|
raise credentials_exception
|
|
except JWTError as e:
|
|
raise credentials_exception from e
|
|
|
|
user = get_user_by_id(db, user_id) # type: ignore
|
|
if user is None or not user.is_active:
|
|
raise credentials_exception
|
|
return user
|
|
|
|
|
|
def get_current_active_user(current_user: Annotated[User, Depends(get_current_user)]):
|
|
if not current_user.is_active:
|
|
raise HTTPException(status_code=400, detail="Inactive user")
|
|
return current_user
|
|
|
|
|
|
def get_current_active_superuser(
|
|
current_user: Annotated[User, Depends(get_current_user)]
|
|
) -> User:
|
|
if not current_user.is_active:
|
|
raise HTTPException(status_code=401, detail="Inactive user")
|
|
if not current_user.is_superuser:
|
|
raise HTTPException(
|
|
status_code=400, detail="The user doesn't have enough privileges"
|
|
)
|
|
return current_user
|
|
|
|
|
|
def verify_password(plain_password, hashed_password):
|
|
settings_service = get_settings_service()
|
|
return settings_service.auth_settings.pwd_context.verify(
|
|
plain_password, hashed_password
|
|
)
|
|
|
|
|
|
def get_password_hash(password):
|
|
settings_service = get_settings_service()
|
|
return settings_service.auth_settings.pwd_context.hash(password)
|
|
|
|
|
|
def create_token(data: dict, expires_delta: timedelta):
|
|
settings_service = get_settings_service()
|
|
|
|
to_encode = data.copy()
|
|
expire = datetime.now(timezone.utc) + expires_delta
|
|
to_encode["exp"] = expire
|
|
|
|
return jwt.encode(
|
|
to_encode,
|
|
settings_service.auth_settings.SECRET_KEY,
|
|
algorithm=settings_service.auth_settings.ALGORITHM,
|
|
)
|
|
|
|
|
|
def create_super_user(
|
|
username: str,
|
|
password: str,
|
|
db: Session = Depends(get_session),
|
|
) -> User:
|
|
super_user = get_user_by_username(db, username)
|
|
|
|
if not super_user:
|
|
super_user = User(
|
|
username=username,
|
|
password=get_password_hash(password),
|
|
is_superuser=True,
|
|
is_active=True,
|
|
last_login_at=None,
|
|
)
|
|
|
|
db.add(super_user)
|
|
db.commit()
|
|
db.refresh(super_user)
|
|
|
|
return super_user
|
|
|
|
|
|
def create_user_longterm_token(db: Session = Depends(get_session)) -> dict:
|
|
settings_service = get_settings_service()
|
|
username = settings_service.auth_settings.SUPERUSER
|
|
password = settings_service.auth_settings.SUPERUSER_PASSWORD
|
|
if not username or not password:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Missing first superuser credentials",
|
|
)
|
|
super_user = create_super_user(db=db, username=username, password=password)
|
|
|
|
access_token_expires_longterm = timedelta(days=365)
|
|
access_token = create_token(
|
|
data={"sub": str(super_user.id)},
|
|
expires_delta=access_token_expires_longterm,
|
|
)
|
|
|
|
# Update: last_login_at
|
|
update_user_last_login_at(super_user.id, db)
|
|
|
|
return {
|
|
"access_token": access_token,
|
|
"refresh_token": None,
|
|
"token_type": "bearer",
|
|
}
|
|
|
|
|
|
def create_user_api_key(user_id: UUID) -> dict:
|
|
access_token = create_token(
|
|
data={"sub": str(user_id), "role": "api_key"},
|
|
expires_delta=timedelta(days=365 * 2),
|
|
)
|
|
|
|
return {"api_key": access_token}
|
|
|
|
|
|
def get_user_id_from_token(token: str) -> UUID:
|
|
try:
|
|
user_id = jwt.get_unverified_claims(token)["sub"]
|
|
return UUID(user_id)
|
|
except (KeyError, JWTError, ValueError):
|
|
return UUID(int=0)
|
|
|
|
|
|
def create_user_tokens(
|
|
user_id: UUID, db: Session = Depends(get_session), update_last_login: bool = False
|
|
) -> dict:
|
|
settings_service = get_settings_service()
|
|
|
|
access_token_expires = timedelta(
|
|
minutes=settings_service.auth_settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
|
)
|
|
access_token = create_token(
|
|
data={"sub": str(user_id)},
|
|
expires_delta=access_token_expires,
|
|
)
|
|
|
|
refresh_token_expires = timedelta(
|
|
minutes=settings_service.auth_settings.REFRESH_TOKEN_EXPIRE_MINUTES
|
|
)
|
|
refresh_token = create_token(
|
|
data={"sub": str(user_id), "type": "rf"},
|
|
expires_delta=refresh_token_expires,
|
|
)
|
|
|
|
# Update: last_login_at
|
|
if update_last_login:
|
|
update_user_last_login_at(user_id, db)
|
|
|
|
return {
|
|
"access_token": access_token,
|
|
"refresh_token": refresh_token,
|
|
"token_type": "bearer",
|
|
}
|
|
|
|
|
|
def create_refresh_token(refresh_token: str, db: Session = Depends(get_session)):
|
|
settings_service = get_settings_service()
|
|
|
|
try:
|
|
payload = jwt.decode(
|
|
refresh_token,
|
|
settings_service.auth_settings.SECRET_KEY,
|
|
algorithms=[settings_service.auth_settings.ALGORITHM],
|
|
)
|
|
user_id: UUID = payload.get("sub") # type: ignore
|
|
token_type: str = payload.get("type") # type: ignore
|
|
|
|
if user_id is None or token_type is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token"
|
|
)
|
|
|
|
return create_user_tokens(user_id, db)
|
|
|
|
except JWTError as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid refresh token",
|
|
) from e
|
|
|
|
|
|
def authenticate_user(
|
|
username: str, password: str, db: Session = Depends(get_session)
|
|
) -> Optional[User]:
|
|
user = get_user_by_username(db, username)
|
|
|
|
if not user:
|
|
return None
|
|
|
|
if not user.is_active:
|
|
if not user.last_login_at:
|
|
raise HTTPException(status_code=400, detail="Waiting for approval")
|
|
raise HTTPException(status_code=400, detail="Inactive user")
|
|
|
|
return user if verify_password(password, user.password) else None
|
|
|
|
|
|
def get_fernet(settings_service=Depends(get_settings_service)):
|
|
SECRET_KEY = settings_service.auth_settings.SECRET_KEY
|
|
# It's important that your secret key is 32 url-safe base64-encoded bytes
|
|
fernet = Fernet(SECRET_KEY)
|
|
return fernet
|
|
|
|
|
|
def encrypt_api_key(api_key: str, settings_service=Depends(get_settings_service)):
|
|
fernet = get_fernet(settings_service)
|
|
# Two-way encryption
|
|
encrypted_key = fernet.encrypt(api_key.encode())
|
|
return encrypted_key
|
|
|
|
|
|
def decrypt_api_key(
|
|
encrypted_api_key: str, settings_service=Depends(get_settings_service)
|
|
):
|
|
fernet = get_fernet(settings_service)
|
|
# Two-way decryption
|
|
decrypted_key = fernet.decrypt(encrypted_api_key.encode()).decode()
|
|
return decrypted_key
|