feat: return variable value if it is a generic variable (#5366)
* fix: update pre-commit configuration for ruff formatting args * fix: update variable type constant from GENERIC_TYPE to CREDENTIAL_TYPE * feat: enhance variable service to handle decryption for generic type and update default variable type to CREDENTIAL_TYPE * feat: add abstract method to retrieve all variables for a user in VariableService * feat: implement get_all method in KubernetesSecretService to retrieve and decrypt user variables * refactor: update variable tests to use fixtures for generic and credential types - Renamed and refactored test fixtures for better clarity and reusability. - Updated tests to utilize `generic_variable` and `credential_variable` fixtures instead of hardcoded values. - Enhanced assertions to ensure correct handling of variable types, including encryption for credential variables and decryption for generic variables. - Improved test structure for creating, reading, updating, and deleting variables, ensuring consistency across test cases. * refactor: update get_all method signatures in variable services to return VariableRead - Changed return type of the get_all method in VariableService, KubernetesSecretService, and DatabaseVariableService from list[Variable | None] to list[VariableRead]. - This update enhances type consistency across variable services and aligns with the new VariableRead model for improved data handling. * fix: update variable type assertion in test_update_variable to CREDENTIAL_TYPE - Changed the assertion in the `test_update_variable` test to verify that the result type is now `CREDENTIAL_TYPE` instead of `GENERIC_TYPE`. - This update aligns the test with recent changes in variable type handling, ensuring accurate validation of variable updates. * fix: update variable type assertion in test_create_variable to CREDENTIAL_TYPE - Changed the assertion in the `test_create_variable` test to verify that the result type is now `CREDENTIAL_TYPE` instead of `GENERIC_TYPE`. - This update ensures consistency with recent changes in variable type handling and improves the accuracy of the test validation.
This commit is contained in:
parent
b7f5a7fe9b
commit
dd68a97567
7 changed files with 143 additions and 58 deletions
|
|
@ -20,4 +20,4 @@ repos:
|
|||
args: [--fix]
|
||||
- id: ruff-format
|
||||
types_or: [python, pyi]
|
||||
args: [--config pyproject.toml]
|
||||
args: [--config, pyproject.toml]
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from sqlalchemy.exc import NoResultFound
|
|||
from langflow.api.utils import CurrentActiveUser, DbSession
|
||||
from langflow.services.database.models.variable import VariableCreate, VariableRead, VariableUpdate
|
||||
from langflow.services.deps import get_variable_service
|
||||
from langflow.services.variable.constants import GENERIC_TYPE
|
||||
from langflow.services.variable.constants import CREDENTIAL_TYPE
|
||||
from langflow.services.variable.service import DatabaseVariableService
|
||||
|
||||
router = APIRouter(prefix="/variables", tags=["Variables"])
|
||||
|
|
@ -38,7 +38,7 @@ async def create_variable(
|
|||
name=variable.name,
|
||||
value=variable.value,
|
||||
default_fields=variable.default_fields or [],
|
||||
type_=variable.type or GENERIC_TYPE,
|
||||
type_=variable.type or CREDENTIAL_TYPE,
|
||||
session=session,
|
||||
)
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from uuid import UUID
|
|||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from langflow.services.base import Service
|
||||
from langflow.services.database.models.variable.model import Variable
|
||||
from langflow.services.database.models.variable.model import Variable, VariableRead
|
||||
|
||||
|
||||
class VariableService(Service):
|
||||
|
|
@ -108,3 +108,12 @@ class VariableService(Service):
|
|||
Returns:
|
||||
The created variable.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_all(self, user_id: UUID | str, session: AsyncSession) -> list[VariableRead]:
|
||||
"""Get all variables.
|
||||
|
||||
Args:
|
||||
user_id: The user ID.
|
||||
session: The database session.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from typing_extensions import override
|
|||
|
||||
from langflow.services.auth import utils as auth_utils
|
||||
from langflow.services.base import Service
|
||||
from langflow.services.database.models.variable.model import Variable, VariableCreate
|
||||
from langflow.services.database.models.variable.model import Variable, VariableCreate, VariableRead
|
||||
from langflow.services.variable.base import VariableService
|
||||
from langflow.services.variable.constants import CREDENTIAL_TYPE, GENERIC_TYPE
|
||||
from langflow.services.variable.kubernetes_secrets import KubernetesSecretManager, encode_user_id
|
||||
|
|
@ -170,3 +170,35 @@ class KubernetesSecretService(VariableService, Service):
|
|||
default_fields=default_fields,
|
||||
)
|
||||
return Variable.model_validate(variable_base, from_attributes=True, update={"user_id": user_id})
|
||||
|
||||
@override
|
||||
async def get_all(self, user_id: UUID | str, session: AsyncSession) -> list[VariableRead]:
|
||||
secret_name = encode_user_id(user_id)
|
||||
variables = await asyncio.to_thread(self.kubernetes_secrets.get_secret, name=secret_name)
|
||||
if not variables:
|
||||
return []
|
||||
|
||||
variables_read = []
|
||||
for key, value in variables.items():
|
||||
name = key
|
||||
type_ = GENERIC_TYPE
|
||||
if key.startswith(CREDENTIAL_TYPE + "_"):
|
||||
name = key[len(CREDENTIAL_TYPE) + 1 :]
|
||||
type_ = CREDENTIAL_TYPE
|
||||
|
||||
decrypted_value = None
|
||||
if type_ == GENERIC_TYPE:
|
||||
decrypted_value = value
|
||||
|
||||
variable_base = VariableCreate(
|
||||
name=name,
|
||||
type=type_,
|
||||
value=auth_utils.encrypt_api_key(value, settings_service=self.settings_service),
|
||||
default_fields=[],
|
||||
)
|
||||
variable = Variable.model_validate(variable_base, from_attributes=True, update={"user_id": user_id})
|
||||
variable_read = VariableRead.model_validate(variable, from_attributes=True)
|
||||
variable_read.value = decrypted_value
|
||||
variables_read.append(variable_read)
|
||||
|
||||
return variables_read
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from sqlmodel import select
|
|||
|
||||
from langflow.services.auth import utils as auth_utils
|
||||
from langflow.services.base import Service
|
||||
from langflow.services.database.models.variable.model import Variable, VariableCreate, VariableUpdate
|
||||
from langflow.services.database.models.variable.model import Variable, VariableCreate, VariableRead, VariableUpdate
|
||||
from langflow.services.variable.base import VariableService
|
||||
from langflow.services.variable.constants import CREDENTIAL_TYPE, GENERIC_TYPE
|
||||
|
||||
|
|
@ -79,9 +79,20 @@ class DatabaseVariableService(VariableService, Service):
|
|||
# we decrypt the value
|
||||
return auth_utils.decrypt_api_key(variable.value, settings_service=self.settings_service)
|
||||
|
||||
async def get_all(self, user_id: UUID | str, session: AsyncSession) -> list[Variable | None]:
|
||||
async def get_all(self, user_id: UUID | str, session: AsyncSession) -> list[VariableRead]:
|
||||
stmt = select(Variable).where(Variable.user_id == user_id)
|
||||
return list((await session.exec(stmt)).all())
|
||||
variables = list((await session.exec(stmt)).all())
|
||||
# If the variable is of type 'Generic' we decrypt the value
|
||||
variables_read = []
|
||||
for variable in variables:
|
||||
value = None
|
||||
if variable.type == GENERIC_TYPE:
|
||||
value = auth_utils.decrypt_api_key(variable.value, settings_service=self.settings_service)
|
||||
|
||||
variable_read = VariableRead.model_validate(variable, from_attributes=True)
|
||||
variable_read.value = value
|
||||
variables_read.append(variable_read)
|
||||
return variables_read
|
||||
|
||||
async def list_variables(self, user_id: UUID | str, session: AsyncSession) -> list[str | None]:
|
||||
variables = await self.get_all(user_id=user_id, session=session)
|
||||
|
|
@ -160,7 +171,7 @@ class DatabaseVariableService(VariableService, Service):
|
|||
value: str,
|
||||
*,
|
||||
default_fields: Sequence[str] = (),
|
||||
type_: str = GENERIC_TYPE,
|
||||
type_: str = CREDENTIAL_TYPE,
|
||||
session: AsyncSession,
|
||||
):
|
||||
variable_base = VariableCreate(
|
||||
|
|
|
|||
|
|
@ -4,36 +4,47 @@ from uuid import uuid4
|
|||
import pytest
|
||||
from fastapi import HTTPException, status
|
||||
from httpx import AsyncClient
|
||||
from langflow.services.variable.constants import CREDENTIAL_TYPE, GENERIC_TYPE
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def body():
|
||||
def generic_variable():
|
||||
return {
|
||||
"name": "test_variable",
|
||||
"value": "test_value",
|
||||
"type": "test_type",
|
||||
"name": "test_generic_variable",
|
||||
"value": "test_generic_value",
|
||||
"type": GENERIC_TYPE,
|
||||
"default_fields": ["test_field"],
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def credential_variable():
|
||||
return {
|
||||
"name": "test_credential_variable",
|
||||
"value": "test_credential_value",
|
||||
"type": CREDENTIAL_TYPE,
|
||||
"default_fields": ["test_field"],
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("active_user")
|
||||
async def test_create_variable(client: AsyncClient, body, logged_in_headers):
|
||||
response = await client.post("api/v1/variables/", json=body, headers=logged_in_headers)
|
||||
async def test_create_variable(client: AsyncClient, generic_variable, logged_in_headers):
|
||||
response = await client.post("api/v1/variables/", json=generic_variable, headers=logged_in_headers)
|
||||
result = response.json()
|
||||
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
assert body["name"] == result["name"]
|
||||
assert body["type"] == result["type"]
|
||||
assert body["default_fields"] == result["default_fields"]
|
||||
assert generic_variable["name"] == result["name"]
|
||||
assert generic_variable["type"] == result["type"]
|
||||
assert generic_variable["default_fields"] == result["default_fields"]
|
||||
assert "id" in result
|
||||
assert body["value"] != result["value"]
|
||||
assert generic_variable["value"] != result["value"] # Value should be encrypted
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("active_user")
|
||||
async def test_create_variable__variable_name_already_exists(client: AsyncClient, body, logged_in_headers):
|
||||
await client.post("api/v1/variables/", json=body, headers=logged_in_headers)
|
||||
async def test_create_variable__variable_name_already_exists(client: AsyncClient, generic_variable, logged_in_headers):
|
||||
await client.post("api/v1/variables/", json=generic_variable, headers=logged_in_headers)
|
||||
|
||||
response = await client.post("api/v1/variables/", json=body, headers=logged_in_headers)
|
||||
response = await client.post("api/v1/variables/", json=generic_variable, headers=logged_in_headers)
|
||||
result = response.json()
|
||||
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
|
|
@ -41,11 +52,13 @@ async def test_create_variable__variable_name_already_exists(client: AsyncClient
|
|||
|
||||
|
||||
@pytest.mark.usefixtures("active_user")
|
||||
async def test_create_variable__variable_name_and_value_cannot_be_empty(client: AsyncClient, body, logged_in_headers):
|
||||
body["name"] = ""
|
||||
body["value"] = ""
|
||||
async def test_create_variable__variable_name_and_value_cannot_be_empty(
|
||||
client: AsyncClient, generic_variable, logged_in_headers
|
||||
):
|
||||
generic_variable["name"] = ""
|
||||
generic_variable["value"] = ""
|
||||
|
||||
response = await client.post("api/v1/variables/", json=body, headers=logged_in_headers)
|
||||
response = await client.post("api/v1/variables/", json=generic_variable, headers=logged_in_headers)
|
||||
result = response.json()
|
||||
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
|
|
@ -53,10 +66,10 @@ async def test_create_variable__variable_name_and_value_cannot_be_empty(client:
|
|||
|
||||
|
||||
@pytest.mark.usefixtures("active_user")
|
||||
async def test_create_variable__variable_name_cannot_be_empty(client: AsyncClient, body, logged_in_headers):
|
||||
body["name"] = ""
|
||||
async def test_create_variable__variable_name_cannot_be_empty(client: AsyncClient, generic_variable, logged_in_headers):
|
||||
generic_variable["name"] = ""
|
||||
|
||||
response = await client.post("api/v1/variables/", json=body, headers=logged_in_headers)
|
||||
response = await client.post("api/v1/variables/", json=generic_variable, headers=logged_in_headers)
|
||||
result = response.json()
|
||||
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
|
|
@ -64,10 +77,12 @@ async def test_create_variable__variable_name_cannot_be_empty(client: AsyncClien
|
|||
|
||||
|
||||
@pytest.mark.usefixtures("active_user")
|
||||
async def test_create_variable__variable_value_cannot_be_empty(client: AsyncClient, body, logged_in_headers):
|
||||
body["value"] = ""
|
||||
async def test_create_variable__variable_value_cannot_be_empty(
|
||||
client: AsyncClient, generic_variable, logged_in_headers
|
||||
):
|
||||
generic_variable["value"] = ""
|
||||
|
||||
response = await client.post("api/v1/variables/", json=body, headers=logged_in_headers)
|
||||
response = await client.post("api/v1/variables/", json=generic_variable, headers=logged_in_headers)
|
||||
result = response.json()
|
||||
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
|
|
@ -75,13 +90,13 @@ async def test_create_variable__variable_value_cannot_be_empty(client: AsyncClie
|
|||
|
||||
|
||||
@pytest.mark.usefixtures("active_user")
|
||||
async def test_create_variable__httpexception(client: AsyncClient, body, logged_in_headers):
|
||||
async def test_create_variable__httpexception(client: AsyncClient, generic_variable, logged_in_headers):
|
||||
status_code = 418
|
||||
generic_message = "I'm a teapot"
|
||||
|
||||
with mock.patch("langflow.services.auth.utils.encrypt_api_key") as m:
|
||||
m.side_effect = HTTPException(status_code=status_code, detail=generic_message)
|
||||
response = await client.post("api/v1/variables/", json=body, headers=logged_in_headers)
|
||||
response = await client.post("api/v1/variables/", json=generic_variable, headers=logged_in_headers)
|
||||
result = response.json()
|
||||
|
||||
assert response.status_code == status.HTTP_418_IM_A_TEAPOT
|
||||
|
|
@ -89,12 +104,12 @@ async def test_create_variable__httpexception(client: AsyncClient, body, logged_
|
|||
|
||||
|
||||
@pytest.mark.usefixtures("active_user")
|
||||
async def test_create_variable__exception(client: AsyncClient, body, logged_in_headers):
|
||||
async def test_create_variable__exception(client: AsyncClient, generic_variable, logged_in_headers):
|
||||
generic_message = "Generic error message"
|
||||
|
||||
with mock.patch("langflow.services.auth.utils.encrypt_api_key") as m:
|
||||
m.side_effect = Exception(generic_message)
|
||||
response = await client.post("api/v1/variables/", json=body, headers=logged_in_headers)
|
||||
response = await client.post("api/v1/variables/", json=generic_variable, headers=logged_in_headers)
|
||||
result = response.json()
|
||||
|
||||
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
|
|
@ -102,17 +117,33 @@ async def test_create_variable__exception(client: AsyncClient, body, logged_in_h
|
|||
|
||||
|
||||
@pytest.mark.usefixtures("active_user")
|
||||
async def test_read_variables(client: AsyncClient, body, logged_in_headers):
|
||||
names = ["test_variable1", "test_variable2", "test_variable3"]
|
||||
for name in names:
|
||||
body["name"] = name
|
||||
await client.post("api/v1/variables/", json=body, headers=logged_in_headers)
|
||||
async def test_read_variables(client: AsyncClient, generic_variable, credential_variable, logged_in_headers):
|
||||
# Create a generic variable
|
||||
create_response = await client.post("api/v1/variables/", json=generic_variable, headers=logged_in_headers)
|
||||
assert create_response.status_code == status.HTTP_201_CREATED
|
||||
|
||||
# Create a credential variable
|
||||
create_response = await client.post("api/v1/variables/", json=credential_variable, headers=logged_in_headers)
|
||||
assert create_response.status_code == status.HTTP_201_CREATED
|
||||
|
||||
response = await client.get("api/v1/variables/", headers=logged_in_headers)
|
||||
result = response.json()
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert all(name in [r["name"] for r in result] for name in names)
|
||||
|
||||
# Check both variables exist
|
||||
assert generic_variable["name"] in [r["name"] for r in result]
|
||||
assert credential_variable["name"] in [r["name"] for r in result]
|
||||
|
||||
# Assert that credentials are not decrypted and generic are decrypted
|
||||
credential_vars = [r for r in result if r["type"] == CREDENTIAL_TYPE]
|
||||
generic_vars = [r for r in result if r["type"] == GENERIC_TYPE]
|
||||
|
||||
# Credential variables should remain encrypted (value should be different)
|
||||
assert all(c["value"] != credential_variable["value"] for c in credential_vars)
|
||||
|
||||
# Generic variables should be decrypted (value should match original)
|
||||
assert all(g["value"] == generic_variable["value"] for g in generic_vars)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("active_user")
|
||||
|
|
@ -140,16 +171,18 @@ async def test_read_variables__(client: AsyncClient, logged_in_headers):
|
|||
|
||||
|
||||
@pytest.mark.usefixtures("active_user")
|
||||
async def test_update_variable(client: AsyncClient, body, logged_in_headers):
|
||||
saved = await client.post("api/v1/variables/", json=body, headers=logged_in_headers)
|
||||
async def test_update_variable(client: AsyncClient, generic_variable, logged_in_headers):
|
||||
saved = await client.post("api/v1/variables/", json=generic_variable, headers=logged_in_headers)
|
||||
saved = saved.json()
|
||||
body["id"] = saved.get("id")
|
||||
body["name"] = "new_name"
|
||||
body["value"] = "new_value"
|
||||
body["type"] = "new_type"
|
||||
body["default_fields"] = ["new_field"]
|
||||
generic_variable["id"] = saved.get("id")
|
||||
generic_variable["name"] = "new_name"
|
||||
generic_variable["value"] = "new_value"
|
||||
generic_variable["type"] = GENERIC_TYPE # Ensure we keep it as GENERIC_TYPE
|
||||
generic_variable["default_fields"] = ["new_field"]
|
||||
|
||||
response = await client.patch(f"api/v1/variables/{saved.get('id')}", json=body, headers=logged_in_headers)
|
||||
response = await client.patch(
|
||||
f"api/v1/variables/{saved.get('id')}", json=generic_variable, headers=logged_in_headers
|
||||
)
|
||||
result = response.json()
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
|
@ -159,11 +192,11 @@ async def test_update_variable(client: AsyncClient, body, logged_in_headers):
|
|||
|
||||
|
||||
@pytest.mark.usefixtures("active_user")
|
||||
async def test_update_variable__exception(client: AsyncClient, body, logged_in_headers):
|
||||
async def test_update_variable__exception(client: AsyncClient, generic_variable, logged_in_headers):
|
||||
wrong_id = uuid4()
|
||||
body["id"] = str(wrong_id)
|
||||
generic_variable["id"] = str(wrong_id)
|
||||
|
||||
response = await client.patch(f"api/v1/variables/{wrong_id}", json=body, headers=logged_in_headers)
|
||||
response = await client.patch(f"api/v1/variables/{wrong_id}", json=generic_variable, headers=logged_in_headers)
|
||||
result = response.json()
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
|
@ -171,8 +204,8 @@ async def test_update_variable__exception(client: AsyncClient, body, logged_in_h
|
|||
|
||||
|
||||
@pytest.mark.usefixtures("active_user")
|
||||
async def test_delete_variable(client: AsyncClient, body, logged_in_headers):
|
||||
response = await client.post("api/v1/variables/", json=body, headers=logged_in_headers)
|
||||
async def test_delete_variable(client: AsyncClient, generic_variable, logged_in_headers):
|
||||
response = await client.post("api/v1/variables/", json=generic_variable, headers=logged_in_headers)
|
||||
saved = response.json()
|
||||
response = await client.delete(f"api/v1/variables/{saved.get('id')}", headers=logged_in_headers)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import pytest
|
|||
from langflow.services.database.models.variable.model import VariableUpdate
|
||||
from langflow.services.deps import get_settings_service
|
||||
from langflow.services.settings.constants import VARIABLES_TO_GET_FROM_ENVIRONMENT
|
||||
from langflow.services.variable.constants import CREDENTIAL_TYPE, GENERIC_TYPE
|
||||
from langflow.services.variable.constants import CREDENTIAL_TYPE
|
||||
from langflow.services.variable.service import DatabaseVariableService
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlmodel import SQLModel
|
||||
|
|
@ -137,7 +137,7 @@ async def test_update_variable(service, session: AsyncSession):
|
|||
assert result.value != old_value
|
||||
assert result.value != new_value
|
||||
assert result.default_fields == []
|
||||
assert result.type == GENERIC_TYPE
|
||||
assert result.type == CREDENTIAL_TYPE
|
||||
assert isinstance(result.created_at, datetime)
|
||||
assert isinstance(result.updated_at, datetime)
|
||||
|
||||
|
|
@ -237,6 +237,6 @@ async def test_create_variable(service, session: AsyncSession):
|
|||
assert result.name == name
|
||||
assert result.value != value
|
||||
assert result.default_fields == []
|
||||
assert result.type == GENERIC_TYPE
|
||||
assert result.type == CREDENTIAL_TYPE
|
||||
assert isinstance(result.created_at, datetime)
|
||||
assert isinstance(result.updated_at, datetime)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue