diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7eafd54eb..2f3b2b3a0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,4 +20,4 @@ repos: args: [--fix] - id: ruff-format types_or: [python, pyi] - args: [--config pyproject.toml] + args: [--config, pyproject.toml] diff --git a/src/backend/base/langflow/api/v1/variable.py b/src/backend/base/langflow/api/v1/variable.py index 5051c788e..c0f05bfe1 100644 --- a/src/backend/base/langflow/api/v1/variable.py +++ b/src/backend/base/langflow/api/v1/variable.py @@ -6,7 +6,7 @@ from sqlalchemy.exc import NoResultFound from langflow.api.utils import CurrentActiveUser, DbSession from langflow.services.database.models.variable import VariableCreate, VariableRead, VariableUpdate from langflow.services.deps import get_variable_service -from langflow.services.variable.constants import GENERIC_TYPE +from langflow.services.variable.constants import CREDENTIAL_TYPE from langflow.services.variable.service import DatabaseVariableService router = APIRouter(prefix="/variables", tags=["Variables"]) @@ -38,7 +38,7 @@ async def create_variable( name=variable.name, value=variable.value, default_fields=variable.default_fields or [], - type_=variable.type or GENERIC_TYPE, + type_=variable.type or CREDENTIAL_TYPE, session=session, ) except Exception as e: diff --git a/src/backend/base/langflow/services/variable/base.py b/src/backend/base/langflow/services/variable/base.py index 8c9bce10c..924bc4e4a 100644 --- a/src/backend/base/langflow/services/variable/base.py +++ b/src/backend/base/langflow/services/variable/base.py @@ -4,7 +4,7 @@ from uuid import UUID from sqlmodel.ext.asyncio.session import AsyncSession from langflow.services.base import Service -from langflow.services.database.models.variable.model import Variable +from langflow.services.database.models.variable.model import Variable, VariableRead class VariableService(Service): @@ -108,3 +108,12 @@ class VariableService(Service): Returns: The created variable. """ + + @abc.abstractmethod + async def get_all(self, user_id: UUID | str, session: AsyncSession) -> list[VariableRead]: + """Get all variables. + + Args: + user_id: The user ID. + session: The database session. + """ diff --git a/src/backend/base/langflow/services/variable/kubernetes.py b/src/backend/base/langflow/services/variable/kubernetes.py index 8efb67d52..d39ad6c0e 100644 --- a/src/backend/base/langflow/services/variable/kubernetes.py +++ b/src/backend/base/langflow/services/variable/kubernetes.py @@ -9,7 +9,7 @@ from typing_extensions import override from langflow.services.auth import utils as auth_utils from langflow.services.base import Service -from langflow.services.database.models.variable.model import Variable, VariableCreate +from langflow.services.database.models.variable.model import Variable, VariableCreate, VariableRead from langflow.services.variable.base import VariableService from langflow.services.variable.constants import CREDENTIAL_TYPE, GENERIC_TYPE from langflow.services.variable.kubernetes_secrets import KubernetesSecretManager, encode_user_id @@ -170,3 +170,35 @@ class KubernetesSecretService(VariableService, Service): default_fields=default_fields, ) return Variable.model_validate(variable_base, from_attributes=True, update={"user_id": user_id}) + + @override + async def get_all(self, user_id: UUID | str, session: AsyncSession) -> list[VariableRead]: + secret_name = encode_user_id(user_id) + variables = await asyncio.to_thread(self.kubernetes_secrets.get_secret, name=secret_name) + if not variables: + return [] + + variables_read = [] + for key, value in variables.items(): + name = key + type_ = GENERIC_TYPE + if key.startswith(CREDENTIAL_TYPE + "_"): + name = key[len(CREDENTIAL_TYPE) + 1 :] + type_ = CREDENTIAL_TYPE + + decrypted_value = None + if type_ == GENERIC_TYPE: + decrypted_value = value + + variable_base = VariableCreate( + name=name, + type=type_, + value=auth_utils.encrypt_api_key(value, settings_service=self.settings_service), + default_fields=[], + ) + variable = Variable.model_validate(variable_base, from_attributes=True, update={"user_id": user_id}) + variable_read = VariableRead.model_validate(variable, from_attributes=True) + variable_read.value = decrypted_value + variables_read.append(variable_read) + + return variables_read diff --git a/src/backend/base/langflow/services/variable/service.py b/src/backend/base/langflow/services/variable/service.py index 1a4f6000c..f7b7e92ec 100644 --- a/src/backend/base/langflow/services/variable/service.py +++ b/src/backend/base/langflow/services/variable/service.py @@ -9,7 +9,7 @@ from sqlmodel import select from langflow.services.auth import utils as auth_utils from langflow.services.base import Service -from langflow.services.database.models.variable.model import Variable, VariableCreate, VariableUpdate +from langflow.services.database.models.variable.model import Variable, VariableCreate, VariableRead, VariableUpdate from langflow.services.variable.base import VariableService from langflow.services.variable.constants import CREDENTIAL_TYPE, GENERIC_TYPE @@ -79,9 +79,20 @@ class DatabaseVariableService(VariableService, Service): # we decrypt the value return auth_utils.decrypt_api_key(variable.value, settings_service=self.settings_service) - async def get_all(self, user_id: UUID | str, session: AsyncSession) -> list[Variable | None]: + async def get_all(self, user_id: UUID | str, session: AsyncSession) -> list[VariableRead]: stmt = select(Variable).where(Variable.user_id == user_id) - return list((await session.exec(stmt)).all()) + variables = list((await session.exec(stmt)).all()) + # If the variable is of type 'Generic' we decrypt the value + variables_read = [] + for variable in variables: + value = None + if variable.type == GENERIC_TYPE: + value = auth_utils.decrypt_api_key(variable.value, settings_service=self.settings_service) + + variable_read = VariableRead.model_validate(variable, from_attributes=True) + variable_read.value = value + variables_read.append(variable_read) + return variables_read async def list_variables(self, user_id: UUID | str, session: AsyncSession) -> list[str | None]: variables = await self.get_all(user_id=user_id, session=session) @@ -160,7 +171,7 @@ class DatabaseVariableService(VariableService, Service): value: str, *, default_fields: Sequence[str] = (), - type_: str = GENERIC_TYPE, + type_: str = CREDENTIAL_TYPE, session: AsyncSession, ): variable_base = VariableCreate( diff --git a/src/backend/tests/unit/api/v1/test_variable.py b/src/backend/tests/unit/api/v1/test_variable.py index 6c1fe08db..8d8f81fde 100644 --- a/src/backend/tests/unit/api/v1/test_variable.py +++ b/src/backend/tests/unit/api/v1/test_variable.py @@ -4,36 +4,47 @@ from uuid import uuid4 import pytest from fastapi import HTTPException, status from httpx import AsyncClient +from langflow.services.variable.constants import CREDENTIAL_TYPE, GENERIC_TYPE @pytest.fixture -def body(): +def generic_variable(): return { - "name": "test_variable", - "value": "test_value", - "type": "test_type", + "name": "test_generic_variable", + "value": "test_generic_value", + "type": GENERIC_TYPE, + "default_fields": ["test_field"], + } + + +@pytest.fixture +def credential_variable(): + return { + "name": "test_credential_variable", + "value": "test_credential_value", + "type": CREDENTIAL_TYPE, "default_fields": ["test_field"], } @pytest.mark.usefixtures("active_user") -async def test_create_variable(client: AsyncClient, body, logged_in_headers): - response = await client.post("api/v1/variables/", json=body, headers=logged_in_headers) +async def test_create_variable(client: AsyncClient, generic_variable, logged_in_headers): + response = await client.post("api/v1/variables/", json=generic_variable, headers=logged_in_headers) result = response.json() assert response.status_code == status.HTTP_201_CREATED - assert body["name"] == result["name"] - assert body["type"] == result["type"] - assert body["default_fields"] == result["default_fields"] + assert generic_variable["name"] == result["name"] + assert generic_variable["type"] == result["type"] + assert generic_variable["default_fields"] == result["default_fields"] assert "id" in result - assert body["value"] != result["value"] + assert generic_variable["value"] != result["value"] # Value should be encrypted @pytest.mark.usefixtures("active_user") -async def test_create_variable__variable_name_already_exists(client: AsyncClient, body, logged_in_headers): - await client.post("api/v1/variables/", json=body, headers=logged_in_headers) +async def test_create_variable__variable_name_already_exists(client: AsyncClient, generic_variable, logged_in_headers): + await client.post("api/v1/variables/", json=generic_variable, headers=logged_in_headers) - response = await client.post("api/v1/variables/", json=body, headers=logged_in_headers) + response = await client.post("api/v1/variables/", json=generic_variable, headers=logged_in_headers) result = response.json() assert response.status_code == status.HTTP_400_BAD_REQUEST @@ -41,11 +52,13 @@ async def test_create_variable__variable_name_already_exists(client: AsyncClient @pytest.mark.usefixtures("active_user") -async def test_create_variable__variable_name_and_value_cannot_be_empty(client: AsyncClient, body, logged_in_headers): - body["name"] = "" - body["value"] = "" +async def test_create_variable__variable_name_and_value_cannot_be_empty( + client: AsyncClient, generic_variable, logged_in_headers +): + generic_variable["name"] = "" + generic_variable["value"] = "" - response = await client.post("api/v1/variables/", json=body, headers=logged_in_headers) + response = await client.post("api/v1/variables/", json=generic_variable, headers=logged_in_headers) result = response.json() assert response.status_code == status.HTTP_400_BAD_REQUEST @@ -53,10 +66,10 @@ async def test_create_variable__variable_name_and_value_cannot_be_empty(client: @pytest.mark.usefixtures("active_user") -async def test_create_variable__variable_name_cannot_be_empty(client: AsyncClient, body, logged_in_headers): - body["name"] = "" +async def test_create_variable__variable_name_cannot_be_empty(client: AsyncClient, generic_variable, logged_in_headers): + generic_variable["name"] = "" - response = await client.post("api/v1/variables/", json=body, headers=logged_in_headers) + response = await client.post("api/v1/variables/", json=generic_variable, headers=logged_in_headers) result = response.json() assert response.status_code == status.HTTP_400_BAD_REQUEST @@ -64,10 +77,12 @@ async def test_create_variable__variable_name_cannot_be_empty(client: AsyncClien @pytest.mark.usefixtures("active_user") -async def test_create_variable__variable_value_cannot_be_empty(client: AsyncClient, body, logged_in_headers): - body["value"] = "" +async def test_create_variable__variable_value_cannot_be_empty( + client: AsyncClient, generic_variable, logged_in_headers +): + generic_variable["value"] = "" - response = await client.post("api/v1/variables/", json=body, headers=logged_in_headers) + response = await client.post("api/v1/variables/", json=generic_variable, headers=logged_in_headers) result = response.json() assert response.status_code == status.HTTP_400_BAD_REQUEST @@ -75,13 +90,13 @@ async def test_create_variable__variable_value_cannot_be_empty(client: AsyncClie @pytest.mark.usefixtures("active_user") -async def test_create_variable__httpexception(client: AsyncClient, body, logged_in_headers): +async def test_create_variable__httpexception(client: AsyncClient, generic_variable, logged_in_headers): status_code = 418 generic_message = "I'm a teapot" with mock.patch("langflow.services.auth.utils.encrypt_api_key") as m: m.side_effect = HTTPException(status_code=status_code, detail=generic_message) - response = await client.post("api/v1/variables/", json=body, headers=logged_in_headers) + response = await client.post("api/v1/variables/", json=generic_variable, headers=logged_in_headers) result = response.json() assert response.status_code == status.HTTP_418_IM_A_TEAPOT @@ -89,12 +104,12 @@ async def test_create_variable__httpexception(client: AsyncClient, body, logged_ @pytest.mark.usefixtures("active_user") -async def test_create_variable__exception(client: AsyncClient, body, logged_in_headers): +async def test_create_variable__exception(client: AsyncClient, generic_variable, logged_in_headers): generic_message = "Generic error message" with mock.patch("langflow.services.auth.utils.encrypt_api_key") as m: m.side_effect = Exception(generic_message) - response = await client.post("api/v1/variables/", json=body, headers=logged_in_headers) + response = await client.post("api/v1/variables/", json=generic_variable, headers=logged_in_headers) result = response.json() assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR @@ -102,17 +117,33 @@ async def test_create_variable__exception(client: AsyncClient, body, logged_in_h @pytest.mark.usefixtures("active_user") -async def test_read_variables(client: AsyncClient, body, logged_in_headers): - names = ["test_variable1", "test_variable2", "test_variable3"] - for name in names: - body["name"] = name - await client.post("api/v1/variables/", json=body, headers=logged_in_headers) +async def test_read_variables(client: AsyncClient, generic_variable, credential_variable, logged_in_headers): + # Create a generic variable + create_response = await client.post("api/v1/variables/", json=generic_variable, headers=logged_in_headers) + assert create_response.status_code == status.HTTP_201_CREATED + + # Create a credential variable + create_response = await client.post("api/v1/variables/", json=credential_variable, headers=logged_in_headers) + assert create_response.status_code == status.HTTP_201_CREATED response = await client.get("api/v1/variables/", headers=logged_in_headers) result = response.json() assert response.status_code == status.HTTP_200_OK - assert all(name in [r["name"] for r in result] for name in names) + + # Check both variables exist + assert generic_variable["name"] in [r["name"] for r in result] + assert credential_variable["name"] in [r["name"] for r in result] + + # Assert that credentials are not decrypted and generic are decrypted + credential_vars = [r for r in result if r["type"] == CREDENTIAL_TYPE] + generic_vars = [r for r in result if r["type"] == GENERIC_TYPE] + + # Credential variables should remain encrypted (value should be different) + assert all(c["value"] != credential_variable["value"] for c in credential_vars) + + # Generic variables should be decrypted (value should match original) + assert all(g["value"] == generic_variable["value"] for g in generic_vars) @pytest.mark.usefixtures("active_user") @@ -140,16 +171,18 @@ async def test_read_variables__(client: AsyncClient, logged_in_headers): @pytest.mark.usefixtures("active_user") -async def test_update_variable(client: AsyncClient, body, logged_in_headers): - saved = await client.post("api/v1/variables/", json=body, headers=logged_in_headers) +async def test_update_variable(client: AsyncClient, generic_variable, logged_in_headers): + saved = await client.post("api/v1/variables/", json=generic_variable, headers=logged_in_headers) saved = saved.json() - body["id"] = saved.get("id") - body["name"] = "new_name" - body["value"] = "new_value" - body["type"] = "new_type" - body["default_fields"] = ["new_field"] + generic_variable["id"] = saved.get("id") + generic_variable["name"] = "new_name" + generic_variable["value"] = "new_value" + generic_variable["type"] = GENERIC_TYPE # Ensure we keep it as GENERIC_TYPE + generic_variable["default_fields"] = ["new_field"] - response = await client.patch(f"api/v1/variables/{saved.get('id')}", json=body, headers=logged_in_headers) + response = await client.patch( + f"api/v1/variables/{saved.get('id')}", json=generic_variable, headers=logged_in_headers + ) result = response.json() assert response.status_code == status.HTTP_200_OK @@ -159,11 +192,11 @@ async def test_update_variable(client: AsyncClient, body, logged_in_headers): @pytest.mark.usefixtures("active_user") -async def test_update_variable__exception(client: AsyncClient, body, logged_in_headers): +async def test_update_variable__exception(client: AsyncClient, generic_variable, logged_in_headers): wrong_id = uuid4() - body["id"] = str(wrong_id) + generic_variable["id"] = str(wrong_id) - response = await client.patch(f"api/v1/variables/{wrong_id}", json=body, headers=logged_in_headers) + response = await client.patch(f"api/v1/variables/{wrong_id}", json=generic_variable, headers=logged_in_headers) result = response.json() assert response.status_code == status.HTTP_404_NOT_FOUND @@ -171,8 +204,8 @@ async def test_update_variable__exception(client: AsyncClient, body, logged_in_h @pytest.mark.usefixtures("active_user") -async def test_delete_variable(client: AsyncClient, body, logged_in_headers): - response = await client.post("api/v1/variables/", json=body, headers=logged_in_headers) +async def test_delete_variable(client: AsyncClient, generic_variable, logged_in_headers): + response = await client.post("api/v1/variables/", json=generic_variable, headers=logged_in_headers) saved = response.json() response = await client.delete(f"api/v1/variables/{saved.get('id')}", headers=logged_in_headers) diff --git a/src/backend/tests/unit/services/variable/test_service.py b/src/backend/tests/unit/services/variable/test_service.py index 1f7b1d86e..fd7a66ba1 100644 --- a/src/backend/tests/unit/services/variable/test_service.py +++ b/src/backend/tests/unit/services/variable/test_service.py @@ -6,7 +6,7 @@ import pytest from langflow.services.database.models.variable.model import VariableUpdate from langflow.services.deps import get_settings_service from langflow.services.settings.constants import VARIABLES_TO_GET_FROM_ENVIRONMENT -from langflow.services.variable.constants import CREDENTIAL_TYPE, GENERIC_TYPE +from langflow.services.variable.constants import CREDENTIAL_TYPE from langflow.services.variable.service import DatabaseVariableService from sqlalchemy.ext.asyncio import create_async_engine from sqlmodel import SQLModel @@ -137,7 +137,7 @@ async def test_update_variable(service, session: AsyncSession): assert result.value != old_value assert result.value != new_value assert result.default_fields == [] - assert result.type == GENERIC_TYPE + assert result.type == CREDENTIAL_TYPE assert isinstance(result.created_at, datetime) assert isinstance(result.updated_at, datetime) @@ -237,6 +237,6 @@ async def test_create_variable(service, session: AsyncSession): assert result.name == name assert result.value != value assert result.default_fields == [] - assert result.type == GENERIC_TYPE + assert result.type == CREDENTIAL_TYPE assert isinstance(result.created_at, datetime) assert isinstance(result.updated_at, datetime)