From 0a3921dd2d5f97809cf7375bf3d3184137c50ae9 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 5 Dec 2023 18:16:02 -0300 Subject: [PATCH] Refactor credential retrieval and listing methods --- .../langflow/services/credentials/service.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/backend/langflow/services/credentials/service.py b/src/backend/langflow/services/credentials/service.py index 1d37fc6b9..3b7810419 100644 --- a/src/backend/langflow/services/credentials/service.py +++ b/src/backend/langflow/services/credentials/service.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING, Optional, Union from uuid import UUID from fastapi import Depends @@ -6,7 +6,7 @@ from langflow.services.auth import utils as auth_utils from langflow.services.base import Service from langflow.services.database.models.credential.model import Credential from langflow.services.deps import get_session -from sqlmodel import Session +from sqlmodel import Session, select if TYPE_CHECKING: from langflow.services.settings.service import SettingsService @@ -20,13 +20,14 @@ class CredentialService(Service): def get_credential(self, user_id: Union[UUID, str], name: str, session: Session = Depends(get_session)) -> str: # we get the credential from the database - credential = session.query(Credential).filter(Credential.user_id == user_id, Credential.name == name).first() + # credential = session.query(Credential).filter(Credential.user_id == user_id, Credential.name == name).first() + credential = session.exec(select(Credential).where(Credential.user_id == user_id, Credential.name == name)).first() # we decrypt the value - if not credential: + if not credential or not credential.value: raise ValueError(f"{name} credential not found.") decrypted = auth_utils.decrypt_api_key(credential.value, settings_service=self.settings_service) return decrypted - def list_credentials(self, user_id: Union[UUID, str], session: Session = Depends(get_session)) -> list[Credential]: - credentials = session.query(Credential).filter(Credential.user_id == user_id).all() + def list_credentials(self, user_id: Union[UUID, str], session: Session = Depends(get_session)) -> list[Optional[str]]: + credentials = session.exec(select(Credential).where(Credential.user_id == user_id)).all() return [credential.name for credential in credentials]