Refactor credential retrieval and listing methods
This commit is contained in:
parent
b1a4957e20
commit
0a3921dd2d
1 changed files with 7 additions and 6 deletions
|
|
@ -1,4 +1,4 @@
|
|||
from typing import TYPE_CHECKING, Union
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends
|
||||
|
|
@ -6,7 +6,7 @@ from langflow.services.auth import utils as auth_utils
|
|||
from langflow.services.base import Service
|
||||
from langflow.services.database.models.credential.model import Credential
|
||||
from langflow.services.deps import get_session
|
||||
from sqlmodel import Session
|
||||
from sqlmodel import Session, select
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.services.settings.service import SettingsService
|
||||
|
|
@ -20,13 +20,14 @@ class CredentialService(Service):
|
|||
|
||||
def get_credential(self, user_id: Union[UUID, str], name: str, session: Session = Depends(get_session)) -> str:
|
||||
# we get the credential from the database
|
||||
credential = session.query(Credential).filter(Credential.user_id == user_id, Credential.name == name).first()
|
||||
# credential = session.query(Credential).filter(Credential.user_id == user_id, Credential.name == name).first()
|
||||
credential = session.exec(select(Credential).where(Credential.user_id == user_id, Credential.name == name)).first()
|
||||
# we decrypt the value
|
||||
if not credential:
|
||||
if not credential or not credential.value:
|
||||
raise ValueError(f"{name} credential not found.")
|
||||
decrypted = auth_utils.decrypt_api_key(credential.value, settings_service=self.settings_service)
|
||||
return decrypted
|
||||
|
||||
def list_credentials(self, user_id: Union[UUID, str], session: Session = Depends(get_session)) -> list[Credential]:
|
||||
credentials = session.query(Credential).filter(Credential.user_id == user_id).all()
|
||||
def list_credentials(self, user_id: Union[UUID, str], session: Session = Depends(get_session)) -> list[Optional[str]]:
|
||||
credentials = session.exec(select(Credential).where(Credential.user_id == user_id)).all()
|
||||
return [credential.name for credential in credentials]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue