feat: return variable value if it is a generic variable (#5366)

* fix: update pre-commit configuration for ruff formatting args

* fix: update variable type constant from GENERIC_TYPE to CREDENTIAL_TYPE

* feat: enhance variable service to handle decryption for generic type and update default variable type to CREDENTIAL_TYPE

* feat: add abstract method to retrieve all variables for a user in VariableService

* feat: implement get_all method in KubernetesSecretService to retrieve and decrypt user variables

* refactor: update variable tests to use fixtures for generic and credential types

- Renamed and refactored test fixtures for better clarity and reusability.
- Updated tests to utilize `generic_variable` and `credential_variable` fixtures instead of hardcoded values.
- Enhanced assertions to ensure correct handling of variable types, including encryption for credential variables and decryption for generic variables.
- Improved test structure for creating, reading, updating, and deleting variables, ensuring consistency across test cases.

* refactor: update get_all method signatures in variable services to return VariableRead

- Changed return type of the get_all method in VariableService, KubernetesSecretService, and DatabaseVariableService from list[Variable | None] to list[VariableRead].
- This update enhances type consistency across variable services and aligns with the new VariableRead model for improved data handling.

* fix: update variable type assertion in test_update_variable to CREDENTIAL_TYPE

- Changed the assertion in the `test_update_variable` test to verify that the result type is now `CREDENTIAL_TYPE` instead of `GENERIC_TYPE`.
- This update aligns the test with recent changes in variable type handling, ensuring accurate validation of variable updates.

* fix: update variable type assertion in test_create_variable to CREDENTIAL_TYPE

- Changed the assertion in the `test_create_variable` test to verify that the result type is now `CREDENTIAL_TYPE` instead of `GENERIC_TYPE`.
- This update ensures consistency with recent changes in variable type handling and improves the accuracy of the test validation.
This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-12-19 14:04:00 -03:00 committed by GitHub
commit dd68a97567
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 143 additions and 58 deletions

View file

@ -6,7 +6,7 @@ from sqlalchemy.exc import NoResultFound
from langflow.api.utils import CurrentActiveUser, DbSession
from langflow.services.database.models.variable import VariableCreate, VariableRead, VariableUpdate
from langflow.services.deps import get_variable_service
from langflow.services.variable.constants import GENERIC_TYPE
from langflow.services.variable.constants import CREDENTIAL_TYPE
from langflow.services.variable.service import DatabaseVariableService
router = APIRouter(prefix="/variables", tags=["Variables"])
@ -38,7 +38,7 @@ async def create_variable(
name=variable.name,
value=variable.value,
default_fields=variable.default_fields or [],
type_=variable.type or GENERIC_TYPE,
type_=variable.type or CREDENTIAL_TYPE,
session=session,
)
except Exception as e:

View file

@ -4,7 +4,7 @@ from uuid import UUID
from sqlmodel.ext.asyncio.session import AsyncSession
from langflow.services.base import Service
from langflow.services.database.models.variable.model import Variable
from langflow.services.database.models.variable.model import Variable, VariableRead
class VariableService(Service):
@ -108,3 +108,12 @@ class VariableService(Service):
Returns:
The created variable.
"""
@abc.abstractmethod
async def get_all(self, user_id: UUID | str, session: AsyncSession) -> list[VariableRead]:
"""Get all variables.
Args:
user_id: The user ID.
session: The database session.
"""

View file

@ -9,7 +9,7 @@ from typing_extensions import override
from langflow.services.auth import utils as auth_utils
from langflow.services.base import Service
from langflow.services.database.models.variable.model import Variable, VariableCreate
from langflow.services.database.models.variable.model import Variable, VariableCreate, VariableRead
from langflow.services.variable.base import VariableService
from langflow.services.variable.constants import CREDENTIAL_TYPE, GENERIC_TYPE
from langflow.services.variable.kubernetes_secrets import KubernetesSecretManager, encode_user_id
@ -170,3 +170,35 @@ class KubernetesSecretService(VariableService, Service):
default_fields=default_fields,
)
return Variable.model_validate(variable_base, from_attributes=True, update={"user_id": user_id})
@override
async def get_all(self, user_id: UUID | str, session: AsyncSession) -> list[VariableRead]:
secret_name = encode_user_id(user_id)
variables = await asyncio.to_thread(self.kubernetes_secrets.get_secret, name=secret_name)
if not variables:
return []
variables_read = []
for key, value in variables.items():
name = key
type_ = GENERIC_TYPE
if key.startswith(CREDENTIAL_TYPE + "_"):
name = key[len(CREDENTIAL_TYPE) + 1 :]
type_ = CREDENTIAL_TYPE
decrypted_value = None
if type_ == GENERIC_TYPE:
decrypted_value = value
variable_base = VariableCreate(
name=name,
type=type_,
value=auth_utils.encrypt_api_key(value, settings_service=self.settings_service),
default_fields=[],
)
variable = Variable.model_validate(variable_base, from_attributes=True, update={"user_id": user_id})
variable_read = VariableRead.model_validate(variable, from_attributes=True)
variable_read.value = decrypted_value
variables_read.append(variable_read)
return variables_read

View file

@ -9,7 +9,7 @@ from sqlmodel import select
from langflow.services.auth import utils as auth_utils
from langflow.services.base import Service
from langflow.services.database.models.variable.model import Variable, VariableCreate, VariableUpdate
from langflow.services.database.models.variable.model import Variable, VariableCreate, VariableRead, VariableUpdate
from langflow.services.variable.base import VariableService
from langflow.services.variable.constants import CREDENTIAL_TYPE, GENERIC_TYPE
@ -79,9 +79,20 @@ class DatabaseVariableService(VariableService, Service):
# we decrypt the value
return auth_utils.decrypt_api_key(variable.value, settings_service=self.settings_service)
async def get_all(self, user_id: UUID | str, session: AsyncSession) -> list[Variable | None]:
async def get_all(self, user_id: UUID | str, session: AsyncSession) -> list[VariableRead]:
stmt = select(Variable).where(Variable.user_id == user_id)
return list((await session.exec(stmt)).all())
variables = list((await session.exec(stmt)).all())
# If the variable is of type 'Generic' we decrypt the value
variables_read = []
for variable in variables:
value = None
if variable.type == GENERIC_TYPE:
value = auth_utils.decrypt_api_key(variable.value, settings_service=self.settings_service)
variable_read = VariableRead.model_validate(variable, from_attributes=True)
variable_read.value = value
variables_read.append(variable_read)
return variables_read
async def list_variables(self, user_id: UUID | str, session: AsyncSession) -> list[str | None]:
variables = await self.get_all(user_id=user_id, session=session)
@ -160,7 +171,7 @@ class DatabaseVariableService(VariableService, Service):
value: str,
*,
default_fields: Sequence[str] = (),
type_: str = GENERIC_TYPE,
type_: str = CREDENTIAL_TYPE,
session: AsyncSession,
):
variable_base = VariableCreate(