fix: Add async aupdate_build_config to CustomComponent (#5181)

* Add async aupdate_build_config to CustomComponent

* Add test of backward compatibility
This commit is contained in:
Christophe Bornet 2024-12-11 17:25:15 +01:00 committed by GitHub
commit 4f5d7d93ad
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 141 additions and 151 deletions

View file

@ -581,7 +581,7 @@ async def custom_component(
built_frontend_node, component_instance = build_custom_component_template(component, user_id=user.id)
if raw_code.frontend_node is not None:
built_frontend_node = component_instance.post_code_processing(built_frontend_node, raw_code.frontend_node)
built_frontend_node = await component_instance.post_code_processing(built_frontend_node, raw_code.frontend_node)
type_ = get_instance_name(component_instance)
return CustomComponentResponse(data=built_frontend_node, type=type_)
@ -630,10 +630,10 @@ async def custom_component_update(
for field_name, field_dict in template.items()
if isinstance(field_dict, dict) and field_dict.get("load_from_db")
]
params = update_params_with_load_from_db_fields(cc_instance, params, load_from_db_fields)
params = await update_params_with_load_from_db_fields(cc_instance, params, load_from_db_fields)
cc_instance.set_attributes(params)
updated_build_config = code_request.get_template()
cc_instance.update_build_config(
await cc_instance.aupdate_build_config(
build_config=updated_build_config,
field_value=code_request.field_value,
field_name=code_request.field,

View file

@ -136,16 +136,18 @@ class AgentComponent(ToolCallingAgentComponent):
value.input_types = []
return build_config
def update_build_config(self, build_config: dotdict, field_value: str, field_name: str | None = None) -> dotdict:
async def aupdate_build_config(
self, build_config: dotdict, field_value: str, field_name: str | None = None
) -> dotdict:
# Iterate over all providers in the MODEL_PROVIDERS_DICT
# Existing logic for updating build_config
if field_name == "agent_llm":
provider_info = MODEL_PROVIDERS_DICT.get(field_value)
if provider_info:
component_class = provider_info.get("component_class")
if component_class and hasattr(component_class, "update_build_config"):
# Call the component class's update_build_config method
build_config = component_class.update_build_config(build_config, field_value, field_name)
if component_class and hasattr(component_class, "aupdate_build_config"):
# Call the component class's aupdate_build_config method
build_config = await component_class.aupdate_build_config(build_config, field_value, field_name)
provider_configs: dict[str, tuple[dict, list[dict]]] = {
provider: (
@ -211,11 +213,11 @@ class AgentComponent(ToolCallingAgentComponent):
if provider_info:
component_class = provider_info.get("component_class")
prefix = provider_info.get("prefix")
if component_class and hasattr(component_class, "update_build_config"):
# Call each component class's update_build_config method
if component_class and hasattr(component_class, "aupdate_build_config"):
# Call each component class's aupdate_build_config method
# remove the prefix from the field_name
if isinstance(field_name, str) and isinstance(prefix, str):
field_name = field_name.replace(prefix, "")
build_config = component_class.update_build_config(build_config, field_value, field_name)
build_config = await component_class.aupdate_build_config(build_config, field_value, field_name)
return build_config

View file

@ -16,24 +16,25 @@ class LMStudioEmbeddingsComponent(LCEmbeddingsModel):
icon = "LMStudio"
@override
def update_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None):
async def aupdate_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None):
if field_name == "model":
base_url_dict = build_config.get("base_url", {})
base_url_load_from_db = base_url_dict.get("load_from_db", False)
base_url_value = base_url_dict.get("value")
if base_url_load_from_db:
base_url_value = self.variables(base_url_value)
base_url_value = await self.variables(base_url_value, field_name)
elif not base_url_value:
base_url_value = "http://localhost:1234/v1"
build_config["model"]["options"] = self.get_model(base_url_value)
build_config["model"]["options"] = await self.get_model(base_url_value)
return build_config
def get_model(self, base_url_value: str) -> list[str]:
@staticmethod
async def get_model(base_url_value: str) -> list[str]:
try:
url = urljoin(base_url_value, "/v1/models")
with httpx.Client() as client:
response = client.get(url)
async with httpx.AsyncClient() as client:
response = await client.get(url)
response.raise_for_status()
data = response.json()

View file

@ -20,24 +20,25 @@ class LMStudioModelComponent(LCModelComponent):
name = "LMStudioModel"
@override
def update_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None):
async def aupdate_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None):
if field_name == "model_name":
base_url_dict = build_config.get("base_url", {})
base_url_load_from_db = base_url_dict.get("load_from_db", False)
base_url_value = base_url_dict.get("value")
if base_url_load_from_db:
base_url_value = self.variables(base_url_value)
base_url_value = await self.variables(base_url_value, field_name)
elif not base_url_value:
base_url_value = "http://localhost:1234/v1"
build_config["model_name"]["options"] = self.get_model(base_url_value)
build_config["model_name"]["options"] = await self.get_model(base_url_value)
return build_config
def get_model(self, base_url_value: str) -> list[str]:
@staticmethod
async def get_model(base_url_value: str) -> list[str]:
try:
url = urljoin(base_url_value, "/v1/models")
with httpx.Client() as client:
response = client.get(url)
async with httpx.AsyncClient() as client:
response = await client.get(url)
response.raise_for_status()
data = response.json()

View file

@ -16,7 +16,7 @@ class ChatOllamaComponent(LCModelComponent):
icon = "Ollama"
name = "OllamaModel"
def update_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None):
async def aupdate_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None):
if field_name == "mirostat":
if field_value == "Disabled":
build_config["mirostat_eta"]["advanced"] = True
@ -40,10 +40,10 @@ class ChatOllamaComponent(LCModelComponent):
base_url_load_from_db = base_url_dict.get("load_from_db", False)
base_url_value = base_url_dict.get("value")
if base_url_load_from_db:
base_url_value = self.variables(base_url_value, field_name)
base_url_value = await self.variables(base_url_value, field_name)
elif not base_url_value:
base_url_value = "http://localhost:11434"
build_config["model_name"]["options"] = self.get_model(base_url_value)
build_config["model_name"]["options"] = await self.get_model(base_url_value)
if field_name == "keep_alive_flag":
if field_value == "Keep":
build_config["keep_alive"]["value"] = "-1"
@ -56,11 +56,12 @@ class ChatOllamaComponent(LCModelComponent):
return build_config
def get_model(self, base_url_value: str) -> list[str]:
@staticmethod
async def get_model(base_url_value: str) -> list[str]:
try:
url = urljoin(base_url_value, "/api/tags")
with httpx.Client() as client:
response = client.get(url)
async with httpx.AsyncClient() as client:
response = await client.get(url)
response.raise_for_status()
data = response.json()

