From 4f5d7d93ada8b570a53732bbdc63bc3822a93e08 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Wed, 11 Dec 2024 17:25:15 +0100 Subject: [PATCH] fix: Add async aupdate_build_config to CustomComponent (#5181) * Add async aupdate_build_config to CustomComponent * Add test of backward compatibility --- src/backend/base/langflow/api/v1/endpoints.py | 6 +-- .../base/langflow/components/agents/agent.py | 16 ++++---- .../embeddings/lmstudioembeddings.py | 13 +++--- .../components/models/lmstudiomodel.py | 13 +++--- .../base/langflow/components/models/ollama.py | 13 +++--- .../langflow/components/prompts/prompt.py | 4 +- .../tools/python_code_structured_tool.py | 16 ++++---- .../custom/custom_component/component.py | 9 ---- .../custom_component/custom_component.py | 41 +++++++++++-------- src/backend/base/langflow/custom/utils.py | 36 ---------------- .../langflow/interface/initialize/loading.py | 6 +-- .../base/langflow/services/variable/base.py | 18 +++++++- .../langflow/services/variable/kubernetes.py | 10 +---- .../langflow/services/variable/service.py | 2 +- .../models/test_chatollama_component.py | 32 +++++++-------- .../custom/custom_component/test_component.py | 23 +++++++++++ .../unit/services/variable/test_service.py | 34 ++++++--------- 17 files changed, 141 insertions(+), 151 deletions(-) diff --git a/src/backend/base/langflow/api/v1/endpoints.py b/src/backend/base/langflow/api/v1/endpoints.py index 34504564c..12ff4903d 100644 --- a/src/backend/base/langflow/api/v1/endpoints.py +++ b/src/backend/base/langflow/api/v1/endpoints.py @@ -581,7 +581,7 @@ async def custom_component( built_frontend_node, component_instance = build_custom_component_template(component, user_id=user.id) if raw_code.frontend_node is not None: - built_frontend_node = component_instance.post_code_processing(built_frontend_node, raw_code.frontend_node) + built_frontend_node = await component_instance.post_code_processing(built_frontend_node, raw_code.frontend_node) type_ = get_instance_name(component_instance) return CustomComponentResponse(data=built_frontend_node, type=type_) @@ -630,10 +630,10 @@ async def custom_component_update( for field_name, field_dict in template.items() if isinstance(field_dict, dict) and field_dict.get("load_from_db") ] - params = update_params_with_load_from_db_fields(cc_instance, params, load_from_db_fields) + params = await update_params_with_load_from_db_fields(cc_instance, params, load_from_db_fields) cc_instance.set_attributes(params) updated_build_config = code_request.get_template() - cc_instance.update_build_config( + await cc_instance.aupdate_build_config( build_config=updated_build_config, field_value=code_request.field_value, field_name=code_request.field, diff --git a/src/backend/base/langflow/components/agents/agent.py b/src/backend/base/langflow/components/agents/agent.py index ddd6b4022..2491193b6 100644 --- a/src/backend/base/langflow/components/agents/agent.py +++ b/src/backend/base/langflow/components/agents/agent.py @@ -136,16 +136,18 @@ class AgentComponent(ToolCallingAgentComponent): value.input_types = [] return build_config - def update_build_config(self, build_config: dotdict, field_value: str, field_name: str | None = None) -> dotdict: + async def aupdate_build_config( + self, build_config: dotdict, field_value: str, field_name: str | None = None + ) -> dotdict: # Iterate over all providers in the MODEL_PROVIDERS_DICT # Existing logic for updating build_config if field_name == "agent_llm": provider_info = MODEL_PROVIDERS_DICT.get(field_value) if provider_info: component_class = provider_info.get("component_class") - if component_class and hasattr(component_class, "update_build_config"): - # Call the component class's update_build_config method - build_config = component_class.update_build_config(build_config, field_value, field_name) + if component_class and hasattr(component_class, "aupdate_build_config"): + # Call the component class's aupdate_build_config method + build_config = await component_class.aupdate_build_config(build_config, field_value, field_name) provider_configs: dict[str, tuple[dict, list[dict]]] = { provider: ( @@ -211,11 +213,11 @@ class AgentComponent(ToolCallingAgentComponent): if provider_info: component_class = provider_info.get("component_class") prefix = provider_info.get("prefix") - if component_class and hasattr(component_class, "update_build_config"): - # Call each component class's update_build_config method + if component_class and hasattr(component_class, "aupdate_build_config"): + # Call each component class's aupdate_build_config method # remove the prefix from the field_name if isinstance(field_name, str) and isinstance(prefix, str): field_name = field_name.replace(prefix, "") - build_config = component_class.update_build_config(build_config, field_value, field_name) + build_config = await component_class.aupdate_build_config(build_config, field_value, field_name) return build_config diff --git a/src/backend/base/langflow/components/embeddings/lmstudioembeddings.py b/src/backend/base/langflow/components/embeddings/lmstudioembeddings.py index 409871974..dd065a4df 100644 --- a/src/backend/base/langflow/components/embeddings/lmstudioembeddings.py +++ b/src/backend/base/langflow/components/embeddings/lmstudioembeddings.py @@ -16,24 +16,25 @@ class LMStudioEmbeddingsComponent(LCEmbeddingsModel): icon = "LMStudio" @override - def update_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None): + async def aupdate_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None): if field_name == "model": base_url_dict = build_config.get("base_url", {}) base_url_load_from_db = base_url_dict.get("load_from_db", False) base_url_value = base_url_dict.get("value") if base_url_load_from_db: - base_url_value = self.variables(base_url_value) + base_url_value = await self.variables(base_url_value, field_name) elif not base_url_value: base_url_value = "http://localhost:1234/v1" - build_config["model"]["options"] = self.get_model(base_url_value) + build_config["model"]["options"] = await self.get_model(base_url_value) return build_config - def get_model(self, base_url_value: str) -> list[str]: + @staticmethod + async def get_model(base_url_value: str) -> list[str]: try: url = urljoin(base_url_value, "/v1/models") - with httpx.Client() as client: - response = client.get(url) + async with httpx.AsyncClient() as client: + response = await client.get(url) response.raise_for_status() data = response.json() diff --git a/src/backend/base/langflow/components/models/lmstudiomodel.py b/src/backend/base/langflow/components/models/lmstudiomodel.py index f4f64c5c9..a61b51391 100644 --- a/src/backend/base/langflow/components/models/lmstudiomodel.py +++ b/src/backend/base/langflow/components/models/lmstudiomodel.py @@ -20,24 +20,25 @@ class LMStudioModelComponent(LCModelComponent): name = "LMStudioModel" @override - def update_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None): + async def aupdate_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None): if field_name == "model_name": base_url_dict = build_config.get("base_url", {}) base_url_load_from_db = base_url_dict.get("load_from_db", False) base_url_value = base_url_dict.get("value") if base_url_load_from_db: - base_url_value = self.variables(base_url_value) + base_url_value = await self.variables(base_url_value, field_name) elif not base_url_value: base_url_value = "http://localhost:1234/v1" - build_config["model_name"]["options"] = self.get_model(base_url_value) + build_config["model_name"]["options"] = await self.get_model(base_url_value) return build_config - def get_model(self, base_url_value: str) -> list[str]: + @staticmethod + async def get_model(base_url_value: str) -> list[str]: try: url = urljoin(base_url_value, "/v1/models") - with httpx.Client() as client: - response = client.get(url) + async with httpx.AsyncClient() as client: + response = await client.get(url) response.raise_for_status() data = response.json() diff --git a/src/backend/base/langflow/components/models/ollama.py b/src/backend/base/langflow/components/models/ollama.py index 12efe8b55..d303af31a 100644 --- a/src/backend/base/langflow/components/models/ollama.py +++ b/src/backend/base/langflow/components/models/ollama.py @@ -16,7 +16,7 @@ class ChatOllamaComponent(LCModelComponent): icon = "Ollama" name = "OllamaModel" - def update_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None): + async def aupdate_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None): if field_name == "mirostat": if field_value == "Disabled": build_config["mirostat_eta"]["advanced"] = True @@ -40,10 +40,10 @@ class ChatOllamaComponent(LCModelComponent): base_url_load_from_db = base_url_dict.get("load_from_db", False) base_url_value = base_url_dict.get("value") if base_url_load_from_db: - base_url_value = self.variables(base_url_value, field_name) + base_url_value = await self.variables(base_url_value, field_name) elif not base_url_value: base_url_value = "http://localhost:11434" - build_config["model_name"]["options"] = self.get_model(base_url_value) + build_config["model_name"]["options"] = await self.get_model(base_url_value) if field_name == "keep_alive_flag": if field_value == "Keep": build_config["keep_alive"]["value"] = "-1" @@ -56,11 +56,12 @@ class ChatOllamaComponent(LCModelComponent): return build_config - def get_model(self, base_url_value: str) -> list[str]: + @staticmethod + async def get_model(base_url_value: str) -> list[str]: try: url = urljoin(base_url_value, "/api/tags") - with httpx.Client() as client: - response = client.get(url) + async with httpx.AsyncClient() as client: + response = await client.get(url) response.raise_for_status() data = response.json() diff --git a/src/backend/base/langflow/components/prompts/prompt.py b/src/backend/base/langflow/components/prompts/prompt.py index fa0fe3810..af6f5b3f7 100644 --- a/src/backend/base/langflow/components/prompts/prompt.py +++ b/src/backend/base/langflow/components/prompts/prompt.py @@ -45,9 +45,9 @@ class PromptComponent(Component): ) return frontend_node - def post_code_processing(self, new_frontend_node: dict, current_frontend_node: dict): + async def post_code_processing(self, new_frontend_node: dict, current_frontend_node: dict): """This function is called after the code validation is done.""" - frontend_node = super().post_code_processing(new_frontend_node, current_frontend_node) + frontend_node = await super().post_code_processing(new_frontend_node, current_frontend_node) template = frontend_node["template"]["template"]["value"] # Kept it duplicated for backwards compatibility _ = process_prompt_template( diff --git a/src/backend/base/langflow/components/tools/python_code_structured_tool.py b/src/backend/base/langflow/components/tools/python_code_structured_tool.py index 0640f2a84..e3af72fed 100644 --- a/src/backend/base/langflow/components/tools/python_code_structured_tool.py +++ b/src/backend/base/langflow/components/tools/python_code_structured_tool.py @@ -98,7 +98,9 @@ class PythonCodeStructuredTool(LCToolComponent): ] @override - def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None) -> dotdict: + async def aupdate_build_config( + self, build_config: dotdict, field_value: Any, field_name: str | None = None + ) -> dotdict: if field_name is None: return build_config @@ -226,22 +228,22 @@ class PythonCodeStructuredTool(LCToolComponent): return_direct=self.return_direct, ) - def post_code_processing(self, new_frontend_node: dict, current_frontend_node: dict): + async def post_code_processing(self, new_frontend_node: dict, current_frontend_node: dict): """This function is called after the code validation is done.""" - frontend_node = super().post_code_processing(new_frontend_node, current_frontend_node) - frontend_node["template"] = self.update_build_config( + frontend_node = await super().post_code_processing(new_frontend_node, current_frontend_node) + frontend_node["template"] = await self.aupdate_build_config( frontend_node["template"], frontend_node["template"]["tool_code"]["value"], "tool_code", ) - frontend_node = super().post_code_processing(new_frontend_node, current_frontend_node) + frontend_node = await super().post_code_processing(new_frontend_node, current_frontend_node) for key in frontend_node["template"]: if key in self.DEFAULT_KEYS: continue - frontend_node["template"] = self.update_build_config( + frontend_node["template"] = await self.aupdate_build_config( frontend_node["template"], frontend_node["template"][key]["value"], key ) - frontend_node = super().post_code_processing(new_frontend_node, current_frontend_node) + frontend_node = await super().post_code_processing(new_frontend_node, current_frontend_node) return frontend_node def _parse_code(self, code: str) -> tuple[list[dict], list[dict]]: diff --git a/src/backend/base/langflow/custom/custom_component/component.py b/src/backend/base/langflow/custom/custom_component/component.py index 08a1e8f4a..050a26ead 100644 --- a/src/backend/base/langflow/custom/custom_component/component.py +++ b/src/backend/base/langflow/custom/custom_component/component.py @@ -44,7 +44,6 @@ if TYPE_CHECKING: from langflow.graph.edge.schema import EdgeData from langflow.graph.vertex.base import Vertex from langflow.inputs.inputs import InputTypes - from langflow.schema import dotdict from langflow.schema.log import LoggableType @@ -391,14 +390,6 @@ class Component(CustomComponent): self._validate_inputs(params) self._validate_outputs() - def update_inputs( - self, - build_config: dotdict, - field_value: Any, - field_name: str | None = None, - ): - return self.update_build_config(build_config, field_value, field_name) - def run_and_validate_update_outputs(self, frontend_node: dict, field_name: str, field_value: Any): frontend_node = self.update_outputs(frontend_node, field_name, field_value) if field_name == "tool_mode" or frontend_node.get("tool_mode"): diff --git a/src/backend/base/langflow/custom/custom_component/custom_component.py b/src/backend/base/langflow/custom/custom_component/custom_component.py index 5a939ebcb..c0fd80931 100644 --- a/src/backend/base/langflow/custom/custom_component/custom_component.py +++ b/src/backend/base/langflow/custom/custom_component/custom_component.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from collections.abc import Callable, Sequence from pathlib import Path from typing import TYPE_CHECKING, Any, ClassVar @@ -12,7 +13,7 @@ from pydantic import BaseModel from langflow.custom.custom_component.base_component import BaseComponent from langflow.helpers.flow import list_flows, load_flow, run_flow from langflow.schema import Data -from langflow.services.deps import get_storage_service, get_variable_service, session_scope +from langflow.services.deps import async_session_scope, get_storage_service, get_variable_service, session_scope from langflow.services.storage.service import StorageService from langflow.template.utils import update_frontend_node_with_template_values from langflow.type_extraction.type_extraction import post_process_type @@ -230,6 +231,19 @@ class CustomComponent(BaseComponent): field_value: Any, field_name: str | None = None, ): + if type(self).aupdate_build_config != CustomComponent.aupdate_build_config: + raise NotImplementedError + build_config[field_name]["value"] = field_value + return build_config + + async def aupdate_build_config( + self, + build_config: dotdict, + field_value: Any, + field_name: str | None = None, + ): + if type(self).update_build_config != CustomComponent.update_build_config: + return await asyncio.to_thread(self.update_build_config, build_config, field_value, field_name) build_config[field_name]["value"] = field_value return build_config @@ -410,8 +424,7 @@ class CustomComponent(BaseComponent): self._template_config = self.build_template_config() return self._template_config - @property - def variables(self): + async def variables(self, name: str, field: str): """Returns the variable for the current user with the specified name. Raises: @@ -420,18 +433,14 @@ class CustomComponent(BaseComponent): Returns: The variable for the current user with the specified name. """ - - def get_variable(name: str, field: str): - if hasattr(self, "_user_id") and not self.user_id: - msg = f"User id is not set for {self.__class__.__name__}" - raise ValueError(msg) - variable_service = get_variable_service() # Get service instance - # Retrieve and decrypt the variable by name for the current user - with session_scope() as session: - user_id = self.user_id or "" - return variable_service.get_variable(user_id=user_id, name=name, field=field, session=session) - - return get_variable + if hasattr(self, "_user_id") and not self.user_id: + msg = f"User id is not set for {self.__class__.__name__}" + raise ValueError(msg) + variable_service = get_variable_service() # Get service instance + # Retrieve and decrypt the variable by name for the current user + async with async_session_scope() as session: + user_id = self.user_id or "" + return await variable_service.get_variable(user_id=user_id, name=name, field=field, session=session) def list_key_names(self): """Lists the names of the variables for the current user. @@ -519,7 +528,7 @@ class CustomComponent(BaseComponent): """ raise NotImplementedError - def post_code_processing(self, new_frontend_node: dict, current_frontend_node: dict): + async def post_code_processing(self, new_frontend_node: dict, current_frontend_node: dict): """This function is called after the code validation is done.""" return update_frontend_node_with_template_values( frontend_node=new_frontend_node, raw_frontend_node=current_frontend_node diff --git a/src/backend/base/langflow/custom/utils.py b/src/backend/base/langflow/custom/utils.py index 3418057f6..ce18e3e02 100644 --- a/src/backend/base/langflow/custom/utils.py +++ b/src/backend/base/langflow/custom/utils.py @@ -506,42 +506,6 @@ async def abuild_custom_components(components_paths: list[str]): return custom_components_from_file -def update_field_dict( - custom_component_instance: "CustomComponent", - field_dict: dict, - build_config: dict, - *, - update_field: str | None = None, - update_field_value: Any | None = None, - call: bool = False, -): - """Update the field dictionary by calling options() or value() if they are callable.""" - if ( - ("real_time_refresh" in field_dict or "refresh_button" in field_dict) - and any( - ( - field_dict.get("real_time_refresh", False), - field_dict.get("refresh_button", False), - ) - ) - and call - ): - try: - dd_build_config = dotdict(build_config) - custom_component_instance.update_build_config( - build_config=dd_build_config, - field_value=update_field, - field_name=update_field_value, - ) - build_config = dd_build_config - except Exception as exc: - msg = f"Error while running update_build_config: {exc}" - logger.exception(msg) - raise UpdateBuildConfigError(msg) from exc - - return build_config - - def sanitize_field_config(field_config: dict | Input): # If any of the already existing keys are in field_config, remove them field_dict = field_config.to_dict() if isinstance(field_config, Input) else field_config diff --git a/src/backend/base/langflow/interface/initialize/loading.py b/src/backend/base/langflow/interface/initialize/loading.py index 9bd831112..c9a4cf09b 100644 --- a/src/backend/base/langflow/interface/initialize/loading.py +++ b/src/backend/base/langflow/interface/initialize/loading.py @@ -57,7 +57,7 @@ async def get_instance_results( fallback_to_env_vars: bool = False, base_type: str = "component", ): - custom_params = update_params_with_load_from_db_fields( + custom_params = await update_params_with_load_from_db_fields( custom_component, custom_params, vertex.load_from_db_fields, fallback_to_env_vars=fallback_to_env_vars ) with warnings.catch_warnings(): @@ -103,7 +103,7 @@ def convert_kwargs(params): return params -def update_params_with_load_from_db_fields( +async def update_params_with_load_from_db_fields( custom_component: CustomComponent, params, load_from_db_fields, @@ -115,7 +115,7 @@ def update_params_with_load_from_db_fields( continue try: - key = custom_component.variables(params[field], field) + key = await custom_component.variables(params[field], field) except ValueError as e: if any(reason in str(e) for reason in ["User id is not set", "variable not found."]): raise diff --git a/src/backend/base/langflow/services/variable/base.py b/src/backend/base/langflow/services/variable/base.py index 9efee0616..0e99b2578 100644 --- a/src/backend/base/langflow/services/variable/base.py +++ b/src/backend/base/langflow/services/variable/base.py @@ -23,7 +23,7 @@ class VariableService(Service): """ @abc.abstractmethod - def get_variable(self, user_id: UUID | str, name: str, field: str, session: Session) -> str: + def get_variable_sync(self, user_id: UUID | str, name: str, field: str, session: Session) -> str: """Get a variable value. Args: @@ -36,6 +36,20 @@ class VariableService(Service): The value of the variable. """ + async def get_variable(self, user_id: UUID | str, name: str, field: str, session: AsyncSession) -> str: + """Async get a variable value. + + Args: + user_id: The user ID. + name: The name of the variable. + field: The field of the variable. + session: The database session. + + Returns: + The value of the variable. + """ + return await session.run_sync(lambda session_: self.get_variable_sync(user_id, name, field, session_)) + @abc.abstractmethod def list_variables_sync(self, user_id: UUID | str, session: Session) -> list[str | None]: """List all variables. @@ -48,7 +62,6 @@ class VariableService(Service): A list of variable names. """ - @abc.abstractmethod async def list_variables(self, user_id: UUID | str, session: AsyncSession) -> list[str | None]: """List all variables. @@ -59,6 +72,7 @@ class VariableService(Service): Returns: A list of variable names. """ + return await session.run_sync(lambda session_: self.list_variables_sync(user_id, session_)) @abc.abstractmethod async def update_variable(self, user_id: UUID | str, name: str, value: str, session: AsyncSession) -> Variable: diff --git a/src/backend/base/langflow/services/variable/kubernetes.py b/src/backend/base/langflow/services/variable/kubernetes.py index 00e3ab5bd..5dc8d069a 100644 --- a/src/backend/base/langflow/services/variable/kubernetes.py +++ b/src/backend/base/langflow/services/variable/kubernetes.py @@ -79,7 +79,7 @@ class KubernetesSecretService(VariableService, Service): raise ValueError(msg) @override - def get_variable( + def get_variable_sync( self, user_id: UUID | str, name: str, @@ -114,14 +114,6 @@ class KubernetesSecretService(VariableService, Service): names.append(key) return names - @override - async def list_variables( - self, - user_id: UUID | str, - session: AsyncSession, - ) -> list[str | None]: - return await asyncio.to_thread(self.list_variables_sync, user_id, session.sync_session) - def _update_variable( self, user_id: UUID | str, diff --git a/src/backend/base/langflow/services/variable/service.py b/src/backend/base/langflow/services/variable/service.py index 40a7c69d3..0053ed69d 100644 --- a/src/backend/base/langflow/services/variable/service.py +++ b/src/backend/base/langflow/services/variable/service.py @@ -53,7 +53,7 @@ class DatabaseVariableService(VariableService, Service): except Exception as e: # noqa: BLE001 logger.exception(f"Error processing {var_name} variable: {e!s}") - def get_variable( + def get_variable_sync( self, user_id: UUID | str, name: str, diff --git a/src/backend/tests/unit/components/models/test_chatollama_component.py b/src/backend/tests/unit/components/models/test_chatollama_component.py index 0f5784de9..d22d419ce 100644 --- a/src/backend/tests/unit/components/models/test_chatollama_component.py +++ b/src/backend/tests/unit/components/models/test_chatollama_component.py @@ -11,8 +11,8 @@ def component(): return ChatOllamaComponent() -@patch("httpx.Client.get") -def test_get_model_success(mock_get, component): +@patch("httpx.AsyncClient.get") +async def test_get_model_success(mock_get, component): mock_response = MagicMock() mock_response.json.return_value = {"models": [{"name": "model1"}, {"name": "model2"}]} mock_response.raise_for_status.return_value = None @@ -20,7 +20,7 @@ def test_get_model_success(mock_get, component): base_url = "http://localhost:11434" - model_names = component.get_model(base_url) + model_names = await component.get_model(base_url) expected_url = urljoin(base_url, "/api/tags") @@ -29,8 +29,8 @@ def test_get_model_success(mock_get, component): assert model_names == ["model1", "model2"] -@patch("httpx.Client.get") -def test_get_model_failure(mock_get, component): +@patch("httpx.AsyncClient.get") +async def test_get_model_failure(mock_get, component): # Mock the response for the HTTP GET request to raise an exception mock_get.side_effect = Exception("HTTP request failed") @@ -38,10 +38,10 @@ def test_get_model_failure(mock_get, component): # Assert that the ValueError is raised when an exception occurs with pytest.raises(ValueError, match="Could not retrieve models"): - component.get_model(url) + await component.get_model(url) -def test_update_build_config_mirostat_disabled(component): +async def test_update_build_config_mirostat_disabled(component): build_config = { "mirostat_eta": {"advanced": False, "value": 0.1}, "mirostat_tau": {"advanced": False, "value": 5}, @@ -49,7 +49,7 @@ def test_update_build_config_mirostat_disabled(component): field_value = "Disabled" field_name = "mirostat" - updated_config = component.update_build_config(build_config, field_value, field_name) + updated_config = await component.aupdate_build_config(build_config, field_value, field_name) assert updated_config["mirostat_eta"]["advanced"] is True assert updated_config["mirostat_tau"]["advanced"] is True @@ -57,7 +57,7 @@ def test_update_build_config_mirostat_disabled(component): assert updated_config["mirostat_tau"]["value"] is None -def test_update_build_config_mirostat_enabled(component): +async def test_update_build_config_mirostat_enabled(component): build_config = { "mirostat_eta": {"advanced": False, "value": None}, "mirostat_tau": {"advanced": False, "value": None}, @@ -65,7 +65,7 @@ def test_update_build_config_mirostat_enabled(component): field_value = "Mirostat 2.0" field_name = "mirostat" - updated_config = component.update_build_config(build_config, field_value, field_name) + updated_config = await component.aupdate_build_config(build_config, field_value, field_name) assert updated_config["mirostat_eta"]["advanced"] is False assert updated_config["mirostat_tau"]["advanced"] is False @@ -73,8 +73,8 @@ def test_update_build_config_mirostat_enabled(component): assert updated_config["mirostat_tau"]["value"] == 10 -@patch("httpx.Client.get") -def test_update_build_config_model_name(mock_get, component): +@patch("httpx.AsyncClient.get") +async def test_update_build_config_model_name(mock_get, component): # Mock the response for the HTTP GET request mock_response = MagicMock() mock_response.json.return_value = {"models": [{"name": "model1"}, {"name": "model2"}]} @@ -88,22 +88,22 @@ def test_update_build_config_model_name(mock_get, component): field_value = None field_name = "model_name" - updated_config = component.update_build_config(build_config, field_value, field_name) + updated_config = await component.aupdate_build_config(build_config, field_value, field_name) assert updated_config["model_name"]["options"] == ["model1", "model2"] -def test_update_build_config_keep_alive(component): +async def test_update_build_config_keep_alive(component): build_config = {"keep_alive": {"value": None, "advanced": False}} field_value = "Keep" field_name = "keep_alive_flag" - updated_config = component.update_build_config(build_config, field_value, field_name) + updated_config = await component.aupdate_build_config(build_config, field_value, field_name) assert updated_config["keep_alive"]["value"] == "-1" assert updated_config["keep_alive"]["advanced"] is True field_value = "Immediately" - updated_config = component.update_build_config(build_config, field_value, field_name) + updated_config = await component.aupdate_build_config(build_config, field_value, field_name) assert updated_config["keep_alive"]["value"] == "0" assert updated_config["keep_alive"]["advanced"] is True diff --git a/src/backend/tests/unit/custom/custom_component/test_component.py b/src/backend/tests/unit/custom/custom_component/test_component.py index 8a0fd11ee..d7cdeba0e 100644 --- a/src/backend/tests/unit/custom/custom_component/test_component.py +++ b/src/backend/tests/unit/custom/custom_component/test_component.py @@ -1,10 +1,15 @@ +from typing import Any + import pytest from langflow.components.agents import AgentComponent from langflow.components.crewai import CrewAIAgentComponent, SequentialTaskComponent +from langflow.components.custom_component import CustomComponent from langflow.components.inputs import ChatInput from langflow.components.models import OpenAIModelComponent from langflow.components.outputs import ChatOutput +from langflow.schema import dotdict from langflow.template import Output +from typing_extensions import override def test_set_invalid_output(): @@ -58,3 +63,21 @@ def test_set_required_inputs_various_components(): assert _assert_all_outputs_have_different_required_inputs(chatoutput.outputs) assert _assert_all_outputs_have_different_required_inputs(task.outputs) assert _assert_all_outputs_have_different_required_inputs(agent.outputs) + + +async def test_update_build_config_backward_compatibility(): + class TestComponent(CustomComponent): + @override + def update_build_config( + self, + build_config: dotdict, + field_value: Any, + field_name: str | None = None, + ): + build_config["foo"] = "bar" + return build_config + + component = TestComponent() + build_config = dotdict() + build_config = await component.aupdate_build_config(build_config, "", "") + assert build_config["foo"] == "bar" diff --git a/src/backend/tests/unit/services/variable/test_service.py b/src/backend/tests/unit/services/variable/test_service.py index d02eedeac..1f7b1d86e 100644 --- a/src/backend/tests/unit/services/variable/test_service.py +++ b/src/backend/tests/unit/services/variable/test_service.py @@ -1,6 +1,6 @@ from datetime import datetime from unittest.mock import patch -from uuid import UUID, uuid4 +from uuid import uuid4 import pytest from langflow.services.database.models.variable.model import VariableUpdate @@ -9,7 +9,7 @@ from langflow.services.settings.constants import VARIABLES_TO_GET_FROM_ENVIRONME from langflow.services.variable.constants import CREDENTIAL_TYPE, GENERIC_TYPE from langflow.services.variable.service import DatabaseVariableService from sqlalchemy.ext.asyncio import create_async_engine -from sqlmodel import Session, SQLModel +from sqlmodel import SQLModel from sqlmodel.ext.asyncio.session import AsyncSession @@ -28,16 +28,6 @@ async def session(): yield session -def _get_variable( - session: Session, - service, - user_id: UUID | str, - name: str, - field: str, -): - return service.get_variable(user_id, name, field, session=session) - - async def test_initialize_user_variables__create_and_update(service, session: AsyncSession): user_id = uuid4() field = "" @@ -53,7 +43,7 @@ async def test_initialize_user_variables__create_and_update(service, session: As variables = await service.list_variables(user_id, session=session) for name in variables: - value = await session.run_sync(_get_variable, service, user_id, name, field) + value = await service.get_variable(user_id, name, field, session=session) assert value == env_vars[name] assert all(i in variables for i in good_vars) @@ -80,7 +70,7 @@ async def test_get_variable(service, session: AsyncSession): field = "" await service.create_variable(user_id, name, value, session=session) - result = await session.run_sync(_get_variable, service, user_id, name, field) + result = await service.get_variable(user_id, name, field, session=session) assert result == value @@ -91,7 +81,7 @@ async def test_get_variable__valueerror(service, session: AsyncSession): field = "" with pytest.raises(ValueError, match=f"{name} variable not found."): - await session.run_sync(_get_variable, service, user_id, name, field) + await service.get_variable(user_id, name, field, session=session) async def test_get_variable__typeerror(service, session: AsyncSession): @@ -103,7 +93,7 @@ async def test_get_variable__typeerror(service, session: AsyncSession): await service.create_variable(user_id, name, value, type_=type_, session=session) with pytest.raises(TypeError) as exc: - await session.run_sync(_get_variable, service, user_id, name, field) + await service.get_variable(user_id, name, field, session=session) assert name in str(exc.value) assert "purpose is to prevent the exposure of value" in str(exc.value) @@ -136,9 +126,9 @@ async def test_update_variable(service, session: AsyncSession): field = "" await service.create_variable(user_id, name, old_value, session=session) - old_recovered = await session.run_sync(_get_variable, service, user_id, name, field) + old_recovered = await service.get_variable(user_id, name, field, session=session) result = await service.update_variable(user_id, name, new_value, session=session) - new_recovered = await session.run_sync(_get_variable, service, user_id, name, field) + new_recovered = await service.get_variable(user_id, name, field, session=session) assert old_value == old_recovered assert new_value == new_recovered @@ -197,10 +187,10 @@ async def test_delete_variable(service, session: AsyncSession): field = "" await service.create_variable(user_id, name, value, session=session) - recovered = await session.run_sync(_get_variable, service, user_id, name, field) + recovered = await service.get_variable(user_id, name, field, session=session) await service.delete_variable(user_id, name, session=session) with pytest.raises(ValueError, match=f"{name} variable not found."): - await session.run_sync(_get_variable, service, user_id, name, field) + await service.get_variable(user_id, name, field, session=session) assert recovered == value @@ -220,10 +210,10 @@ async def test_delete_variable_by_id(service, session: AsyncSession): field = "field" saved = await service.create_variable(user_id, name, value, session=session) - recovered = await session.run_sync(_get_variable, service, user_id, name, field) + recovered = await service.get_variable(user_id, name, field, session=session) await service.delete_variable_by_id(user_id, saved.id, session=session) with pytest.raises(ValueError, match=f"{name} variable not found."): - await session.run_sync(_get_variable, service, user_id, name, field) + await service.get_variable(user_id, name, field, session=session) assert recovered == value