feat: return variable value if it is a generic variable (#5366)

* fix: update pre-commit configuration for ruff formatting args

* fix: update variable type constant from GENERIC_TYPE to CREDENTIAL_TYPE

* feat: enhance variable service to handle decryption for generic type and update default variable type to CREDENTIAL_TYPE

* feat: add abstract method to retrieve all variables for a user in VariableService

* feat: implement get_all method in KubernetesSecretService to retrieve and decrypt user variables

* refactor: update variable tests to use fixtures for generic and credential types

- Renamed and refactored test fixtures for better clarity and reusability.
- Updated tests to utilize `generic_variable` and `credential_variable` fixtures instead of hardcoded values.
- Enhanced assertions to ensure correct handling of variable types, including encryption for credential variables and decryption for generic variables.
- Improved test structure for creating, reading, updating, and deleting variables, ensuring consistency across test cases.

* refactor: update get_all method signatures in variable services to return VariableRead

- Changed return type of the get_all method in VariableService, KubernetesSecretService, and DatabaseVariableService from list[Variable | None] to list[VariableRead].
- This update enhances type consistency across variable services and aligns with the new VariableRead model for improved data handling.

* fix: update variable type assertion in test_update_variable to CREDENTIAL_TYPE

- Changed the assertion in the `test_update_variable` test to verify that the result type is now `CREDENTIAL_TYPE` instead of `GENERIC_TYPE`.
- This update aligns the test with recent changes in variable type handling, ensuring accurate validation of variable updates.

* fix: update variable type assertion in test_create_variable to CREDENTIAL_TYPE

- Changed the assertion in the `test_create_variable` test to verify that the result type is now `CREDENTIAL_TYPE` instead of `GENERIC_TYPE`.
- This update ensures consistency with recent changes in variable type handling and improves the accuracy of the test validation.
This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-12-19 14:04:00 -03:00 committed by GitHub
commit dd68a97567
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 143 additions and 58 deletions

View file

@ -20,4 +20,4 @@ repos:
args: [--fix]
- id: ruff-format
types_or: [python, pyi]
args: [--config pyproject.toml]
args: [--config, pyproject.toml]

View file

@ -6,7 +6,7 @@ from sqlalchemy.exc import NoResultFound
from langflow.api.utils import CurrentActiveUser, DbSession
from langflow.services.database.models.variable import VariableCreate, VariableRead, VariableUpdate
from langflow.services.deps import get_variable_service
from langflow.services.variable.constants import GENERIC_TYPE
from langflow.services.variable.constants import CREDENTIAL_TYPE
from langflow.services.variable.service import DatabaseVariableService
router = APIRouter(prefix="/variables", tags=["Variables"])
@ -38,7 +38,7 @@ async def create_variable(
name=variable.name,
value=variable.value,
default_fields=variable.default_fields or [],
type_=variable.type or GENERIC_TYPE,
type_=variable.type or CREDENTIAL_TYPE,
session=session,
)
except Exception as e:

View file

@ -4,7 +4,7 @@ from uuid import UUID
from sqlmodel.ext.asyncio.session import AsyncSession
from langflow.services.base import Service
from langflow.services.database.models.variable.model import Variable
from langflow.services.database.models.variable.model import Variable, VariableRead
class VariableService(Service):
@ -108,3 +108,12 @@ class VariableService(Service):
Returns:
The created variable.
"""
@abc.abstractmethod
async def get_all(self, user_id: UUID | str, session: AsyncSession) -> list[VariableRead]:
"""Get all variables.
Args:
user_id: The user ID.
session: The database session.
"""

View file

@ -9,7 +9,7 @@ from typing_extensions import override
from langflow.services.auth import utils as auth_utils
from langflow.services.base import Service
from langflow.services.database.models.variable.model import Variable, VariableCreate
from langflow.services.database.models.variable.model import Variable, VariableCreate, VariableRead
from langflow.services.variable.base import VariableService
from langflow.services.variable.constants import CREDENTIAL_TYPE, GENERIC_TYPE
from langflow.services.variable.kubernetes_secrets import KubernetesSecretManager, encode_user_id
@ -170,3 +170,35 @@ class KubernetesSecretService(VariableService, Service):
default_fields=default_fields,
)
return Variable.model_validate(variable_base, from_attributes=True, update={"user_id": user_id})
@override
async def get_all(self, user_id: UUID | str, session: AsyncSession) -> list[VariableRead]:
secret_name = encode_user_id(user_id)
variables = await asyncio.to_thread(self.kubernetes_secrets.get_secret, name=secret_name)
if not variables:
return []
variables_read = []
for key, value in variables.items():
name = key
type_ = GENERIC_TYPE
if key.startswith(CREDENTIAL_TYPE + "_"):
name = key[len(CREDENTIAL_TYPE) + 1 :]
type_ = CREDENTIAL_TYPE
decrypted_value = None
if type_ == GENERIC_TYPE:
decrypted_value = value
variable_base = VariableCreate(
name=name,
type=type_,
value=auth_utils.encrypt_api_key(value, settings_service=self.settings_service),
default_fields=[],
)
variable = Variable.model_validate(variable_base, from_attributes=True, update={"user_id": user_id})
variable_read = VariableRead.model_validate(variable, from_attributes=True)
variable_read.value = decrypted_value
variables_read.append(variable_read)
return variables_read

View file

@ -9,7 +9,7 @@ from sqlmodel import select
from langflow.services.auth import utils as auth_utils
from langflow.services.base import Service
from langflow.services.database.models.variable.model import Variable, VariableCreate, VariableUpdate
from langflow.services.database.models.variable.model import Variable, VariableCreate, VariableRead, VariableUpdate
from langflow.services.variable.base import VariableService
from langflow.services.variable.constants import CREDENTIAL_TYPE, GENERIC_TYPE
@ -79,9 +79,20 @@ class DatabaseVariableService(VariableService, Service):
# we decrypt the value
return auth_utils.decrypt_api_key(variable.value, settings_service=self.settings_service)
async def get_all(self, user_id: UUID | str, session: AsyncSession) -> list[Variable | None]:
async def get_all(self, user_id: UUID | str, session: AsyncSession) -> list[VariableRead]:
stmt = select(Variable).where(Variable.user_id == user_id)
return list((await session.exec(stmt)).all())
variables = list((await session.exec(stmt)).all())
# If the variable is of type 'Generic' we decrypt the value
variables_read = []
for variable in variables:
value = None
if variable.type == GENERIC_TYPE:
value = auth_utils.decrypt_api_key(variable.value, settings_service=self.settings_service)
variable_read = VariableRead.model_validate(variable, from_attributes=True)
variable_read.value = value
variables_read.append(variable_read)
return variables_read
async def list_variables(self, user_id: UUID | str, session: AsyncSession) -> list[str | None]:
variables = await self.get_all(user_id=user_id, session=session)
@ -160,7 +171,7 @@ class DatabaseVariableService(VariableService, Service):
value: str,
*,
default_fields: Sequence[str] = (),
type_: str = GENERIC_TYPE,
type_: str = CREDENTIAL_TYPE,
session: AsyncSession,
):
variable_base = VariableCreate(

View file

@ -4,36 +4,47 @@ from uuid import uuid4
import pytest
from fastapi import HTTPException, status
from httpx import AsyncClient
from langflow.services.variable.constants import CREDENTIAL_TYPE, GENERIC_TYPE
@pytest.fixture
def body():
def generic_variable():
return {
"name": "test_variable",
"value": "test_value",
"type": "test_type",
"name": "test_generic_variable",
"value": "test_generic_value",
"type": GENERIC_TYPE,
"default_fields": ["test_field"],
}
@pytest.fixture
def credential_variable():
return {
"name": "test_credential_variable",
"value": "test_credential_value",
"type": CREDENTIAL_TYPE,
"default_fields": ["test_field"],
}
@pytest.mark.usefixtures("active_user")
async def test_create_variable(client: AsyncClient, body, logged_in_headers):
response = await client.post("api/v1/variables/", json=body, headers=logged_in_headers)
async def test_create_variable(client: AsyncClient, generic_variable, logged_in_headers):
response = await client.post("api/v1/variables/", json=generic_variable, headers=logged_in_headers)
result = response.json()
assert response.status_code == status.HTTP_201_CREATED
assert body["name"] == result["name"]
assert body["type"] == result["type"]
assert body["default_fields"] == result["default_fields"]
assert generic_variable["name"] == result["name"]
assert generic_variable["type"] == result["type"]
assert generic_variable["default_fields"] == result["default_fields"]
assert "id" in result
assert body["value"] != result["value"]
assert generic_variable["value"] != result["value"] # Value should be encrypted
@pytest.mark.usefixtures("active_user")
async def test_create_variable__variable_name_already_exists(client: AsyncClient, body, logged_in_headers):
await client.post("api/v1/variables/", json=body, headers=logged_in_headers)
async def test_create_variable__variable_name_already_exists(client: AsyncClient, generic_variable, logged_in_headers):
await client.post("api/v1/variables/", json=generic_variable, headers=logged_in_headers)
response = await client.post("api/v1/variables/", json=body, headers=logged_in_headers)
response = await client.post("api/v1/variables/", json=generic_variable, headers=logged_in_headers)
result = response.json()
assert response.status_code == status.HTTP_400_BAD_REQUEST
@ -41,11 +52,13 @@ async def test_create_variable__variable_name_already_exists(client: AsyncClient
@pytest.mark.usefixtures("active_user")
async def test_create_variable__variable_name_and_value_cannot_be_empty(client: AsyncClient, body, logged_in_headers):
body["name"] = ""
body["value"] = ""
async def test_create_variable__variable_name_and_value_cannot_be_empty(
client: AsyncClient, generic_variable, logged_in_headers
):
generic_variable["name"] = ""
generic_variable["value"] = ""
response = await client.post("api/v1/variables/", json=body, headers=logged_in_headers)
response = await client.post("api/v1/variables/", json=generic_variable, headers=logged_in_headers)
result = response.json()
assert response.status_code == status.HTTP_400_BAD_REQUEST
@ -53,10 +66,10 @@ async def test_create_variable__variable_name_and_value_cannot_be_empty(client:
@pytest.mark.usefixtures("active_user")
async def test_create_variable__variable_name_cannot_be_empty(client: AsyncClient, body, logged_in_headers):
body["name"] = ""
async def test_create_variable__variable_name_cannot_be_empty(client: AsyncClient, generic_variable, logged_in_headers):
generic_variable["name"] = ""
response = await client.post("api/v1/variables/", json=body, headers=logged_in_headers)
response = await client.post("api/v1/variables/", json=generic_variable, headers=logged_in_headers)
result = response.json()
assert response.status_code == status.HTTP_400_BAD_REQUEST
@ -64,10 +77,12 @@ async def test_create_variable__variable_name_cannot_be_empty(client: AsyncClien
@pytest.mark.usefixtures("active_user")
async def test_create_variable__variable_value_cannot_be_empty(client: AsyncClient, body, logged_in_headers):
body["value"] = ""
async def test_create_variable__variable_value_cannot_be_empty(
client: AsyncClient, generic_variable, logged_in_headers
):
generic_variable["value"] = ""
response = await client.post("api/v1/variables/", json=body, headers=logged_in_headers)
response = await client.post("api/v1/variables/", json=generic_variable, headers=logged_in_headers)
result = response.json()
assert response.status_code == status.HTTP_400_BAD_REQUEST
@ -75,13 +90,13 @@ async def test_create_variable__variable_value_cannot_be_empty(client: AsyncClie
@pytest.mark.usefixtures("active_user")
async def test_create_variable__httpexception(client: AsyncClient, body, logged_in_headers):
async def test_create_variable__httpexception(client: AsyncClient, generic_variable, logged_in_headers):
status_code = 418
generic_message = "I'm a teapot"
with mock.patch("langflow.services.auth.utils.encrypt_api_key") as m:
m.side_effect = HTTPException(status_code=status_code, detail=generic_message)
response = await client.post("api/v1/variables/", json=body, headers=logged_in_headers)
response = await client.post("api/v1/variables/", json=generic_variable, headers=logged_in_headers)
result = response.json()
assert response.status_code == status.HTTP_418_IM_A_TEAPOT
@ -89,12 +104,12 @@ async def test_create_variable__httpexception(client: AsyncClient, body, logged_
@pytest.mark.usefixtures("active_user")
async def test_create_variable__exception(client: AsyncClient, body, logged_in_headers):
async def test_create_variable__exception(client: AsyncClient, generic_variable, logged_in_headers):
generic_message = "Generic error message"
with mock.patch("langflow.services.auth.utils.encrypt_api_key") as m:
m.side_effect = Exception(generic_message)
response = await client.post("api/v1/variables/", json=body, headers=logged_in_headers)
response = await client.post("api/v1/variables/", json=generic_variable, headers=logged_in_headers)
result = response.json()
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
@ -102,17 +117,33 @@ async def test_create_variable__exception(client: AsyncClient, body, logged_in_h
@pytest.mark.usefixtures("active_user")
async def test_read_variables(client: AsyncClient, body, logged_in_headers):
names = ["test_variable1", "test_variable2", "test_variable3"]
for name in names:
body["name"] = name
await client.post("api/v1/variables/", json=body, headers=logged_in_headers)
async def test_read_variables(client: AsyncClient, generic_variable, credential_variable, logged_in_headers):
# Create a generic variable
create_response = await client.post("api/v1/variables/", json=generic_variable, headers=logged_in_headers)
assert create_response.status_code == status.HTTP_201_CREATED
# Create a credential variable
create_response = await client.post("api/v1/variables/", json=credential_variable, headers=logged_in_headers)
assert create_response.status_code == status.HTTP_201_CREATED
response = await client.get("api/v1/variables/", headers=logged_in_headers)
result = response.json()
assert response.status_code == status.HTTP_200_OK
assert all(name in [r["name"] for r in result] for name in names)
# Check both variables exist
assert generic_variable["name"] in [r["name"] for r in result]
assert credential_variable["name"] in [r["name"] for r in result]
# Assert that credentials are not decrypted and generic are decrypted
credential_vars = [r for r in result if r["type"] == CREDENTIAL_TYPE]
generic_vars = [r for r in result if r["type"] == GENERIC_TYPE]
# Credential variables should remain encrypted (value should be different)
assert all(c["value"] != credential_variable["value"] for c in credential_vars)
# Generic variables should be decrypted (value should match original)
assert all(g["value"] == generic_variable["value"] for g in generic_vars)
@pytest.mark.usefixtures("active_user")
@ -140,16 +171,18 @@ async def test_read_variables__(client: AsyncClient, logged_in_headers):
@pytest.mark.usefixtures("active_user")
async def test_update_variable(client: AsyncClient, body, logged_in_headers):
saved = await client.post("api/v1/variables/", json=body, headers=logged_in_headers)
async def test_update_variable(client: AsyncClient, generic_variable, logged_in_headers):
saved = await client.post("api/v1/variables/", json=generic_variable, headers=logged_in_headers)
saved = saved.json()
body["id"] = saved.get("id")
body["name"] = "new_name"
body["value"] = "new_value"
body["type"] = "new_type"
body["default_fields"] = ["new_field"]
generic_variable["id"] = saved.get("id")
generic_variable["name"] = "new_name"
generic_variable["value"] = "new_value"
generic_variable["type"] = GENERIC_TYPE # Ensure we keep it as GENERIC_TYPE
generic_variable["default_fields"] = ["new_field"]
response = await client.patch(f"api/v1/variables/{saved.get('id')}", json=body, headers=logged_in_headers)
response = await client.patch(
f"api/v1/variables/{saved.get('id')}", json=generic_variable, headers=logged_in_headers
)
result = response.json()
assert response.status_code == status.HTTP_200_OK
@ -159,11 +192,11 @@ async def test_update_variable(client: AsyncClient, body, logged_in_headers):
@pytest.mark.usefixtures("active_user")
async def test_update_variable__exception(client: AsyncClient, body, logged_in_headers):
async def test_update_variable__exception(client: AsyncClient, generic_variable, logged_in_headers):
wrong_id = uuid4()
body["id"] = str(wrong_id)
generic_variable["id"] = str(wrong_id)
response = await client.patch(f"api/v1/variables/{wrong_id}", json=body, headers=logged_in_headers)
response = await client.patch(f"api/v1/variables/{wrong_id}", json=generic_variable, headers=logged_in_headers)
result = response.json()
assert response.status_code == status.HTTP_404_NOT_FOUND
@ -171,8 +204,8 @@ async def test_update_variable__exception(client: AsyncClient, body, logged_in_h
@pytest.mark.usefixtures("active_user")
async def test_delete_variable(client: AsyncClient, body, logged_in_headers):
response = await client.post("api/v1/variables/", json=body, headers=logged_in_headers)
async def test_delete_variable(client: AsyncClient, generic_variable, logged_in_headers):
response = await client.post("api/v1/variables/", json=generic_variable, headers=logged_in_headers)
saved = response.json()
response = await client.delete(f"api/v1/variables/{saved.get('id')}", headers=logged_in_headers)

View file

@ -6,7 +6,7 @@ import pytest
from langflow.services.database.models.variable.model import VariableUpdate
from langflow.services.deps import get_settings_service
from langflow.services.settings.constants import VARIABLES_TO_GET_FROM_ENVIRONMENT
from langflow.services.variable.constants import CREDENTIAL_TYPE, GENERIC_TYPE
from langflow.services.variable.constants import CREDENTIAL_TYPE
from langflow.services.variable.service import DatabaseVariableService
from sqlalchemy.ext.asyncio import create_async_engine
from sqlmodel import SQLModel
@ -137,7 +137,7 @@ async def test_update_variable(service, session: AsyncSession):
assert result.value != old_value
assert result.value != new_value
assert result.default_fields == []
assert result.type == GENERIC_TYPE
assert result.type == CREDENTIAL_TYPE
assert isinstance(result.created_at, datetime)
assert isinstance(result.updated_at, datetime)
@ -237,6 +237,6 @@ async def test_create_variable(service, session: AsyncSession):
assert result.name == name
assert result.value != value
assert result.default_fields == []
assert result.type == GENERIC_TYPE
assert result.type == CREDENTIAL_TYPE
assert isinstance(result.created_at, datetime)
assert isinstance(result.updated_at, datetime)