ref: Remove useless fastapi Depends (#4217)

Remove useless fastapi Depends
This commit is contained in:
Christophe Bornet 2024-10-21 22:53:17 +02:00 committed by GitHub
commit 99bcaab9d8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 25 additions and 23 deletions

View file

@ -19,6 +19,7 @@ from langflow.services.database.models.api_key.model import ApiKey
from langflow.services.database.models.user.crud import get_user_by_id, get_user_by_username, update_user_last_login_at
from langflow.services.database.models.user.model import User, UserRead
from langflow.services.deps import get_session, get_settings_service
from langflow.services.settings.service import SettingsService
oauth2_login = OAuth2PasswordBearer(tokenUrl="api/v1/login", auto_error=False)
@ -207,7 +208,7 @@ def create_token(data: dict, expires_delta: timedelta):
def create_super_user(
username: str,
password: str,
db: Session = Depends(get_session),
db: Session,
) -> User:
super_user = get_user_by_username(db, username)
@ -227,7 +228,7 @@ def create_super_user(
return super_user
def create_user_longterm_token(db: Session = Depends(get_session)) -> tuple[UUID, dict]:
def create_user_longterm_token(db: Session) -> tuple[UUID, dict]:
settings_service = get_settings_service()
username = settings_service.auth_settings.SUPERUSER
@ -267,7 +268,7 @@ def get_user_id_from_token(token: str) -> UUID:
return UUID(int=0)
def create_user_tokens(user_id: UUID, db: Session = Depends(get_session), *, update_last_login: bool = False) -> dict:
def create_user_tokens(user_id: UUID, db: Session, *, update_last_login: bool = False) -> dict:
settings_service = get_settings_service()
access_token_expires = timedelta(seconds=settings_service.auth_settings.ACCESS_TOKEN_EXPIRE_SECONDS)
@ -293,7 +294,7 @@ def create_user_tokens(user_id: UUID, db: Session = Depends(get_session), *, upd
}
def create_refresh_token(refresh_token: str, db: Session = Depends(get_session)):
def create_refresh_token(refresh_token: str, db: Session):
settings_service = get_settings_service()
try:
@ -326,7 +327,7 @@ def create_refresh_token(refresh_token: str, db: Session = Depends(get_session))
) from e
def authenticate_user(username: str, password: str, db: Session = Depends(get_session)) -> User | None:
def authenticate_user(username: str, password: str, db: Session) -> User | None:
user = get_user_by_username(db, username)
if not user:
@ -359,20 +360,20 @@ def ensure_valid_key(s: str) -> bytes:
return key
def get_fernet(settings_service=Depends(get_settings_service)):
def get_fernet(settings_service: SettingsService):
secret_key: str = settings_service.auth_settings.SECRET_KEY.get_secret_value()
valid_key = ensure_valid_key(secret_key)
return Fernet(valid_key)
def encrypt_api_key(api_key: str, settings_service=Depends(get_settings_service)):
def encrypt_api_key(api_key: str, settings_service: SettingsService):
fernet = get_fernet(settings_service)
# Two-way encryption
encrypted_key = fernet.encrypt(api_key.encode())
return encrypted_key.decode()
def decrypt_api_key(encrypted_api_key: str, settings_service=Depends(get_settings_service)):
def decrypt_api_key(encrypted_api_key: str, settings_service: SettingsService):
fernet = get_fernet(settings_service)
decrypted_key = ""
# Two-way decryption

View file

@ -1,14 +1,13 @@
from datetime import datetime, timezone
from uuid import UUID
from fastapi import Depends, HTTPException, status
from fastapi import HTTPException, status
from loguru import logger
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm.attributes import flag_modified
from sqlmodel import Session, select
from langflow.services.database.models.user.model import User, UserUpdate
from langflow.services.deps import get_session
def get_user_by_username(db: Session, username: str) -> User | None:
@ -19,7 +18,7 @@ def get_user_by_id(db: Session, user_id: UUID) -> User | None:
return db.exec(select(User).where(User.id == user_id)).first()
def update_user(user_db: User | None, user: UserUpdate, db: Session = Depends(get_session)) -> User:
def update_user(user_db: User | None, user: UserUpdate, db: Session) -> User:
if not user_db:
raise HTTPException(status_code=404, detail="User not found")
@ -49,7 +48,7 @@ def update_user(user_db: User | None, user: UserUpdate, db: Session = Depends(ge
return user_db
def update_user_last_login_at(user_id: UUID, db: Session = Depends(get_session)):
def update_user_last_login_at(user_id: UUID, db: Session):
try:
user_data = UserUpdate(last_login_at=datetime.now(timezone.utc))
user = get_user_by_id(db, user_id)

View file

@ -90,6 +90,7 @@ class VariableService(Service):
user_id: UUID | str,
name: str,
value: str,
*,
default_fields: list[str],
_type: str,
session: Session,

View file

@ -129,14 +129,16 @@ class KubernetesSecretService(VariableService, Service):
def delete_variable_by_id(self, user_id: UUID | str, variable_id: UUID | str, _session: Session) -> None:
self.delete_variable(user_id, _session, str(variable_id))
@override
def create_variable(
self,
user_id: UUID | str,
name: str,
value: str,
*,
default_fields: list[str],
_type: str,
_session: Session,
session: Session,
) -> Variable:
secret_name = encode_user_id(user_id)
secret_key = name

View file

@ -4,14 +4,12 @@ import os
from datetime import datetime, timezone
from typing import TYPE_CHECKING
from fastapi import Depends
from loguru import logger
from sqlmodel import Session, select
from langflow.services.auth import utils as auth_utils
from langflow.services.base import Service
from langflow.services.database.models.variable.model import Variable, VariableCreate, VariableUpdate
from langflow.services.deps import get_session
from langflow.services.variable.base import VariableService
from langflow.services.variable.constants import CREDENTIAL_TYPE, GENERIC_TYPE
@ -26,7 +24,7 @@ class DatabaseVariableService(VariableService, Service):
def __init__(self, settings_service: SettingsService):
self.settings_service = settings_service
def initialize_user_variables(self, user_id: UUID | str, session: Session = Depends(get_session)):
def initialize_user_variables(self, user_id: UUID | str, session: Session):
if not self.settings_service.settings.store_environment_variables:
logger.info("Skipping environment variable storage.")
return
@ -58,7 +56,7 @@ class DatabaseVariableService(VariableService, Service):
user_id: UUID | str,
name: str,
field: str,
session: Session = Depends(get_session),
session: Session,
) -> str:
# we get the credential from the database
# credential = session.query(Variable).filter(Variable.user_id == user_id, Variable.name == name).first()
@ -78,10 +76,10 @@ class DatabaseVariableService(VariableService, Service):
# we decrypt the value
return auth_utils.decrypt_api_key(variable.value, settings_service=self.settings_service)
def get_all(self, user_id: UUID | str, session: Session = Depends(get_session)) -> list[Variable | None]:
def get_all(self, user_id: UUID | str, session: Session) -> list[Variable | None]:
return list(session.exec(select(Variable).where(Variable.user_id == user_id)).all())
def list_variables(self, user_id: UUID | str, session: Session = Depends(get_session)) -> list[str | None]:
def list_variables(self, user_id: UUID | str, session: Session) -> list[str | None]:
variables = self.get_all(user_id=user_id, session=session)
return [variable.name for variable in variables if variable]
@ -90,7 +88,7 @@ class DatabaseVariableService(VariableService, Service):
user_id: UUID | str,
name: str,
value: str,
session: Session = Depends(get_session),
session: Session,
):
variable = session.exec(select(Variable).where(Variable.user_id == user_id, Variable.name == name)).first()
if not variable:
@ -108,7 +106,7 @@ class DatabaseVariableService(VariableService, Service):
user_id: UUID | str,
variable_id: UUID | str,
variable: VariableUpdate,
session: Session = Depends(get_session),
session: Session,
):
query = select(Variable).where(Variable.id == variable_id, Variable.user_id == user_id)
db_variable = session.exec(query).one()
@ -131,7 +129,7 @@ class DatabaseVariableService(VariableService, Service):
self,
user_id: UUID | str,
name: str,
session: Session = Depends(get_session),
session: Session,
):
stmt = select(Variable).where(Variable.user_id == user_id).where(Variable.name == name)
variable = session.exec(stmt).first()
@ -154,9 +152,10 @@ class DatabaseVariableService(VariableService, Service):
user_id: UUID | str,
name: str,
value: str,
*,
default_fields: Sequence[str] = (),
_type: str = GENERIC_TYPE,
session: Session = Depends(get_session),
session: Session,
):
variable_base = VariableCreate(
name=name,