View file

@ -45,9 +45,9 @@ class PromptComponent(Component):
)
return frontend_node
def post_code_processing(self, new_frontend_node: dict, current_frontend_node: dict):
async def post_code_processing(self, new_frontend_node: dict, current_frontend_node: dict):
"""This function is called after the code validation is done."""
frontend_node = super().post_code_processing(new_frontend_node, current_frontend_node)
frontend_node = await super().post_code_processing(new_frontend_node, current_frontend_node)
template = frontend_node["template"]["template"]["value"]
# Kept it duplicated for backwards compatibility
_ = process_prompt_template(

View file

@ -98,7 +98,9 @@ class PythonCodeStructuredTool(LCToolComponent):
]
@override
def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None) -> dotdict:
async def aupdate_build_config(
self, build_config: dotdict, field_value: Any, field_name: str | None = None
) -> dotdict:
if field_name is None:
return build_config
@ -226,22 +228,22 @@ class PythonCodeStructuredTool(LCToolComponent):
return_direct=self.return_direct,
)
def post_code_processing(self, new_frontend_node: dict, current_frontend_node: dict):
async def post_code_processing(self, new_frontend_node: dict, current_frontend_node: dict):
"""This function is called after the code validation is done."""
frontend_node = super().post_code_processing(new_frontend_node, current_frontend_node)
frontend_node["template"] = self.update_build_config(
frontend_node = await super().post_code_processing(new_frontend_node, current_frontend_node)
frontend_node["template"] = await self.aupdate_build_config(
frontend_node["template"],
frontend_node["template"]["tool_code"]["value"],
"tool_code",
)
frontend_node = super().post_code_processing(new_frontend_node, current_frontend_node)
frontend_node = await super().post_code_processing(new_frontend_node, current_frontend_node)
for key in frontend_node["template"]:
if key in self.DEFAULT_KEYS:
continue
frontend_node["template"] = self.update_build_config(
frontend_node["template"] = await self.aupdate_build_config(
frontend_node["template"], frontend_node["template"][key]["value"], key
)
frontend_node = super().post_code_processing(new_frontend_node, current_frontend_node)
frontend_node = await super().post_code_processing(new_frontend_node, current_frontend_node)
return frontend_node
def _parse_code(self, code: str) -> tuple[list[dict], list[dict]]:

View file

@ -44,7 +44,6 @@ if TYPE_CHECKING:
from langflow.graph.edge.schema import EdgeData
from langflow.graph.vertex.base import Vertex
from langflow.inputs.inputs import InputTypes
from langflow.schema import dotdict
from langflow.schema.log import LoggableType
@ -391,14 +390,6 @@ class Component(CustomComponent):
self._validate_inputs(params)
self._validate_outputs()
def update_inputs(
self,
build_config: dotdict,
field_value: Any,
field_name: str | None = None,
):
return self.update_build_config(build_config, field_value, field_name)
def run_and_validate_update_outputs(self, frontend_node: dict, field_name: str, field_value: Any):
frontend_node = self.update_outputs(frontend_node, field_name, field_value)
if field_name == "tool_mode" or frontend_node.get("tool_mode"):

View file

@ -1,5 +1,6 @@
from __future__ import annotations
import asyncio
from collections.abc import Callable, Sequence
from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar
@ -12,7 +13,7 @@ from pydantic import BaseModel
from langflow.custom.custom_component.base_component import BaseComponent
from langflow.helpers.flow import list_flows, load_flow, run_flow
from langflow.schema import Data
from langflow.services.deps import get_storage_service, get_variable_service, session_scope
from langflow.services.deps import async_session_scope, get_storage_service, get_variable_service, session_scope
from langflow.services.storage.service import StorageService
from langflow.template.utils import update_frontend_node_with_template_values
from langflow.type_extraction.type_extraction import post_process_type
@ -230,6 +231,19 @@ class CustomComponent(BaseComponent):
field_value: Any,
field_name: str | None = None,
):
if type(self).aupdate_build_config != CustomComponent.aupdate_build_config:
raise NotImplementedError
build_config[field_name]["value"] = field_value
return build_config
async def aupdate_build_config(
self,
build_config: dotdict,
field_value: Any,
field_name: str | None = None,
):
if type(self).update_build_config != CustomComponent.update_build_config:
return await asyncio.to_thread(self.update_build_config, build_config, field_value, field_name)
build_config[field_name]["value"] = field_value
return build_config
@ -410,8 +424,7 @@ class CustomComponent(BaseComponent):
self._template_config = self.build_template_config()
return self._template_config
@property
def variables(self):
async def variables(self, name: str, field: str):
"""Returns the variable for the current user with the specified name.
Raises:
@ -420,18 +433,14 @@ class CustomComponent(BaseComponent):
Returns:
The variable for the current user with the specified name.
"""
def get_variable(name: str, field: str):
if hasattr(self, "_user_id") and not self.user_id:
msg = f"User id is not set for {self.__class__.__name__}"
raise ValueError(msg)
variable_service = get_variable_service() # Get service instance
# Retrieve and decrypt the variable by name for the current user
with session_scope() as session:
user_id = self.user_id or ""
return variable_service.get_variable(user_id=user_id, name=name, field=field, session=session)
return get_variable
if hasattr(self, "_user_id") and not self.user_id:
msg = f"User id is not set for {self.__class__.__name__}"
raise ValueError(msg)
variable_service = get_variable_service() # Get service instance
# Retrieve and decrypt the variable by name for the current user
async with async_session_scope() as session:
user_id = self.user_id or ""
return await variable_service.get_variable(user_id=user_id, name=name, field=field, session=session)
def list_key_names(self):
"""Lists the names of the variables for the current user.
@ -519,7 +528,7 @@ class CustomComponent(BaseComponent):
"""
raise NotImplementedError
def post_code_processing(self, new_frontend_node: dict, current_frontend_node: dict):
async def post_code_processing(self, new_frontend_node: dict, current_frontend_node: dict):
"""This function is called after the code validation is done."""
return update_frontend_node_with_template_values(
frontend_node=new_frontend_node, raw_frontend_node=current_frontend_node

View file

@ -506,42 +506,6 @@ async def abuild_custom_components(components_paths: list[str]):
return custom_components_from_file
def update_field_dict(
custom_component_instance: "CustomComponent",
field_dict: dict,
build_config: dict,
*,
update_field: str | None = None,
update_field_value: Any | None = None,
call: bool = False,
):
"""Update the field dictionary by calling options() or value() if they are callable."""
if (
("real_time_refresh" in field_dict or "refresh_button" in field_dict)
and any(
(
field_dict.get("real_time_refresh", False),
field_dict.get("refresh_button", False),
)
)
and call
):
try:
dd_build_config = dotdict(build_config)
custom_component_instance.update_build_config(
build_config=dd_build_config,
field_value=update_field,
field_name=update_field_value,
)
build_config = dd_build_config
except Exception as exc:
msg = f"Error while running update_build_config: {exc}"
logger.exception(msg)
raise UpdateBuildConfigError(msg) from exc
return build_config
def sanitize_field_config(field_config: dict | Input):
# If any of the already existing keys are in field_config, remove them
field_dict = field_config.to_dict() if isinstance(field_config, Input) else field_config

View file

@ -57,7 +57,7 @@ async def get_instance_results(
fallback_to_env_vars: bool = False,
base_type: str = "component",
):
custom_params = update_params_with_load_from_db_fields(
custom_params = await update_params_with_load_from_db_fields(
custom_component, custom_params, vertex.load_from_db_fields, fallback_to_env_vars=fallback_to_env_vars
)
with warnings.catch_warnings():
@ -103,7 +103,7 @@ def convert_kwargs(params):
return params
def update_params_with_load_from_db_fields(
async def update_params_with_load_from_db_fields(
custom_component: CustomComponent,
params,
load_from_db_fields,
@ -115,7 +115,7 @@ def update_params_with_load_from_db_fields(
continue
try:
key = custom_component.variables(params[field], field)
key = await custom_component.variables(params[field], field)
except ValueError as e:
if any(reason in str(e) for reason in ["User id is not set", "variable not found."]):
raise

View file

@ -23,7 +23,7 @@ class VariableService(Service):
"""
@abc.abstractmethod
def get_variable(self, user_id: UUID | str, name: str, field: str, session: Session) -> str:
def get_variable_sync(self, user_id: UUID | str, name: str, field: str, session: Session) -> str:
"""Get a variable value.
Args:
@ -36,6 +36,20 @@ class VariableService(Service):
The value of the variable.
"""
async def get_variable(self, user_id: UUID | str, name: str, field: str, session: AsyncSession) -> str:
"""Async get a variable value.
Args:
user_id: The user ID.
name: The name of the variable.
field: The field of the variable.
session: The database session.
Returns:
The value of the variable.
"""
return await session.run_sync(lambda session_: self.get_variable_sync(user_id, name, field, session_))
@abc.abstractmethod
def list_variables_sync(self, user_id: UUID | str, session: Session) -> list[str | None]:
"""List all variables.
@ -48,7 +62,6 @@ class VariableService(Service):
A list of variable names.
"""
@abc.abstractmethod
async def list_variables(self, user_id: UUID | str, session: AsyncSession) -> list[str | None]:
"""List all variables.
@ -59,6 +72,7 @@ class VariableService(Service):
Returns:
A list of variable names.
"""
return await session.run_sync(lambda session_: self.list_variables_sync(user_id, session_))
@abc.abstractmethod
async def update_variable(self, user_id: UUID | str, name: str, value: str, session: AsyncSession) -> Variable:

View file

@ -79,7 +79,7 @@ class KubernetesSecretService(VariableService, Service):
raise ValueError(msg)
@override
def get_variable(
def get_variable_sync(
self,
user_id: UUID | str,
name: str,
@ -114,14 +114,6 @@ class KubernetesSecretService(VariableService, Service):
names.append(key)
return names
@override
async def list_variables(
self,
user_id: UUID | str,
session: AsyncSession,
) -> list[str | None]:
return await asyncio.to_thread(self.list_variables_sync, user_id, session.sync_session)
def _update_variable(
self,
user_id: UUID | str,

View file

@ -53,7 +53,7 @@ class DatabaseVariableService(VariableService, Service):
except Exception as e: # noqa: BLE001
logger.exception(f"Error processing {var_name} variable: {e!s}")
def get_variable(
def get_variable_sync(
self,
user_id: UUID | str,
name: str,

View file

@ -11,8 +11,8 @@ def component():
return ChatOllamaComponent()
@patch("httpx.Client.get")
def test_get_model_success(mock_get, component):
@patch("httpx.AsyncClient.get")
async def test_get_model_success(mock_get, component):
mock_response = MagicMock()
mock_response.json.return_value = {"models": [{"name": "model1"}, {"name": "model2"}]}
mock_response.raise_for_status.return_value = None
@ -20,7 +20,7 @@ def test_get_model_success(mock_get, component):
base_url = "http://localhost:11434"
model_names = component.get_model(base_url)
model_names = await component.get_model(base_url)
expected_url = urljoin(base_url, "/api/tags")
@ -29,8 +29,8 @@ def test_get_model_success(mock_get, component):
assert model_names == ["model1", "model2"]
@patch("httpx.Client.get")
def test_get_model_failure(mock_get, component):
@patch("httpx.AsyncClient.get")
async def test_get_model_failure(mock_get, component):
# Mock the response for the HTTP GET request to raise an exception
mock_get.side_effect = Exception("HTTP request failed")
@ -38,10 +38,10 @@ def test_get_model_failure(mock_get, component):
# Assert that the ValueError is raised when an exception occurs
with pytest.raises(ValueError, match="Could not retrieve models"):
component.get_model(url)
await component.get_model(url)
def test_update_build_config_mirostat_disabled(component):
async def test_update_build_config_mirostat_disabled(component):
build_config = {
"mirostat_eta": {"advanced": False, "value": 0.1},
"mirostat_tau": {"advanced": False, "value": 5},
@ -49,7 +49,7 @@ def test_update_build_config_mirostat_disabled(component):
field_value = "Disabled"
field_name = "mirostat"
updated_config = component.update_build_config(build_config, field_value, field_name)
updated_config = await component.aupdate_build_config(build_config, field_value, field_name)
assert updated_config["mirostat_eta"]["advanced"] is True
assert updated_config["mirostat_tau"]["advanced"] is True
@ -57,7 +57,7 @@ def test_update_build_config_mirostat_disabled(component):
assert updated_config["mirostat_tau"]["value"] is None
def test_update_build_config_mirostat_enabled(component):
async def test_update_build_config_mirostat_enabled(component):
build_config = {
"mirostat_eta": {"advanced": False, "value": None},
"mirostat_tau": {"advanced": False, "value": None},
@ -65,7 +65,7 @@ def test_update_build_config_mirostat_enabled(component):
field_value = "Mirostat 2.0"
field_name = "mirostat"
updated_config = component.update_build_config(build_config, field_value, field_name)
updated_config = await component.aupdate_build_config(build_config, field_value, field_name)
assert updated_config["mirostat_eta"]["advanced"] is False
assert updated_config["mirostat_tau"]["advanced"] is False
@ -73,8 +73,8 @@ def test_update_build_config_mirostat_enabled(component):
assert updated_config["mirostat_tau"]["value"] == 10
@patch("httpx.Client.get")
def test_update_build_config_model_name(mock_get, component):
@patch("httpx.AsyncClient.get")
async def test_update_build_config_model_name(mock_get, component):
# Mock the response for the HTTP GET request
mock_response = MagicMock()
mock_response.json.return_value = {"models": [{"name": "model1"}, {"name": "model2"}]}
@ -88,22 +88,22 @@ def test_update_build_config_model_name(mock_get, component):
field_value = None
field_name = "model_name"
updated_config = component.update_build_config(build_config, field_value, field_name)
updated_config = await component.aupdate_build_config(build_config, field_value, field_name)
assert updated_config["model_name"]["options"] == ["model1", "model2"]
def test_update_build_config_keep_alive(component):
async def test_update_build_config_keep_alive(component):
build_config = {"keep_alive": {"value": None, "advanced": False}}
field_value = "Keep"
field_name = "keep_alive_flag"
updated_config = component.update_build_config(build_config, field_value, field_name)
updated_config = await component.aupdate_build_config(build_config, field_value, field_name)
assert updated_config["keep_alive"]["value"] == "-1"
assert updated_config["keep_alive"]["advanced"] is True
field_value = "Immediately"
updated_config = component.update_build_config(build_config, field_value, field_name)
updated_config = await component.aupdate_build_config(build_config, field_value, field_name)
assert updated_config["keep_alive"]["value"] == "0"
assert updated_config["keep_alive"]["advanced"] is True

View file

@ -1,10 +1,15 @@
from typing import Any
import pytest
from langflow.components.agents import AgentComponent
from langflow.components.crewai import CrewAIAgentComponent, SequentialTaskComponent
from langflow.components.custom_component import CustomComponent
from langflow.components.inputs import ChatInput
from langflow.components.models import OpenAIModelComponent
from langflow.components.outputs import ChatOutput
from langflow.schema import dotdict
from langflow.template import Output
from typing_extensions import override
def test_set_invalid_output():
@ -58,3 +63,21 @@ def test_set_required_inputs_various_components():
assert _assert_all_outputs_have_different_required_inputs(chatoutput.outputs)
assert _assert_all_outputs_have_different_required_inputs(task.outputs)
assert _assert_all_outputs_have_different_required_inputs(agent.outputs)
async def test_update_build_config_backward_compatibility():
class TestComponent(CustomComponent):
@override
def update_build_config(
self,
build_config: dotdict,
field_value: Any,
field_name: str | None = None,
):
build_config["foo"] = "bar"
return build_config
component = TestComponent()
build_config = dotdict()
build_config = await component.aupdate_build_config(build_config, "", "")
assert build_config["foo"] == "bar"

View file

@ -1,6 +1,6 @@
from datetime import datetime
from unittest.mock import patch
from uuid import UUID, uuid4
from uuid import uuid4
import pytest
from langflow.services.database.models.variable.model import VariableUpdate
@ -9,7 +9,7 @@ from langflow.services.settings.constants import VARIABLES_TO_GET_FROM_ENVIRONME
from langflow.services.variable.constants import CREDENTIAL_TYPE, GENERIC_TYPE
from langflow.services.variable.service import DatabaseVariableService
from sqlalchemy.ext.asyncio import create_async_engine
from sqlmodel import Session, SQLModel
from sqlmodel import SQLModel
from sqlmodel.ext.asyncio.session import AsyncSession
@ -28,16 +28,6 @@ async def session():
yield session
def _get_variable(
session: Session,
service,
user_id: UUID | str,
name: str,
field: str,
):
return service.get_variable(user_id, name, field, session=session)
async def test_initialize_user_variables__create_and_update(service, session: AsyncSession):
user_id = uuid4()
field = ""
@ -53,7 +43,7 @@ async def test_initialize_user_variables__create_and_update(service, session: As
variables = await service.list_variables(user_id, session=session)
for name in variables:
value = await session.run_sync(_get_variable, service, user_id, name, field)
value = await service.get_variable(user_id, name, field, session=session)
assert value == env_vars[name]
assert all(i in variables for i in good_vars)
@ -80,7 +70,7 @@ async def test_get_variable(service, session: AsyncSession):
field = ""
await service.create_variable(user_id, name, value, session=session)
result = await session.run_sync(_get_variable, service, user_id, name, field)
result = await service.get_variable(user_id, name, field, session=session)
assert result == value
@ -91,7 +81,7 @@ async def test_get_variable__valueerror(service, session: AsyncSession):
field = ""
with pytest.raises(ValueError, match=f"{name} variable not found."):
await session.run_sync(_get_variable, service, user_id, name, field)
await service.get_variable(user_id, name, field, session=session)
async def test_get_variable__typeerror(service, session: AsyncSession):
@ -103,7 +93,7 @@ async def test_get_variable__typeerror(service, session: AsyncSession):
await service.create_variable(user_id, name, value, type_=type_, session=session)
with pytest.raises(TypeError) as exc:
await session.run_sync(_get_variable, service, user_id, name, field)
await service.get_variable(user_id, name, field, session=session)
assert name in str(exc.value)
assert "purpose is to prevent the exposure of value" in str(exc.value)
@ -136,9 +126,9 @@ async def test_update_variable(service, session: AsyncSession):
field = ""
await service.create_variable(user_id, name, old_value, session=session)
old_recovered = await session.run_sync(_get_variable, service, user_id, name, field)
old_recovered = await service.get_variable(user_id, name, field, session=session)
result = await service.update_variable(user_id, name, new_value, session=session)
new_recovered = await session.run_sync(_get_variable, service, user_id, name, field)
new_recovered = await service.get_variable(user_id, name, field, session=session)
assert old_value == old_recovered
assert new_value == new_recovered
@ -197,10 +187,10 @@ async def test_delete_variable(service, session: AsyncSession):
field = ""
await service.create_variable(user_id, name, value, session=session)
recovered = await session.run_sync(_get_variable, service, user_id, name, field)
recovered = await service.get_variable(user_id, name, field, session=session)
await service.delete_variable(user_id, name, session=session)
with pytest.raises(ValueError, match=f"{name} variable not found."):
await session.run_sync(_get_variable, service, user_id, name, field)
await service.get_variable(user_id, name, field, session=session)
assert recovered == value
@ -220,10 +210,10 @@ async def test_delete_variable_by_id(service, session: AsyncSession):
field = "field"
saved = await service.create_variable(user_id, name, value, session=session)
recovered = await session.run_sync(_get_variable, service, user_id, name, field)
recovered = await service.get_variable(user_id, name, field, session=session)
await service.delete_variable_by_id(user_id, saved.id, session=session)
with pytest.raises(ValueError, match=f"{name} variable not found."):
await session.run_sync(_get_variable, service, user_id, name, field)
await service.get_variable(user_id, name, field, session=session)
assert recovered == value