ref: Remove useless fastapi Depends (#4217)
Remove useless fastapi Depends
This commit is contained in:
parent
d48ec86121
commit
99bcaab9d8
5 changed files with 25 additions and 23 deletions
|
|
@ -19,6 +19,7 @@ from langflow.services.database.models.api_key.model import ApiKey
|
|||
from langflow.services.database.models.user.crud import get_user_by_id, get_user_by_username, update_user_last_login_at
|
||||
from langflow.services.database.models.user.model import User, UserRead
|
||||
from langflow.services.deps import get_session, get_settings_service
|
||||
from langflow.services.settings.service import SettingsService
|
||||
|
||||
oauth2_login = OAuth2PasswordBearer(tokenUrl="api/v1/login", auto_error=False)
|
||||
|
||||
|
|
@ -207,7 +208,7 @@ def create_token(data: dict, expires_delta: timedelta):
|
|||
def create_super_user(
|
||||
username: str,
|
||||
password: str,
|
||||
db: Session = Depends(get_session),
|
||||
db: Session,
|
||||
) -> User:
|
||||
super_user = get_user_by_username(db, username)
|
||||
|
||||
|
|
@ -227,7 +228,7 @@ def create_super_user(
|
|||
return super_user
|
||||
|
||||
|
||||
def create_user_longterm_token(db: Session = Depends(get_session)) -> tuple[UUID, dict]:
|
||||
def create_user_longterm_token(db: Session) -> tuple[UUID, dict]:
|
||||
settings_service = get_settings_service()
|
||||
|
||||
username = settings_service.auth_settings.SUPERUSER
|
||||
|
|
@ -267,7 +268,7 @@ def get_user_id_from_token(token: str) -> UUID:
|
|||
return UUID(int=0)
|
||||
|
||||
|
||||
def create_user_tokens(user_id: UUID, db: Session = Depends(get_session), *, update_last_login: bool = False) -> dict:
|
||||
def create_user_tokens(user_id: UUID, db: Session, *, update_last_login: bool = False) -> dict:
|
||||
settings_service = get_settings_service()
|
||||
|
||||
access_token_expires = timedelta(seconds=settings_service.auth_settings.ACCESS_TOKEN_EXPIRE_SECONDS)
|
||||
|
|
@ -293,7 +294,7 @@ def create_user_tokens(user_id: UUID, db: Session = Depends(get_session), *, upd
|
|||
}
|
||||
|
||||
|
||||
def create_refresh_token(refresh_token: str, db: Session = Depends(get_session)):
|
||||
def create_refresh_token(refresh_token: str, db: Session):
|
||||
settings_service = get_settings_service()
|
||||
|
||||
try:
|
||||
|
|
@ -326,7 +327,7 @@ def create_refresh_token(refresh_token: str, db: Session = Depends(get_session))
|
|||
) from e
|
||||
|
||||
|
||||
def authenticate_user(username: str, password: str, db: Session = Depends(get_session)) -> User | None:
|
||||
def authenticate_user(username: str, password: str, db: Session) -> User | None:
|
||||
user = get_user_by_username(db, username)
|
||||
|
||||
if not user:
|
||||
|
|
@ -359,20 +360,20 @@ def ensure_valid_key(s: str) -> bytes:
|
|||
return key
|
||||
|
||||
|
||||
def get_fernet(settings_service=Depends(get_settings_service)):
|
||||
def get_fernet(settings_service: SettingsService):
|
||||
secret_key: str = settings_service.auth_settings.SECRET_KEY.get_secret_value()
|
||||
valid_key = ensure_valid_key(secret_key)
|
||||
return Fernet(valid_key)
|
||||
|
||||
|
||||
def encrypt_api_key(api_key: str, settings_service=Depends(get_settings_service)):
|
||||
def encrypt_api_key(api_key: str, settings_service: SettingsService):
|
||||
fernet = get_fernet(settings_service)
|
||||
# Two-way encryption
|
||||
encrypted_key = fernet.encrypt(api_key.encode())
|
||||
return encrypted_key.decode()
|
||||
|
||||
|
||||
def decrypt_api_key(encrypted_api_key: str, settings_service=Depends(get_settings_service)):
|
||||
def decrypt_api_key(encrypted_api_key: str, settings_service: SettingsService):
|
||||
fernet = get_fernet(settings_service)
|
||||
decrypted_key = ""
|
||||
# Two-way decryption
|
||||
|
|
|
|||
|
|
@ -1,14 +1,13 @@
|
|||
from datetime import datetime, timezone
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi import HTTPException, status
|
||||
from loguru import logger
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from langflow.services.database.models.user.model import User, UserUpdate
|
||||
from langflow.services.deps import get_session
|
||||
|
||||
|
||||
def get_user_by_username(db: Session, username: str) -> User | None:
|
||||
|
|
@ -19,7 +18,7 @@ def get_user_by_id(db: Session, user_id: UUID) -> User | None:
|
|||
return db.exec(select(User).where(User.id == user_id)).first()
|
||||
|
||||
|
||||
def update_user(user_db: User | None, user: UserUpdate, db: Session = Depends(get_session)) -> User:
|
||||
def update_user(user_db: User | None, user: UserUpdate, db: Session) -> User:
|
||||
if not user_db:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
|
|
@ -49,7 +48,7 @@ def update_user(user_db: User | None, user: UserUpdate, db: Session = Depends(ge
|
|||
return user_db
|
||||
|
||||
|
||||
def update_user_last_login_at(user_id: UUID, db: Session = Depends(get_session)):
|
||||
def update_user_last_login_at(user_id: UUID, db: Session):
|
||||
try:
|
||||
user_data = UserUpdate(last_login_at=datetime.now(timezone.utc))
|
||||
user = get_user_by_id(db, user_id)
|
||||
|
|
|
|||
|
|
@ -90,6 +90,7 @@ class VariableService(Service):
|
|||
user_id: UUID | str,
|
||||
name: str,
|
||||
value: str,
|
||||
*,
|
||||
default_fields: list[str],
|
||||
_type: str,
|
||||
session: Session,
|
||||
|
|
|
|||
|
|
@ -129,14 +129,16 @@ class KubernetesSecretService(VariableService, Service):
|
|||
def delete_variable_by_id(self, user_id: UUID | str, variable_id: UUID | str, _session: Session) -> None:
|
||||
self.delete_variable(user_id, _session, str(variable_id))
|
||||
|
||||
@override
|
||||
def create_variable(
|
||||
self,
|
||||
user_id: UUID | str,
|
||||
name: str,
|
||||
value: str,
|
||||
*,
|
||||
default_fields: list[str],
|
||||
_type: str,
|
||||
_session: Session,
|
||||
session: Session,
|
||||
) -> Variable:
|
||||
secret_name = encode_user_id(user_id)
|
||||
secret_key = name
|
||||
|
|
|
|||
|
|
@ -4,14 +4,12 @@ import os
|
|||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import Depends
|
||||
from loguru import logger
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from langflow.services.auth import utils as auth_utils
|
||||
from langflow.services.base import Service
|
||||
from langflow.services.database.models.variable.model import Variable, VariableCreate, VariableUpdate
|
||||
from langflow.services.deps import get_session
|
||||
from langflow.services.variable.base import VariableService
|
||||
from langflow.services.variable.constants import CREDENTIAL_TYPE, GENERIC_TYPE
|
||||
|
||||
|
|
@ -26,7 +24,7 @@ class DatabaseVariableService(VariableService, Service):
|
|||
def __init__(self, settings_service: SettingsService):
|
||||
self.settings_service = settings_service
|
||||
|
||||
def initialize_user_variables(self, user_id: UUID | str, session: Session = Depends(get_session)):
|
||||
def initialize_user_variables(self, user_id: UUID | str, session: Session):
|
||||
if not self.settings_service.settings.store_environment_variables:
|
||||
logger.info("Skipping environment variable storage.")
|
||||
return
|
||||
|
|
@ -58,7 +56,7 @@ class DatabaseVariableService(VariableService, Service):
|
|||
user_id: UUID | str,
|
||||
name: str,
|
||||
field: str,
|
||||
session: Session = Depends(get_session),
|
||||
session: Session,
|
||||
) -> str:
|
||||
# we get the credential from the database
|
||||
# credential = session.query(Variable).filter(Variable.user_id == user_id, Variable.name == name).first()
|
||||
|
|
@ -78,10 +76,10 @@ class DatabaseVariableService(VariableService, Service):
|
|||
# we decrypt the value
|
||||
return auth_utils.decrypt_api_key(variable.value, settings_service=self.settings_service)
|
||||
|
||||
def get_all(self, user_id: UUID | str, session: Session = Depends(get_session)) -> list[Variable | None]:
|
||||
def get_all(self, user_id: UUID | str, session: Session) -> list[Variable | None]:
|
||||
return list(session.exec(select(Variable).where(Variable.user_id == user_id)).all())
|
||||
|
||||
def list_variables(self, user_id: UUID | str, session: Session = Depends(get_session)) -> list[str | None]:
|
||||
def list_variables(self, user_id: UUID | str, session: Session) -> list[str | None]:
|
||||
variables = self.get_all(user_id=user_id, session=session)
|
||||
return [variable.name for variable in variables if variable]
|
||||
|
||||
|
|
@ -90,7 +88,7 @@ class DatabaseVariableService(VariableService, Service):
|
|||
user_id: UUID | str,
|
||||
name: str,
|
||||
value: str,
|
||||
session: Session = Depends(get_session),
|
||||
session: Session,
|
||||
):
|
||||
variable = session.exec(select(Variable).where(Variable.user_id == user_id, Variable.name == name)).first()
|
||||
if not variable:
|
||||
|
|
@ -108,7 +106,7 @@ class DatabaseVariableService(VariableService, Service):
|
|||
user_id: UUID | str,
|
||||
variable_id: UUID | str,
|
||||
variable: VariableUpdate,
|
||||
session: Session = Depends(get_session),
|
||||
session: Session,
|
||||
):
|
||||
query = select(Variable).where(Variable.id == variable_id, Variable.user_id == user_id)
|
||||
db_variable = session.exec(query).one()
|
||||
|
|
@ -131,7 +129,7 @@ class DatabaseVariableService(VariableService, Service):
|
|||
self,
|
||||
user_id: UUID | str,
|
||||
name: str,
|
||||
session: Session = Depends(get_session),
|
||||
session: Session,
|
||||
):
|
||||
stmt = select(Variable).where(Variable.user_id == user_id).where(Variable.name == name)
|
||||
variable = session.exec(stmt).first()
|
||||
|
|
@ -154,9 +152,10 @@ class DatabaseVariableService(VariableService, Service):
|
|||
user_id: UUID | str,
|
||||
name: str,
|
||||
value: str,
|
||||
*,
|
||||
default_fields: Sequence[str] = (),
|
||||
_type: str = GENERIC_TYPE,
|
||||
session: Session = Depends(get_session),
|
||||
session: Session,
|
||||
):
|
||||
variable_base = VariableCreate(
|
||||
name=name,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue