feat: centralize global variable management (#3284)
* test: add tests for global variable endpoints * test: add unit tests variable service * fix: anticipate checks to prevent the code from breaking * feat: add a new method to interface * feat: add method to update fields in variable service * feat: replace variable api code * fix: mypy error * fix: mypy error * feat(variable): Allow deleting variables by name or ID in DatabaseVariableService. * refactor(api): Simplify delete method in variable router. --------- Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>
This commit is contained in:
parent
665842680e
commit
952ba5eef1
6 changed files with 499 additions and 67 deletions
|
|
@ -1,14 +1,15 @@
|
|||
from datetime import datetime, timezone
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlmodel import Session, select
|
||||
from sqlalchemy.exc import NoResultFound
|
||||
from sqlmodel import Session
|
||||
|
||||
from langflow.services.auth import utils as auth_utils
|
||||
from langflow.services.auth.utils import get_current_active_user
|
||||
from langflow.services.database.models.user.model import User
|
||||
from langflow.services.database.models.variable import Variable, VariableCreate, VariableRead, VariableUpdate
|
||||
from langflow.services.deps import get_session, get_settings_service
|
||||
from langflow.services.database.models.variable import VariableCreate, VariableRead, VariableUpdate
|
||||
from langflow.services.deps import get_session, get_settings_service, get_variable_service
|
||||
from langflow.services.variable.base import VariableService
|
||||
from langflow.services.variable.service import GENERIC_TYPE, DatabaseVariableService
|
||||
|
||||
router = APIRouter(prefix="/variables", tags=["Variables"])
|
||||
|
||||
|
|
@ -20,36 +21,30 @@ def create_variable(
|
|||
variable: VariableCreate,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
settings_service=Depends(get_settings_service),
|
||||
variable_service: DatabaseVariableService = Depends(get_variable_service),
|
||||
):
|
||||
"""Create a new variable."""
|
||||
try:
|
||||
# check if variable name already exists
|
||||
variable_exists = session.exec(
|
||||
select(Variable).where(
|
||||
Variable.name == variable.name,
|
||||
Variable.user_id == current_user.id,
|
||||
)
|
||||
).first()
|
||||
if variable_exists:
|
||||
if not variable.name and not variable.value:
|
||||
raise HTTPException(status_code=400, detail="Variable name and value cannot be empty")
|
||||
|
||||
if not variable.name:
|
||||
raise HTTPException(status_code=400, detail="Variable name cannot be empty")
|
||||
|
||||
if not variable.value:
|
||||
raise HTTPException(status_code=400, detail="Variable value cannot be empty")
|
||||
|
||||
if variable.name in variable_service.list_variables(user_id=current_user.id, session=session):
|
||||
raise HTTPException(status_code=400, detail="Variable name already exists")
|
||||
|
||||
variable_dict = variable.model_dump()
|
||||
variable_dict["user_id"] = current_user.id
|
||||
|
||||
db_variable = Variable.model_validate(variable_dict)
|
||||
if not db_variable.name and not db_variable.value:
|
||||
raise HTTPException(status_code=400, detail="Variable name and value cannot be empty")
|
||||
elif not db_variable.name:
|
||||
raise HTTPException(status_code=400, detail="Variable name cannot be empty")
|
||||
elif not db_variable.value:
|
||||
raise HTTPException(status_code=400, detail="Variable value cannot be empty")
|
||||
encrypted = auth_utils.encrypt_api_key(db_variable.value, settings_service=settings_service)
|
||||
db_variable.value = encrypted
|
||||
db_variable.user_id = current_user.id
|
||||
session.add(db_variable)
|
||||
session.commit()
|
||||
session.refresh(db_variable)
|
||||
return db_variable
|
||||
return variable_service.create_variable(
|
||||
user_id=current_user.id,
|
||||
name=variable.name,
|
||||
value=variable.value,
|
||||
default_fields=variable.default_fields or [],
|
||||
_type=variable.type or GENERIC_TYPE,
|
||||
session=session,
|
||||
)
|
||||
except Exception as e:
|
||||
if isinstance(e, HTTPException):
|
||||
raise e
|
||||
|
|
@ -61,11 +56,12 @@ def read_variables(
|
|||
*,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
variable_service: DatabaseVariableService = Depends(get_variable_service),
|
||||
):
|
||||
"""Read all variables."""
|
||||
try:
|
||||
variables = session.exec(select(Variable).where(Variable.user_id == current_user.id)).all()
|
||||
return variables
|
||||
return variable_service.get_all(user_id=current_user.id, session=session)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
|
|
@ -77,22 +73,19 @@ def update_variable(
|
|||
variable_id: UUID,
|
||||
variable: VariableUpdate,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
variable_service: DatabaseVariableService = Depends(get_variable_service),
|
||||
):
|
||||
"""Update a variable."""
|
||||
try:
|
||||
db_variable = session.exec(
|
||||
select(Variable).where(Variable.id == variable_id, Variable.user_id == current_user.id)
|
||||
).first()
|
||||
if not db_variable:
|
||||
raise HTTPException(status_code=404, detail="Variable not found")
|
||||
return variable_service.update_variable_fields(
|
||||
user_id=current_user.id,
|
||||
variable_id=variable_id,
|
||||
variable=variable,
|
||||
session=session,
|
||||
)
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail="Variable not found")
|
||||
|
||||
variable_data = variable.model_dump(exclude_unset=True)
|
||||
for key, value in variable_data.items():
|
||||
setattr(db_variable, key, value)
|
||||
db_variable.updated_at = datetime.now(timezone.utc)
|
||||
session.commit()
|
||||
session.refresh(db_variable)
|
||||
return db_variable
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
|
|
@ -103,15 +96,10 @@ def delete_variable(
|
|||
session: Session = Depends(get_session),
|
||||
variable_id: UUID,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
variable_service: VariableService = Depends(get_variable_service),
|
||||
):
|
||||
"""Delete a variable."""
|
||||
try:
|
||||
db_variable = session.exec(
|
||||
select(Variable).where(Variable.id == variable_id, Variable.user_id == current_user.id)
|
||||
).first()
|
||||
if not db_variable:
|
||||
raise HTTPException(status_code=404, detail="Variable not found")
|
||||
session.delete(db_variable)
|
||||
session.commit()
|
||||
variable_service.delete_variable_by_id(user_id=current_user.id, variable_id=variable_id, session=session)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ class VariableService(Service):
|
|||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete_variable(self, user_id: Union[UUID, str], name: str, session: Session) -> Variable:
|
||||
def delete_variable(self, user_id: Union[UUID, str], name: str, session: Session) -> None:
|
||||
"""
|
||||
Delete a variable.
|
||||
|
||||
|
|
@ -82,6 +82,17 @@ class VariableService(Service):
|
|||
The deleted variable.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete_variable_by_id(self, user_id: Union[UUID, str], variable_id: UUID, session: Session) -> None:
|
||||
"""
|
||||
Delete a variable by ID.
|
||||
|
||||
Args:
|
||||
user_id: The user ID.
|
||||
variable_id: The ID of the variable.
|
||||
session: The database session.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def create_variable(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -2,6 +2,9 @@ import os
|
|||
from typing import Optional, Tuple, Union
|
||||
from uuid import UUID
|
||||
|
||||
from loguru import logger
|
||||
from sqlmodel import Session
|
||||
|
||||
from langflow.services.auth import utils as auth_utils
|
||||
from langflow.services.base import Service
|
||||
from langflow.services.database.models.variable.model import Variable, VariableCreate
|
||||
|
|
@ -9,8 +12,6 @@ from langflow.services.settings.service import SettingsService
|
|||
from langflow.services.variable.base import VariableService
|
||||
from langflow.services.variable.kubernetes_secrets import KubernetesSecretManager, encode_user_id
|
||||
from langflow.services.variable.service import CREDENTIAL_TYPE, GENERIC_TYPE
|
||||
from loguru import logger
|
||||
from sqlmodel import Session
|
||||
|
||||
|
||||
class KubernetesSecretService(VariableService, Service):
|
||||
|
|
@ -110,17 +111,16 @@ class KubernetesSecretService(VariableService, Service):
|
|||
secret_key, _ = self.resolve_variable(secret_name, user_id, name)
|
||||
return self.kubernetes_secrets.update_secret(name=secret_name, data={secret_key: value})
|
||||
|
||||
def delete_variable(
|
||||
self,
|
||||
user_id: Union[UUID, str],
|
||||
name: str,
|
||||
_session: Session,
|
||||
):
|
||||
def delete_variable(self, user_id: Union[UUID, str], name: str, _session: Session) -> None:
|
||||
secret_name = encode_user_id(user_id)
|
||||
|
||||
secret_key, _ = self.resolve_variable(secret_name, user_id, name)
|
||||
self.kubernetes_secrets.delete_secret_key(name=secret_name, key=secret_key)
|
||||
return
|
||||
|
||||
def delete_variable_by_id(self, user_id: Union[UUID, str], variable_id: UUID | str, _session: Session) -> None:
|
||||
self.delete_variable(user_id, _session, str(variable_id))
|
||||
|
||||
def create_variable(
|
||||
self,
|
||||
user_id: Union[UUID, str],
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
|
|
@ -8,7 +9,7 @@ from sqlmodel import Session, select
|
|||
|
||||
from langflow.services.auth import utils as auth_utils
|
||||
from langflow.services.base import Service
|
||||
from langflow.services.database.models.variable.model import Variable, VariableCreate
|
||||
from langflow.services.database.models.variable.model import Variable, VariableCreate, VariableUpdate
|
||||
from langflow.services.deps import get_session
|
||||
from langflow.services.variable.base import VariableService
|
||||
|
||||
|
|
@ -76,6 +77,9 @@ class DatabaseVariableService(VariableService, Service):
|
|||
# credential = session.query(Variable).filter(Variable.user_id == user_id, Variable.name == name).first()
|
||||
variable = session.exec(select(Variable).where(Variable.user_id == user_id, Variable.name == name)).first()
|
||||
|
||||
if not variable or not variable.value:
|
||||
raise ValueError(f"{name} variable not found.")
|
||||
|
||||
if variable.type == CREDENTIAL_TYPE and field == "session_id": # type: ignore
|
||||
raise TypeError(
|
||||
f"variable {name} of type 'Credential' cannot be used in a Session ID field "
|
||||
|
|
@ -83,14 +87,15 @@ class DatabaseVariableService(VariableService, Service):
|
|||
)
|
||||
|
||||
# we decrypt the value
|
||||
if not variable or not variable.value:
|
||||
raise ValueError(f"{name} variable not found.")
|
||||
decrypted = auth_utils.decrypt_api_key(variable.value, settings_service=self.settings_service)
|
||||
return decrypted
|
||||
|
||||
def get_all(self, user_id: Union[UUID, str], session: Session = Depends(get_session)) -> list[Optional[Variable]]:
|
||||
return list(session.exec(select(Variable).where(Variable.user_id == user_id)).all())
|
||||
|
||||
def list_variables(self, user_id: Union[UUID, str], session: Session = Depends(get_session)) -> list[Optional[str]]:
|
||||
variables = session.exec(select(Variable).where(Variable.user_id == user_id)).all()
|
||||
return [variable.name for variable in variables]
|
||||
variables = self.get_all(user_id=user_id, session=session)
|
||||
return [variable.name for variable in variables if variable]
|
||||
|
||||
def update_variable(
|
||||
self,
|
||||
|
|
@ -109,18 +114,47 @@ class DatabaseVariableService(VariableService, Service):
|
|||
session.refresh(variable)
|
||||
return variable
|
||||
|
||||
def update_variable_fields(
|
||||
self,
|
||||
user_id: Union[UUID, str],
|
||||
variable_id: Union[UUID, str],
|
||||
variable: VariableUpdate,
|
||||
session: Session = Depends(get_session),
|
||||
):
|
||||
query = select(Variable).where(Variable.id == variable_id, Variable.user_id == user_id)
|
||||
db_variable = session.exec(query).one()
|
||||
|
||||
variable_data = variable.model_dump(exclude_unset=True)
|
||||
for key, value in variable_data.items():
|
||||
setattr(db_variable, key, value)
|
||||
db_variable.updated_at = datetime.now(timezone.utc)
|
||||
encrypted = auth_utils.encrypt_api_key(db_variable.value, settings_service=self.settings_service)
|
||||
variable.value = encrypted
|
||||
|
||||
session.add(db_variable)
|
||||
session.commit()
|
||||
session.refresh(db_variable)
|
||||
return db_variable
|
||||
|
||||
def delete_variable(
|
||||
self,
|
||||
user_id: Union[UUID, str],
|
||||
name: str,
|
||||
session: Session = Depends(get_session),
|
||||
):
|
||||
variable = session.exec(select(Variable).where(Variable.user_id == user_id, Variable.name == name)).first()
|
||||
stmt = select(Variable).where(Variable.user_id == user_id).where(Variable.name == name)
|
||||
variable = session.exec(stmt).first()
|
||||
if not variable:
|
||||
raise ValueError(f"{name} variable not found.")
|
||||
session.delete(variable)
|
||||
session.commit()
|
||||
return variable
|
||||
|
||||
def delete_variable_by_id(self, user_id: Union[UUID, str], variable_id: UUID, session: Session):
|
||||
variable = session.exec(select(Variable).where(Variable.user_id == user_id, Variable.id == variable_id)).first()
|
||||
if not variable:
|
||||
raise ValueError(f"{variable_id} variable not found.")
|
||||
session.delete(variable)
|
||||
session.commit()
|
||||
|
||||
def create_variable(
|
||||
self,
|
||||
|
|
|
|||
180
src/backend/tests/unit/api/v1/test_variable.py
Normal file
180
src/backend/tests/unit/api/v1/test_variable.py
Normal file
|
|
@ -0,0 +1,180 @@
|
|||
import pytest
|
||||
from uuid import uuid4
|
||||
from unittest import mock
|
||||
|
||||
from fastapi import status, HTTPException
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def body():
|
||||
return {
|
||||
"name": "test_variable",
|
||||
"value": "test_value",
|
||||
"type": "test_type",
|
||||
"default_fields": ["test_field"],
|
||||
}
|
||||
|
||||
|
||||
def test_create_variable(client, body, active_user, logged_in_headers):
|
||||
response = client.post("api/v1/variables", json=body, headers=logged_in_headers)
|
||||
result = response.json()
|
||||
|
||||
assert status.HTTP_201_CREATED == response.status_code
|
||||
assert body["name"] == result["name"]
|
||||
assert body["type"] == result["type"]
|
||||
assert body["default_fields"] == result["default_fields"]
|
||||
assert "id" in result.keys()
|
||||
assert "value" not in result.keys()
|
||||
|
||||
|
||||
def test_create_variable__variable_name_alread_exists(client, body, active_user, logged_in_headers):
|
||||
client.post("api/v1/variables", json=body, headers=logged_in_headers)
|
||||
|
||||
response = client.post("api/v1/variables", json=body, headers=logged_in_headers)
|
||||
result = response.json()
|
||||
|
||||
assert status.HTTP_400_BAD_REQUEST == response.status_code
|
||||
assert "Variable name already exists" in result["detail"]
|
||||
|
||||
|
||||
def test_create_variable__variable_name_and_value_cannot_be_empty(client, body, active_user, logged_in_headers):
|
||||
body["name"] = ""
|
||||
body["value"] = ""
|
||||
|
||||
response = client.post("api/v1/variables", json=body, headers=logged_in_headers)
|
||||
result = response.json()
|
||||
|
||||
assert status.HTTP_400_BAD_REQUEST == response.status_code
|
||||
assert "Variable name and value cannot be empty" in result["detail"]
|
||||
|
||||
|
||||
def test_create_variable__variable_name_cannot_be_empty(client, body, active_user, logged_in_headers):
|
||||
body["name"] = ""
|
||||
|
||||
response = client.post("api/v1/variables", json=body, headers=logged_in_headers)
|
||||
result = response.json()
|
||||
|
||||
assert status.HTTP_400_BAD_REQUEST == response.status_code
|
||||
assert "Variable name cannot be empty" in result["detail"]
|
||||
|
||||
|
||||
def test_create_variable__variable_value_cannot_be_empty(client, body, active_user, logged_in_headers):
|
||||
body["value"] = ""
|
||||
|
||||
response = client.post("api/v1/variables", json=body, headers=logged_in_headers)
|
||||
result = response.json()
|
||||
|
||||
assert status.HTTP_400_BAD_REQUEST == response.status_code
|
||||
assert "Variable value cannot be empty" in result["detail"]
|
||||
|
||||
|
||||
def test_create_variable__HTTPException(client, body, active_user, logged_in_headers):
|
||||
status_code = 418
|
||||
generic_message = "I'm a teapot"
|
||||
|
||||
with mock.patch("langflow.services.auth.utils.encrypt_api_key") as m:
|
||||
m.side_effect = HTTPException(status_code=status_code, detail=generic_message)
|
||||
response = client.post("api/v1/variables", json=body, headers=logged_in_headers)
|
||||
result = response.json()
|
||||
|
||||
assert status.HTTP_418_IM_A_TEAPOT == response.status_code
|
||||
assert generic_message in result["detail"]
|
||||
|
||||
|
||||
def test_create_variable__Exception(client, body, active_user, logged_in_headers):
|
||||
generic_message = "Generic error message"
|
||||
|
||||
with mock.patch("langflow.services.auth.utils.encrypt_api_key") as m:
|
||||
m.side_effect = Exception(generic_message)
|
||||
response = client.post("api/v1/variables", json=body, headers=logged_in_headers)
|
||||
result = response.json()
|
||||
|
||||
assert status.HTTP_500_INTERNAL_SERVER_ERROR == response.status_code
|
||||
assert generic_message in result["detail"]
|
||||
|
||||
|
||||
def test_read_variables(client, body, active_user, logged_in_headers):
|
||||
names = ["test_variable1", "test_variable2", "test_variable3"]
|
||||
for name in names:
|
||||
body["name"] = name
|
||||
client.post("api/v1/variables", json=body, headers=logged_in_headers)
|
||||
|
||||
response = client.get("api/v1/variables", headers=logged_in_headers)
|
||||
result = response.json()
|
||||
|
||||
assert status.HTTP_200_OK == response.status_code
|
||||
assert all(name in [r["name"] for r in result] for name in names)
|
||||
|
||||
|
||||
def test_read_variables__empty(client, active_user, logged_in_headers):
|
||||
all_variables = client.get("api/v1/variables", headers=logged_in_headers).json()
|
||||
for variable in all_variables:
|
||||
client.delete(f"api/v1/variables/{variable.get('id')}", headers=logged_in_headers)
|
||||
|
||||
response = client.get("api/v1/variables", headers=logged_in_headers)
|
||||
result = response.json()
|
||||
|
||||
assert status.HTTP_200_OK == response.status_code
|
||||
assert [] == result
|
||||
|
||||
|
||||
def test_read_variables__(client, active_user, logged_in_headers): # TODO check if this is correct
|
||||
generic_message = "Generic error message"
|
||||
|
||||
with pytest.raises(Exception) as exc:
|
||||
with mock.patch("sqlmodel.Session.exec") as m:
|
||||
m.side_effect = Exception(generic_message)
|
||||
|
||||
response = client.get("api/v1/variables", headers=logged_in_headers)
|
||||
result = response.json()
|
||||
|
||||
assert status.HTTP_500_INTERNAL_SERVER_ERROR == response.status_code
|
||||
assert generic_message in result["detail"]
|
||||
|
||||
assert generic_message in str(exc.value)
|
||||
|
||||
|
||||
def test_update_variable(client, body, active_user, logged_in_headers):
|
||||
saved = client.post("api/v1/variables", json=body, headers=logged_in_headers).json()
|
||||
body["id"] = saved.get("id")
|
||||
body["name"] = "new_name"
|
||||
body["value"] = "new_value"
|
||||
body["type"] = "new_type"
|
||||
body["default_fields"] = ["new_field"]
|
||||
|
||||
response = client.patch(f"api/v1/variables/{saved.get('id')}", json=body, headers=logged_in_headers)
|
||||
result = response.json()
|
||||
|
||||
assert status.HTTP_200_OK == response.status_code
|
||||
assert saved["id"] == result["id"]
|
||||
assert saved["name"] != result["name"]
|
||||
# assert saved["type"] != result["type"] # TODO check if this is correct
|
||||
assert saved["default_fields"] != result["default_fields"]
|
||||
|
||||
|
||||
def test_update_variable__Exception(client, body, active_user, logged_in_headers):
|
||||
wrong_id = uuid4()
|
||||
body["id"] = str(wrong_id)
|
||||
|
||||
response = client.patch(f"api/v1/variables/{wrong_id}", json=body, headers=logged_in_headers)
|
||||
result = response.json()
|
||||
|
||||
# assert status.HTTP_404_NOT_FOUND == response.status_code # TODO check if this is correct
|
||||
assert "Variable not found" in result["detail"]
|
||||
|
||||
|
||||
def test_delete_variable(client, body, active_user, logged_in_headers):
|
||||
saved = client.post("api/v1/variables", json=body, headers=logged_in_headers).json()
|
||||
|
||||
response = client.delete(f"api/v1/variables/{saved.get('id')}", headers=logged_in_headers)
|
||||
|
||||
assert status.HTTP_204_NO_CONTENT == response.status_code
|
||||
|
||||
|
||||
def test_delete_variable__Exception(client, active_user, logged_in_headers):
|
||||
wrong_id = uuid4()
|
||||
|
||||
response = client.delete(f"api/v1/variables/{wrong_id}", headers=logged_in_headers)
|
||||
|
||||
# assert status.HTTP_404_NOT_FOUND == response.status_code # TODO check if this is correct
|
||||
assert status.HTTP_500_INTERNAL_SERVER_ERROR == response.status_code
|
||||
219
src/backend/tests/unit/services/variable/test_service.py
Normal file
219
src/backend/tests/unit/services/variable/test_service.py
Normal file
|
|
@ -0,0 +1,219 @@
|
|||
from langflow.services.database.models.variable.model import VariableUpdate
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
from datetime import datetime
|
||||
from sqlmodel import SQLModel, Session, create_engine
|
||||
from langflow.services.deps import get_settings_service
|
||||
from langflow.services.variable.service import GENERIC_TYPE, CREDENTIAL_TYPE, DatabaseVariableService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service():
|
||||
settings_service = get_settings_service()
|
||||
return DatabaseVariableService(settings_service)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session():
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
SQLModel.metadata.create_all(engine)
|
||||
with Session(engine) as session:
|
||||
yield session
|
||||
|
||||
|
||||
def test_initialize_user_variables__donkey(service, session):
|
||||
user_id = uuid4()
|
||||
name = "OPENAI_API_KEY"
|
||||
value = "donkey"
|
||||
service.initialize_user_variables(user_id, session=session)
|
||||
result = service.create_variable(user_id, "OPENAI_API_KEY", "donkey", session=session)
|
||||
new_service = DatabaseVariableService(get_settings_service())
|
||||
new_service.initialize_user_variables(user_id, session=session)
|
||||
|
||||
result = new_service.get_variable(user_id, name, "", session=session)
|
||||
|
||||
assert result != value
|
||||
|
||||
|
||||
def test_initialize_user_variables__not_found_variable(service, session):
|
||||
with patch("langflow.services.variable.service.DatabaseVariableService.create_variable") as m:
|
||||
m.side_effect = Exception()
|
||||
service.initialize_user_variables(uuid4(), session=session)
|
||||
|
||||
assert True
|
||||
|
||||
|
||||
def test_initialize_user_variables__skipping_environment_variable_storage(service, session):
|
||||
service.settings_service.settings.store_environment_variables = False
|
||||
service.initialize_user_variables(uuid4(), session=session)
|
||||
assert True
|
||||
|
||||
|
||||
def test_get_variable(service, session):
|
||||
user_id = uuid4()
|
||||
name = "name"
|
||||
value = "value"
|
||||
field = ""
|
||||
service.create_variable(user_id, name, value, session=session)
|
||||
|
||||
result = service.get_variable(user_id, name, field, session=session)
|
||||
|
||||
assert result == value
|
||||
|
||||
|
||||
def test_get_variable__ValueError(service, session):
|
||||
user_id = uuid4()
|
||||
name = "name"
|
||||
field = ""
|
||||
|
||||
with pytest.raises(ValueError) as exc:
|
||||
service.get_variable(user_id, name, field, session)
|
||||
|
||||
assert name in str(exc.value)
|
||||
assert "variable not found" in str(exc.value)
|
||||
|
||||
|
||||
def test_get_variable__TypeError(service, session):
|
||||
user_id = uuid4()
|
||||
name = "name"
|
||||
value = "value"
|
||||
field = "session_id"
|
||||
_type = CREDENTIAL_TYPE
|
||||
service.create_variable(user_id, name, value, _type=_type, session=session)
|
||||
|
||||
with pytest.raises(TypeError) as exc:
|
||||
service.get_variable(user_id, name, field, session)
|
||||
|
||||
assert name in str(exc.value)
|
||||
assert "purpose is to prevent the exposure of value" in str(exc.value)
|
||||
|
||||
|
||||
def test_list_variables(service, session):
|
||||
user_id = uuid4()
|
||||
names = ["name1", "name2", "name3"]
|
||||
value = "value"
|
||||
for name in names:
|
||||
service.create_variable(user_id, name, value, session=session)
|
||||
|
||||
result = service.list_variables(user_id, session=session)
|
||||
|
||||
assert all(name in result for name in names)
|
||||
|
||||
|
||||
def test_list_variables__empty(service, session):
|
||||
result = service.list_variables(uuid4(), session=session)
|
||||
|
||||
assert not result
|
||||
assert isinstance(result, list)
|
||||
|
||||
|
||||
def test_update_variable(service, session):
|
||||
user_id = uuid4()
|
||||
name = "name"
|
||||
old_value = "old_value"
|
||||
new_value = "new_value"
|
||||
field = ""
|
||||
service.create_variable(user_id, name, old_value, session=session)
|
||||
|
||||
old_recovered = service.get_variable(user_id, name, field, session=session)
|
||||
result = service.update_variable(user_id, name, new_value, session=session)
|
||||
new_recovered = service.get_variable(user_id, name, field, session=session)
|
||||
|
||||
assert old_value == old_recovered
|
||||
assert new_value == new_recovered
|
||||
assert result.user_id == user_id
|
||||
assert result.name == name
|
||||
assert result.value != old_value
|
||||
assert result.value != new_value
|
||||
assert result.default_fields == []
|
||||
assert result.type == GENERIC_TYPE
|
||||
assert isinstance(result.created_at, datetime)
|
||||
assert isinstance(result.updated_at, datetime)
|
||||
|
||||
|
||||
def test_update_variable__ValueError(service, session):
|
||||
user_id = uuid4()
|
||||
name = "name"
|
||||
value = "value"
|
||||
|
||||
with pytest.raises(ValueError) as exc:
|
||||
service.update_variable(user_id, name, value, session=session)
|
||||
|
||||
assert name in str(exc.value)
|
||||
assert "variable not found" in str(exc.value)
|
||||
|
||||
|
||||
def test_update_variable_fields(service, session):
|
||||
user_id = uuid4()
|
||||
variable = service.create_variable(user_id, "old_name", "old_value", session=session)
|
||||
saved = variable.model_dump()
|
||||
variable = VariableUpdate(**saved)
|
||||
variable.name = "new_name"
|
||||
variable.value = "new_value"
|
||||
variable.default_fields = ["new_field"]
|
||||
|
||||
result = service.update_variable_fields(
|
||||
user_id=user_id,
|
||||
variable_id=saved.get("id"),
|
||||
variable=variable,
|
||||
session=session,
|
||||
)
|
||||
|
||||
assert saved.get("id") == result.id
|
||||
assert saved.get("user_id") == result.user_id
|
||||
assert saved.get("name") != result.name
|
||||
assert saved.get("value") != result.value
|
||||
assert saved.get("default_fields") != result.default_fields
|
||||
assert saved.get("type") == result.type
|
||||
assert saved.get("created_at") == result.created_at
|
||||
assert saved.get("updated_at") != result.updated_at
|
||||
|
||||
|
||||
def test_delete_variable(service, session):
|
||||
user_id = uuid4()
|
||||
name = "name"
|
||||
value = "value"
|
||||
field = ""
|
||||
|
||||
saved = service.create_variable(user_id, name, value, session=session)
|
||||
recovered = service.get_variable(user_id, name, field, session=session)
|
||||
service.delete_variable(user_id, name, session=session)
|
||||
with pytest.raises(ValueError) as exc:
|
||||
service.get_variable(user_id, name, field, session)
|
||||
|
||||
assert recovered == value
|
||||
assert name in str(exc.value)
|
||||
assert "variable not found" in str(exc.value)
|
||||
|
||||
|
||||
def test_delete_variable__ValueError(service, session):
|
||||
user_id = uuid4()
|
||||
name = "name"
|
||||
|
||||
with pytest.raises(ValueError) as exc:
|
||||
service.delete_variable(user_id, name, session=session)
|
||||
|
||||
assert name in str(exc.value)
|
||||
assert "variable not found" in str(exc.value)
|
||||
|
||||
|
||||
def test_create_variable(service, session):
|
||||
user_id = uuid4()
|
||||
name = "name"
|
||||
value = "value"
|
||||
|
||||
result = service.create_variable(user_id, name, value, session=session)
|
||||
|
||||
assert result.user_id == user_id
|
||||
assert result.name == name
|
||||
assert result.value != value
|
||||
assert result.default_fields == []
|
||||
assert result.type == GENERIC_TYPE
|
||||
assert isinstance(result.created_at, datetime)
|
||||
assert isinstance(result.updated_at, datetime)
|
||||
Loading…
Add table
Add a link
Reference in a new issue