feat: centralize global variable management (#3284)

* test: add tests for global variable endpoints

* test: add unit tests variable service

* fix: anticipate checks to prevent the code from breaking

* feat: add a new method to interface

* feat: add method to update fields in variable service

* feat: replace variable api code

* fix: mypy error

* fix: mypy error

* feat(variable): Allow deleting variables by name or ID in DatabaseVariableService.

* refactor(api): Simplify delete method in variable router.

---------

Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>
This commit is contained in:
Ítalo Johnny 2024-08-13 10:32:57 -03:00 committed by GitHub
commit 952ba5eef1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 499 additions and 67 deletions

View file

@ -1,14 +1,15 @@
from datetime import datetime, timezone
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException
from sqlmodel import Session, select
from sqlalchemy.exc import NoResultFound
from sqlmodel import Session
from langflow.services.auth import utils as auth_utils
from langflow.services.auth.utils import get_current_active_user
from langflow.services.database.models.user.model import User
from langflow.services.database.models.variable import Variable, VariableCreate, VariableRead, VariableUpdate
from langflow.services.deps import get_session, get_settings_service
from langflow.services.database.models.variable import VariableCreate, VariableRead, VariableUpdate
from langflow.services.deps import get_session, get_settings_service, get_variable_service
from langflow.services.variable.base import VariableService
from langflow.services.variable.service import GENERIC_TYPE, DatabaseVariableService
router = APIRouter(prefix="/variables", tags=["Variables"])
@ -20,36 +21,30 @@ def create_variable(
variable: VariableCreate,
current_user: User = Depends(get_current_active_user),
settings_service=Depends(get_settings_service),
variable_service: DatabaseVariableService = Depends(get_variable_service),
):
"""Create a new variable."""
try:
# check if variable name already exists
variable_exists = session.exec(
select(Variable).where(
Variable.name == variable.name,
Variable.user_id == current_user.id,
)
).first()
if variable_exists:
if not variable.name and not variable.value:
raise HTTPException(status_code=400, detail="Variable name and value cannot be empty")
if not variable.name:
raise HTTPException(status_code=400, detail="Variable name cannot be empty")
if not variable.value:
raise HTTPException(status_code=400, detail="Variable value cannot be empty")
if variable.name in variable_service.list_variables(user_id=current_user.id, session=session):
raise HTTPException(status_code=400, detail="Variable name already exists")
variable_dict = variable.model_dump()
variable_dict["user_id"] = current_user.id
db_variable = Variable.model_validate(variable_dict)
if not db_variable.name and not db_variable.value:
raise HTTPException(status_code=400, detail="Variable name and value cannot be empty")
elif not db_variable.name:
raise HTTPException(status_code=400, detail="Variable name cannot be empty")
elif not db_variable.value:
raise HTTPException(status_code=400, detail="Variable value cannot be empty")
encrypted = auth_utils.encrypt_api_key(db_variable.value, settings_service=settings_service)
db_variable.value = encrypted
db_variable.user_id = current_user.id
session.add(db_variable)
session.commit()
session.refresh(db_variable)
return db_variable
return variable_service.create_variable(
user_id=current_user.id,
name=variable.name,
value=variable.value,
default_fields=variable.default_fields or [],
_type=variable.type or GENERIC_TYPE,
session=session,
)
except Exception as e:
if isinstance(e, HTTPException):
raise e
@ -61,11 +56,12 @@ def read_variables(
*,
session: Session = Depends(get_session),
current_user: User = Depends(get_current_active_user),
variable_service: DatabaseVariableService = Depends(get_variable_service),
):
"""Read all variables."""
try:
variables = session.exec(select(Variable).where(Variable.user_id == current_user.id)).all()
return variables
return variable_service.get_all(user_id=current_user.id, session=session)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e
@ -77,22 +73,19 @@ def update_variable(
variable_id: UUID,
variable: VariableUpdate,
current_user: User = Depends(get_current_active_user),
variable_service: DatabaseVariableService = Depends(get_variable_service),
):
"""Update a variable."""
try:
db_variable = session.exec(
select(Variable).where(Variable.id == variable_id, Variable.user_id == current_user.id)
).first()
if not db_variable:
raise HTTPException(status_code=404, detail="Variable not found")
return variable_service.update_variable_fields(
user_id=current_user.id,
variable_id=variable_id,
variable=variable,
session=session,
)
except NoResultFound:
raise HTTPException(status_code=404, detail="Variable not found")
variable_data = variable.model_dump(exclude_unset=True)
for key, value in variable_data.items():
setattr(db_variable, key, value)
db_variable.updated_at = datetime.now(timezone.utc)
session.commit()
session.refresh(db_variable)
return db_variable
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e
@ -103,15 +96,10 @@ def delete_variable(
session: Session = Depends(get_session),
variable_id: UUID,
current_user: User = Depends(get_current_active_user),
variable_service: VariableService = Depends(get_variable_service),
):
"""Delete a variable."""
try:
db_variable = session.exec(
select(Variable).where(Variable.id == variable_id, Variable.user_id == current_user.id)
).first()
if not db_variable:
raise HTTPException(status_code=404, detail="Variable not found")
session.delete(db_variable)
session.commit()
variable_service.delete_variable_by_id(user_id=current_user.id, variable_id=variable_id, session=session)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e

View file

@ -69,7 +69,7 @@ class VariableService(Service):
"""
@abc.abstractmethod
def delete_variable(self, user_id: Union[UUID, str], name: str, session: Session) -> Variable:
def delete_variable(self, user_id: Union[UUID, str], name: str, session: Session) -> None:
"""
Delete a variable.
@ -82,6 +82,17 @@ class VariableService(Service):
The deleted variable.
"""
@abc.abstractmethod
def delete_variable_by_id(self, user_id: Union[UUID, str], variable_id: UUID, session: Session) -> None:
"""
Delete a variable by ID.
Args:
user_id: The user ID.
variable_id: The ID of the variable.
session: The database session.
"""
@abc.abstractmethod
def create_variable(
self,

View file

@ -2,6 +2,9 @@ import os
from typing import Optional, Tuple, Union
from uuid import UUID
from loguru import logger
from sqlmodel import Session
from langflow.services.auth import utils as auth_utils
from langflow.services.base import Service
from langflow.services.database.models.variable.model import Variable, VariableCreate
@ -9,8 +12,6 @@ from langflow.services.settings.service import SettingsService
from langflow.services.variable.base import VariableService
from langflow.services.variable.kubernetes_secrets import KubernetesSecretManager, encode_user_id
from langflow.services.variable.service import CREDENTIAL_TYPE, GENERIC_TYPE
from loguru import logger
from sqlmodel import Session
class KubernetesSecretService(VariableService, Service):
@ -110,17 +111,16 @@ class KubernetesSecretService(VariableService, Service):
secret_key, _ = self.resolve_variable(secret_name, user_id, name)
return self.kubernetes_secrets.update_secret(name=secret_name, data={secret_key: value})
def delete_variable(
self,
user_id: Union[UUID, str],
name: str,
_session: Session,
):
def delete_variable(self, user_id: Union[UUID, str], name: str, _session: Session) -> None:
secret_name = encode_user_id(user_id)
secret_key, _ = self.resolve_variable(secret_name, user_id, name)
self.kubernetes_secrets.delete_secret_key(name=secret_name, key=secret_key)
return
def delete_variable_by_id(self, user_id: Union[UUID, str], variable_id: UUID | str, _session: Session) -> None:
self.delete_variable(user_id, _session, str(variable_id))
def create_variable(
self,
user_id: Union[UUID, str],

View file

@ -1,4 +1,5 @@
import os
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Optional, Union
from uuid import UUID
@ -8,7 +9,7 @@ from sqlmodel import Session, select
from langflow.services.auth import utils as auth_utils
from langflow.services.base import Service
from langflow.services.database.models.variable.model import Variable, VariableCreate
from langflow.services.database.models.variable.model import Variable, VariableCreate, VariableUpdate
from langflow.services.deps import get_session
from langflow.services.variable.base import VariableService
@ -76,6 +77,9 @@ class DatabaseVariableService(VariableService, Service):
# credential = session.query(Variable).filter(Variable.user_id == user_id, Variable.name == name).first()
variable = session.exec(select(Variable).where(Variable.user_id == user_id, Variable.name == name)).first()
if not variable or not variable.value:
raise ValueError(f"{name} variable not found.")
if variable.type == CREDENTIAL_TYPE and field == "session_id": # type: ignore
raise TypeError(
f"variable {name} of type 'Credential' cannot be used in a Session ID field "
@ -83,14 +87,15 @@ class DatabaseVariableService(VariableService, Service):
)
# we decrypt the value
if not variable or not variable.value:
raise ValueError(f"{name} variable not found.")
decrypted = auth_utils.decrypt_api_key(variable.value, settings_service=self.settings_service)
return decrypted
def get_all(self, user_id: Union[UUID, str], session: Session = Depends(get_session)) -> list[Optional[Variable]]:
return list(session.exec(select(Variable).where(Variable.user_id == user_id)).all())
def list_variables(self, user_id: Union[UUID, str], session: Session = Depends(get_session)) -> list[Optional[str]]:
variables = session.exec(select(Variable).where(Variable.user_id == user_id)).all()
return [variable.name for variable in variables]
variables = self.get_all(user_id=user_id, session=session)
return [variable.name for variable in variables if variable]
def update_variable(
self,
@ -109,18 +114,47 @@ class DatabaseVariableService(VariableService, Service):
session.refresh(variable)
return variable
def update_variable_fields(
self,
user_id: Union[UUID, str],
variable_id: Union[UUID, str],
variable: VariableUpdate,
session: Session = Depends(get_session),
):
query = select(Variable).where(Variable.id == variable_id, Variable.user_id == user_id)
db_variable = session.exec(query).one()
variable_data = variable.model_dump(exclude_unset=True)
for key, value in variable_data.items():
setattr(db_variable, key, value)
db_variable.updated_at = datetime.now(timezone.utc)
encrypted = auth_utils.encrypt_api_key(db_variable.value, settings_service=self.settings_service)
variable.value = encrypted
session.add(db_variable)
session.commit()
session.refresh(db_variable)
return db_variable
def delete_variable(
self,
user_id: Union[UUID, str],
name: str,
session: Session = Depends(get_session),
):
variable = session.exec(select(Variable).where(Variable.user_id == user_id, Variable.name == name)).first()
stmt = select(Variable).where(Variable.user_id == user_id).where(Variable.name == name)
variable = session.exec(stmt).first()
if not variable:
raise ValueError(f"{name} variable not found.")
session.delete(variable)
session.commit()
return variable
def delete_variable_by_id(self, user_id: Union[UUID, str], variable_id: UUID, session: Session):
variable = session.exec(select(Variable).where(Variable.user_id == user_id, Variable.id == variable_id)).first()
if not variable:
raise ValueError(f"{variable_id} variable not found.")
session.delete(variable)
session.commit()
def create_variable(
self,

View file

@ -0,0 +1,180 @@
import pytest
from uuid import uuid4
from unittest import mock
from fastapi import status, HTTPException
@pytest.fixture
def body():
return {
"name": "test_variable",
"value": "test_value",
"type": "test_type",
"default_fields": ["test_field"],
}
def test_create_variable(client, body, active_user, logged_in_headers):
response = client.post("api/v1/variables", json=body, headers=logged_in_headers)
result = response.json()
assert status.HTTP_201_CREATED == response.status_code
assert body["name"] == result["name"]
assert body["type"] == result["type"]
assert body["default_fields"] == result["default_fields"]
assert "id" in result.keys()
assert "value" not in result.keys()
def test_create_variable__variable_name_alread_exists(client, body, active_user, logged_in_headers):
client.post("api/v1/variables", json=body, headers=logged_in_headers)
response = client.post("api/v1/variables", json=body, headers=logged_in_headers)
result = response.json()
assert status.HTTP_400_BAD_REQUEST == response.status_code
assert "Variable name already exists" in result["detail"]
def test_create_variable__variable_name_and_value_cannot_be_empty(client, body, active_user, logged_in_headers):
body["name"] = ""
body["value"] = ""
response = client.post("api/v1/variables", json=body, headers=logged_in_headers)
result = response.json()
assert status.HTTP_400_BAD_REQUEST == response.status_code
assert "Variable name and value cannot be empty" in result["detail"]
def test_create_variable__variable_name_cannot_be_empty(client, body, active_user, logged_in_headers):
body["name"] = ""
response = client.post("api/v1/variables", json=body, headers=logged_in_headers)
result = response.json()
assert status.HTTP_400_BAD_REQUEST == response.status_code
assert "Variable name cannot be empty" in result["detail"]
def test_create_variable__variable_value_cannot_be_empty(client, body, active_user, logged_in_headers):
body["value"] = ""
response = client.post("api/v1/variables", json=body, headers=logged_in_headers)
result = response.json()
assert status.HTTP_400_BAD_REQUEST == response.status_code
assert "Variable value cannot be empty" in result["detail"]
def test_create_variable__HTTPException(client, body, active_user, logged_in_headers):
status_code = 418
generic_message = "I'm a teapot"
with mock.patch("langflow.services.auth.utils.encrypt_api_key") as m:
m.side_effect = HTTPException(status_code=status_code, detail=generic_message)
response = client.post("api/v1/variables", json=body, headers=logged_in_headers)
result = response.json()
assert status.HTTP_418_IM_A_TEAPOT == response.status_code
assert generic_message in result["detail"]
def test_create_variable__Exception(client, body, active_user, logged_in_headers):
generic_message = "Generic error message"
with mock.patch("langflow.services.auth.utils.encrypt_api_key") as m:
m.side_effect = Exception(generic_message)
response = client.post("api/v1/variables", json=body, headers=logged_in_headers)
result = response.json()
assert status.HTTP_500_INTERNAL_SERVER_ERROR == response.status_code
assert generic_message in result["detail"]
def test_read_variables(client, body, active_user, logged_in_headers):
names = ["test_variable1", "test_variable2", "test_variable3"]
for name in names:
body["name"] = name
client.post("api/v1/variables", json=body, headers=logged_in_headers)
response = client.get("api/v1/variables", headers=logged_in_headers)
result = response.json()
assert status.HTTP_200_OK == response.status_code
assert all(name in [r["name"] for r in result] for name in names)
def test_read_variables__empty(client, active_user, logged_in_headers):
all_variables = client.get("api/v1/variables", headers=logged_in_headers).json()
for variable in all_variables:
client.delete(f"api/v1/variables/{variable.get('id')}", headers=logged_in_headers)
response = client.get("api/v1/variables", headers=logged_in_headers)
result = response.json()
assert status.HTTP_200_OK == response.status_code
assert [] == result
def test_read_variables__(client, active_user, logged_in_headers): # TODO check if this is correct
generic_message = "Generic error message"
with pytest.raises(Exception) as exc:
with mock.patch("sqlmodel.Session.exec") as m:
m.side_effect = Exception(generic_message)
response = client.get("api/v1/variables", headers=logged_in_headers)
result = response.json()
assert status.HTTP_500_INTERNAL_SERVER_ERROR == response.status_code
assert generic_message in result["detail"]
assert generic_message in str(exc.value)
def test_update_variable(client, body, active_user, logged_in_headers):
saved = client.post("api/v1/variables", json=body, headers=logged_in_headers).json()
body["id"] = saved.get("id")
body["name"] = "new_name"
body["value"] = "new_value"
body["type"] = "new_type"
body["default_fields"] = ["new_field"]
response = client.patch(f"api/v1/variables/{saved.get('id')}", json=body, headers=logged_in_headers)
result = response.json()
assert status.HTTP_200_OK == response.status_code
assert saved["id"] == result["id"]
assert saved["name"] != result["name"]
# assert saved["type"] != result["type"] # TODO check if this is correct
assert saved["default_fields"] != result["default_fields"]
def test_update_variable__Exception(client, body, active_user, logged_in_headers):
wrong_id = uuid4()
body["id"] = str(wrong_id)
response = client.patch(f"api/v1/variables/{wrong_id}", json=body, headers=logged_in_headers)
result = response.json()
# assert status.HTTP_404_NOT_FOUND == response.status_code # TODO check if this is correct
assert "Variable not found" in result["detail"]
def test_delete_variable(client, body, active_user, logged_in_headers):
saved = client.post("api/v1/variables", json=body, headers=logged_in_headers).json()
response = client.delete(f"api/v1/variables/{saved.get('id')}", headers=logged_in_headers)
assert status.HTTP_204_NO_CONTENT == response.status_code
def test_delete_variable__Exception(client, active_user, logged_in_headers):
wrong_id = uuid4()
response = client.delete(f"api/v1/variables/{wrong_id}", headers=logged_in_headers)
# assert status.HTTP_404_NOT_FOUND == response.status_code # TODO check if this is correct
assert status.HTTP_500_INTERNAL_SERVER_ERROR == response.status_code

View file

@ -0,0 +1,219 @@
from langflow.services.database.models.variable.model import VariableUpdate
import pytest
from unittest.mock import patch
from uuid import uuid4
from datetime import datetime
from sqlmodel import SQLModel, Session, create_engine
from langflow.services.deps import get_settings_service
from langflow.services.variable.service import GENERIC_TYPE, CREDENTIAL_TYPE, DatabaseVariableService
@pytest.fixture
def client():
pass
@pytest.fixture
def service():
settings_service = get_settings_service()
return DatabaseVariableService(settings_service)
@pytest.fixture
def session():
engine = create_engine("sqlite:///:memory:")
SQLModel.metadata.create_all(engine)
with Session(engine) as session:
yield session
def test_initialize_user_variables__donkey(service, session):
user_id = uuid4()
name = "OPENAI_API_KEY"
value = "donkey"
service.initialize_user_variables(user_id, session=session)
result = service.create_variable(user_id, "OPENAI_API_KEY", "donkey", session=session)
new_service = DatabaseVariableService(get_settings_service())
new_service.initialize_user_variables(user_id, session=session)
result = new_service.get_variable(user_id, name, "", session=session)
assert result != value
def test_initialize_user_variables__not_found_variable(service, session):
with patch("langflow.services.variable.service.DatabaseVariableService.create_variable") as m:
m.side_effect = Exception()
service.initialize_user_variables(uuid4(), session=session)
assert True
def test_initialize_user_variables__skipping_environment_variable_storage(service, session):
service.settings_service.settings.store_environment_variables = False
service.initialize_user_variables(uuid4(), session=session)
assert True
def test_get_variable(service, session):
user_id = uuid4()
name = "name"
value = "value"
field = ""
service.create_variable(user_id, name, value, session=session)
result = service.get_variable(user_id, name, field, session=session)
assert result == value
def test_get_variable__ValueError(service, session):
user_id = uuid4()
name = "name"
field = ""
with pytest.raises(ValueError) as exc:
service.get_variable(user_id, name, field, session)
assert name in str(exc.value)
assert "variable not found" in str(exc.value)
def test_get_variable__TypeError(service, session):
user_id = uuid4()
name = "name"
value = "value"
field = "session_id"
_type = CREDENTIAL_TYPE
service.create_variable(user_id, name, value, _type=_type, session=session)
with pytest.raises(TypeError) as exc:
service.get_variable(user_id, name, field, session)
assert name in str(exc.value)
assert "purpose is to prevent the exposure of value" in str(exc.value)
def test_list_variables(service, session):
user_id = uuid4()
names = ["name1", "name2", "name3"]
value = "value"
for name in names:
service.create_variable(user_id, name, value, session=session)
result = service.list_variables(user_id, session=session)
assert all(name in result for name in names)
def test_list_variables__empty(service, session):
result = service.list_variables(uuid4(), session=session)
assert not result
assert isinstance(result, list)
def test_update_variable(service, session):
user_id = uuid4()
name = "name"
old_value = "old_value"
new_value = "new_value"
field = ""
service.create_variable(user_id, name, old_value, session=session)
old_recovered = service.get_variable(user_id, name, field, session=session)
result = service.update_variable(user_id, name, new_value, session=session)
new_recovered = service.get_variable(user_id, name, field, session=session)
assert old_value == old_recovered
assert new_value == new_recovered
assert result.user_id == user_id
assert result.name == name
assert result.value != old_value
assert result.value != new_value
assert result.default_fields == []
assert result.type == GENERIC_TYPE
assert isinstance(result.created_at, datetime)
assert isinstance(result.updated_at, datetime)
def test_update_variable__ValueError(service, session):
user_id = uuid4()
name = "name"
value = "value"
with pytest.raises(ValueError) as exc:
service.update_variable(user_id, name, value, session=session)
assert name in str(exc.value)
assert "variable not found" in str(exc.value)
def test_update_variable_fields(service, session):
user_id = uuid4()
variable = service.create_variable(user_id, "old_name", "old_value", session=session)
saved = variable.model_dump()
variable = VariableUpdate(**saved)
variable.name = "new_name"
variable.value = "new_value"
variable.default_fields = ["new_field"]
result = service.update_variable_fields(
user_id=user_id,
variable_id=saved.get("id"),
variable=variable,
session=session,
)
assert saved.get("id") == result.id
assert saved.get("user_id") == result.user_id
assert saved.get("name") != result.name
assert saved.get("value") != result.value
assert saved.get("default_fields") != result.default_fields
assert saved.get("type") == result.type
assert saved.get("created_at") == result.created_at
assert saved.get("updated_at") != result.updated_at
def test_delete_variable(service, session):
user_id = uuid4()
name = "name"
value = "value"
field = ""
saved = service.create_variable(user_id, name, value, session=session)
recovered = service.get_variable(user_id, name, field, session=session)
service.delete_variable(user_id, name, session=session)
with pytest.raises(ValueError) as exc:
service.get_variable(user_id, name, field, session)
assert recovered == value
assert name in str(exc.value)
assert "variable not found" in str(exc.value)
def test_delete_variable__ValueError(service, session):
user_id = uuid4()
name = "name"
with pytest.raises(ValueError) as exc:
service.delete_variable(user_id, name, session=session)
assert name in str(exc.value)
assert "variable not found" in str(exc.value)
def test_create_variable(service, session):
user_id = uuid4()
name = "name"
value = "value"
result = service.create_variable(user_id, name, value, session=session)
assert result.user_id == user_id
assert result.name == name
assert result.value != value
assert result.default_fields == []
assert result.type == GENERIC_TYPE
assert isinstance(result.created_at, datetime)
assert isinstance(result.updated_at, datetime)