fix: Add async aupdate_build_config to CustomComponent (#5181)
* Add async aupdate_build_config to CustomComponent * Add test of backward compatibility
This commit is contained in:
parent
a302a946f2
commit
4f5d7d93ad
17 changed files with 141 additions and 151 deletions
|
|
@ -581,7 +581,7 @@ async def custom_component(
|
|||
|
||||
built_frontend_node, component_instance = build_custom_component_template(component, user_id=user.id)
|
||||
if raw_code.frontend_node is not None:
|
||||
built_frontend_node = component_instance.post_code_processing(built_frontend_node, raw_code.frontend_node)
|
||||
built_frontend_node = await component_instance.post_code_processing(built_frontend_node, raw_code.frontend_node)
|
||||
|
||||
type_ = get_instance_name(component_instance)
|
||||
return CustomComponentResponse(data=built_frontend_node, type=type_)
|
||||
|
|
@ -630,10 +630,10 @@ async def custom_component_update(
|
|||
for field_name, field_dict in template.items()
|
||||
if isinstance(field_dict, dict) and field_dict.get("load_from_db")
|
||||
]
|
||||
params = update_params_with_load_from_db_fields(cc_instance, params, load_from_db_fields)
|
||||
params = await update_params_with_load_from_db_fields(cc_instance, params, load_from_db_fields)
|
||||
cc_instance.set_attributes(params)
|
||||
updated_build_config = code_request.get_template()
|
||||
cc_instance.update_build_config(
|
||||
await cc_instance.aupdate_build_config(
|
||||
build_config=updated_build_config,
|
||||
field_value=code_request.field_value,
|
||||
field_name=code_request.field,
|
||||
|
|
|
|||
|
|
@ -136,16 +136,18 @@ class AgentComponent(ToolCallingAgentComponent):
|
|||
value.input_types = []
|
||||
return build_config
|
||||
|
||||
def update_build_config(self, build_config: dotdict, field_value: str, field_name: str | None = None) -> dotdict:
|
||||
async def aupdate_build_config(
|
||||
self, build_config: dotdict, field_value: str, field_name: str | None = None
|
||||
) -> dotdict:
|
||||
# Iterate over all providers in the MODEL_PROVIDERS_DICT
|
||||
# Existing logic for updating build_config
|
||||
if field_name == "agent_llm":
|
||||
provider_info = MODEL_PROVIDERS_DICT.get(field_value)
|
||||
if provider_info:
|
||||
component_class = provider_info.get("component_class")
|
||||
if component_class and hasattr(component_class, "update_build_config"):
|
||||
# Call the component class's update_build_config method
|
||||
build_config = component_class.update_build_config(build_config, field_value, field_name)
|
||||
if component_class and hasattr(component_class, "aupdate_build_config"):
|
||||
# Call the component class's aupdate_build_config method
|
||||
build_config = await component_class.aupdate_build_config(build_config, field_value, field_name)
|
||||
|
||||
provider_configs: dict[str, tuple[dict, list[dict]]] = {
|
||||
provider: (
|
||||
|
|
@ -211,11 +213,11 @@ class AgentComponent(ToolCallingAgentComponent):
|
|||
if provider_info:
|
||||
component_class = provider_info.get("component_class")
|
||||
prefix = provider_info.get("prefix")
|
||||
if component_class and hasattr(component_class, "update_build_config"):
|
||||
# Call each component class's update_build_config method
|
||||
if component_class and hasattr(component_class, "aupdate_build_config"):
|
||||
# Call each component class's aupdate_build_config method
|
||||
# remove the prefix from the field_name
|
||||
if isinstance(field_name, str) and isinstance(prefix, str):
|
||||
field_name = field_name.replace(prefix, "")
|
||||
build_config = component_class.update_build_config(build_config, field_value, field_name)
|
||||
build_config = await component_class.aupdate_build_config(build_config, field_value, field_name)
|
||||
|
||||
return build_config
|
||||
|
|
|
|||
|
|
@ -16,24 +16,25 @@ class LMStudioEmbeddingsComponent(LCEmbeddingsModel):
|
|||
icon = "LMStudio"
|
||||
|
||||
@override
|
||||
def update_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None):
|
||||
async def aupdate_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None):
|
||||
if field_name == "model":
|
||||
base_url_dict = build_config.get("base_url", {})
|
||||
base_url_load_from_db = base_url_dict.get("load_from_db", False)
|
||||
base_url_value = base_url_dict.get("value")
|
||||
if base_url_load_from_db:
|
||||
base_url_value = self.variables(base_url_value)
|
||||
base_url_value = await self.variables(base_url_value, field_name)
|
||||
elif not base_url_value:
|
||||
base_url_value = "http://localhost:1234/v1"
|
||||
build_config["model"]["options"] = self.get_model(base_url_value)
|
||||
build_config["model"]["options"] = await self.get_model(base_url_value)
|
||||
|
||||
return build_config
|
||||
|
||||
def get_model(self, base_url_value: str) -> list[str]:
|
||||
@staticmethod
|
||||
async def get_model(base_url_value: str) -> list[str]:
|
||||
try:
|
||||
url = urljoin(base_url_value, "/v1/models")
|
||||
with httpx.Client() as client:
|
||||
response = client.get(url)
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
|
|
|
|||
|
|
@ -20,24 +20,25 @@ class LMStudioModelComponent(LCModelComponent):
|
|||
name = "LMStudioModel"
|
||||
|
||||
@override
|
||||
def update_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None):
|
||||
async def aupdate_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None):
|
||||
if field_name == "model_name":
|
||||
base_url_dict = build_config.get("base_url", {})
|
||||
base_url_load_from_db = base_url_dict.get("load_from_db", False)
|
||||
base_url_value = base_url_dict.get("value")
|
||||
if base_url_load_from_db:
|
||||
base_url_value = self.variables(base_url_value)
|
||||
base_url_value = await self.variables(base_url_value, field_name)
|
||||
elif not base_url_value:
|
||||
base_url_value = "http://localhost:1234/v1"
|
||||
build_config["model_name"]["options"] = self.get_model(base_url_value)
|
||||
build_config["model_name"]["options"] = await self.get_model(base_url_value)
|
||||
|
||||
return build_config
|
||||
|
||||
def get_model(self, base_url_value: str) -> list[str]:
|
||||
@staticmethod
|
||||
async def get_model(base_url_value: str) -> list[str]:
|
||||
try:
|
||||
url = urljoin(base_url_value, "/v1/models")
|
||||
with httpx.Client() as client:
|
||||
response = client.get(url)
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
icon = "Ollama"
|
||||
name = "OllamaModel"
|
||||
|
||||
def update_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None):
|
||||
async def aupdate_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None):
|
||||
if field_name == "mirostat":
|
||||
if field_value == "Disabled":
|
||||
build_config["mirostat_eta"]["advanced"] = True
|
||||
|
|
@ -40,10 +40,10 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
base_url_load_from_db = base_url_dict.get("load_from_db", False)
|
||||
base_url_value = base_url_dict.get("value")
|
||||
if base_url_load_from_db:
|
||||
base_url_value = self.variables(base_url_value, field_name)
|
||||
base_url_value = await self.variables(base_url_value, field_name)
|
||||
elif not base_url_value:
|
||||
base_url_value = "http://localhost:11434"
|
||||
build_config["model_name"]["options"] = self.get_model(base_url_value)
|
||||
build_config["model_name"]["options"] = await self.get_model(base_url_value)
|
||||
if field_name == "keep_alive_flag":
|
||||
if field_value == "Keep":
|
||||
build_config["keep_alive"]["value"] = "-1"
|
||||
|
|
@ -56,11 +56,12 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
|
||||
return build_config
|
||||
|
||||
def get_model(self, base_url_value: str) -> list[str]:
|
||||
@staticmethod
|
||||
async def get_model(base_url_value: str) -> list[str]:
|
||||
try:
|
||||
url = urljoin(base_url_value, "/api/tags")
|
||||
with httpx.Client() as client:
|
||||
response = client.get(url)
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
|
|
|
|||
|
|
@ -45,9 +45,9 @@ class PromptComponent(Component):
|
|||
)
|
||||
return frontend_node
|
||||
|
||||
def post_code_processing(self, new_frontend_node: dict, current_frontend_node: dict):
|
||||
async def post_code_processing(self, new_frontend_node: dict, current_frontend_node: dict):
|
||||
"""This function is called after the code validation is done."""
|
||||
frontend_node = super().post_code_processing(new_frontend_node, current_frontend_node)
|
||||
frontend_node = await super().post_code_processing(new_frontend_node, current_frontend_node)
|
||||
template = frontend_node["template"]["template"]["value"]
|
||||
# Kept it duplicated for backwards compatibility
|
||||
_ = process_prompt_template(
|
||||
|
|
|
|||
|
|
@ -98,7 +98,9 @@ class PythonCodeStructuredTool(LCToolComponent):
|
|||
]
|
||||
|
||||
@override
|
||||
def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None) -> dotdict:
|
||||
async def aupdate_build_config(
|
||||
self, build_config: dotdict, field_value: Any, field_name: str | None = None
|
||||
) -> dotdict:
|
||||
if field_name is None:
|
||||
return build_config
|
||||
|
||||
|
|
@ -226,22 +228,22 @@ class PythonCodeStructuredTool(LCToolComponent):
|
|||
return_direct=self.return_direct,
|
||||
)
|
||||
|
||||
def post_code_processing(self, new_frontend_node: dict, current_frontend_node: dict):
|
||||
async def post_code_processing(self, new_frontend_node: dict, current_frontend_node: dict):
|
||||
"""This function is called after the code validation is done."""
|
||||
frontend_node = super().post_code_processing(new_frontend_node, current_frontend_node)
|
||||
frontend_node["template"] = self.update_build_config(
|
||||
frontend_node = await super().post_code_processing(new_frontend_node, current_frontend_node)
|
||||
frontend_node["template"] = await self.aupdate_build_config(
|
||||
frontend_node["template"],
|
||||
frontend_node["template"]["tool_code"]["value"],
|
||||
"tool_code",
|
||||
)
|
||||
frontend_node = super().post_code_processing(new_frontend_node, current_frontend_node)
|
||||
frontend_node = await super().post_code_processing(new_frontend_node, current_frontend_node)
|
||||
for key in frontend_node["template"]:
|
||||
if key in self.DEFAULT_KEYS:
|
||||
continue
|
||||
frontend_node["template"] = self.update_build_config(
|
||||
frontend_node["template"] = await self.aupdate_build_config(
|
||||
frontend_node["template"], frontend_node["template"][key]["value"], key
|
||||
)
|
||||
frontend_node = super().post_code_processing(new_frontend_node, current_frontend_node)
|
||||
frontend_node = await super().post_code_processing(new_frontend_node, current_frontend_node)
|
||||
return frontend_node
|
||||
|
||||
def _parse_code(self, code: str) -> tuple[list[dict], list[dict]]:
|
||||
|
|
|
|||
|
|
@ -44,7 +44,6 @@ if TYPE_CHECKING:
|
|||
from langflow.graph.edge.schema import EdgeData
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
from langflow.inputs.inputs import InputTypes
|
||||
from langflow.schema import dotdict
|
||||
from langflow.schema.log import LoggableType
|
||||
|
||||
|
||||
|
|
@ -391,14 +390,6 @@ class Component(CustomComponent):
|
|||
self._validate_inputs(params)
|
||||
self._validate_outputs()
|
||||
|
||||
def update_inputs(
|
||||
self,
|
||||
build_config: dotdict,
|
||||
field_value: Any,
|
||||
field_name: str | None = None,
|
||||
):
|
||||
return self.update_build_config(build_config, field_value, field_name)
|
||||
|
||||
def run_and_validate_update_outputs(self, frontend_node: dict, field_name: str, field_value: Any):
|
||||
frontend_node = self.update_outputs(frontend_node, field_name, field_value)
|
||||
if field_name == "tool_mode" or frontend_node.get("tool_mode"):
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable, Sequence
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, ClassVar
|
||||
|
|
@ -12,7 +13,7 @@ from pydantic import BaseModel
|
|||
from langflow.custom.custom_component.base_component import BaseComponent
|
||||
from langflow.helpers.flow import list_flows, load_flow, run_flow
|
||||
from langflow.schema import Data
|
||||
from langflow.services.deps import get_storage_service, get_variable_service, session_scope
|
||||
from langflow.services.deps import async_session_scope, get_storage_service, get_variable_service, session_scope
|
||||
from langflow.services.storage.service import StorageService
|
||||
from langflow.template.utils import update_frontend_node_with_template_values
|
||||
from langflow.type_extraction.type_extraction import post_process_type
|
||||
|
|
@ -230,6 +231,19 @@ class CustomComponent(BaseComponent):
|
|||
field_value: Any,
|
||||
field_name: str | None = None,
|
||||
):
|
||||
if type(self).aupdate_build_config != CustomComponent.aupdate_build_config:
|
||||
raise NotImplementedError
|
||||
build_config[field_name]["value"] = field_value
|
||||
return build_config
|
||||
|
||||
async def aupdate_build_config(
|
||||
self,
|
||||
build_config: dotdict,
|
||||
field_value: Any,
|
||||
field_name: str | None = None,
|
||||
):
|
||||
if type(self).update_build_config != CustomComponent.update_build_config:
|
||||
return await asyncio.to_thread(self.update_build_config, build_config, field_value, field_name)
|
||||
build_config[field_name]["value"] = field_value
|
||||
return build_config
|
||||
|
||||
|
|
@ -410,8 +424,7 @@ class CustomComponent(BaseComponent):
|
|||
self._template_config = self.build_template_config()
|
||||
return self._template_config
|
||||
|
||||
@property
|
||||
def variables(self):
|
||||
async def variables(self, name: str, field: str):
|
||||
"""Returns the variable for the current user with the specified name.
|
||||
|
||||
Raises:
|
||||
|
|
@ -420,18 +433,14 @@ class CustomComponent(BaseComponent):
|
|||
Returns:
|
||||
The variable for the current user with the specified name.
|
||||
"""
|
||||
|
||||
def get_variable(name: str, field: str):
|
||||
if hasattr(self, "_user_id") and not self.user_id:
|
||||
msg = f"User id is not set for {self.__class__.__name__}"
|
||||
raise ValueError(msg)
|
||||
variable_service = get_variable_service() # Get service instance
|
||||
# Retrieve and decrypt the variable by name for the current user
|
||||
with session_scope() as session:
|
||||
user_id = self.user_id or ""
|
||||
return variable_service.get_variable(user_id=user_id, name=name, field=field, session=session)
|
||||
|
||||
return get_variable
|
||||
if hasattr(self, "_user_id") and not self.user_id:
|
||||
msg = f"User id is not set for {self.__class__.__name__}"
|
||||
raise ValueError(msg)
|
||||
variable_service = get_variable_service() # Get service instance
|
||||
# Retrieve and decrypt the variable by name for the current user
|
||||
async with async_session_scope() as session:
|
||||
user_id = self.user_id or ""
|
||||
return await variable_service.get_variable(user_id=user_id, name=name, field=field, session=session)
|
||||
|
||||
def list_key_names(self):
|
||||
"""Lists the names of the variables for the current user.
|
||||
|
|
@ -519,7 +528,7 @@ class CustomComponent(BaseComponent):
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def post_code_processing(self, new_frontend_node: dict, current_frontend_node: dict):
|
||||
async def post_code_processing(self, new_frontend_node: dict, current_frontend_node: dict):
|
||||
"""This function is called after the code validation is done."""
|
||||
return update_frontend_node_with_template_values(
|
||||
frontend_node=new_frontend_node, raw_frontend_node=current_frontend_node
|
||||
|
|
|
|||
|
|
@ -506,42 +506,6 @@ async def abuild_custom_components(components_paths: list[str]):
|
|||
return custom_components_from_file
|
||||
|
||||
|
||||
def update_field_dict(
|
||||
custom_component_instance: "CustomComponent",
|
||||
field_dict: dict,
|
||||
build_config: dict,
|
||||
*,
|
||||
update_field: str | None = None,
|
||||
update_field_value: Any | None = None,
|
||||
call: bool = False,
|
||||
):
|
||||
"""Update the field dictionary by calling options() or value() if they are callable."""
|
||||
if (
|
||||
("real_time_refresh" in field_dict or "refresh_button" in field_dict)
|
||||
and any(
|
||||
(
|
||||
field_dict.get("real_time_refresh", False),
|
||||
field_dict.get("refresh_button", False),
|
||||
)
|
||||
)
|
||||
and call
|
||||
):
|
||||
try:
|
||||
dd_build_config = dotdict(build_config)
|
||||
custom_component_instance.update_build_config(
|
||||
build_config=dd_build_config,
|
||||
field_value=update_field,
|
||||
field_name=update_field_value,
|
||||
)
|
||||
build_config = dd_build_config
|
||||
except Exception as exc:
|
||||
msg = f"Error while running update_build_config: {exc}"
|
||||
logger.exception(msg)
|
||||
raise UpdateBuildConfigError(msg) from exc
|
||||
|
||||
return build_config
|
||||
|
||||
|
||||
def sanitize_field_config(field_config: dict | Input):
|
||||
# If any of the already existing keys are in field_config, remove them
|
||||
field_dict = field_config.to_dict() if isinstance(field_config, Input) else field_config
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ async def get_instance_results(
|
|||
fallback_to_env_vars: bool = False,
|
||||
base_type: str = "component",
|
||||
):
|
||||
custom_params = update_params_with_load_from_db_fields(
|
||||
custom_params = await update_params_with_load_from_db_fields(
|
||||
custom_component, custom_params, vertex.load_from_db_fields, fallback_to_env_vars=fallback_to_env_vars
|
||||
)
|
||||
with warnings.catch_warnings():
|
||||
|
|
@ -103,7 +103,7 @@ def convert_kwargs(params):
|
|||
return params
|
||||
|
||||
|
||||
def update_params_with_load_from_db_fields(
|
||||
async def update_params_with_load_from_db_fields(
|
||||
custom_component: CustomComponent,
|
||||
params,
|
||||
load_from_db_fields,
|
||||
|
|
@ -115,7 +115,7 @@ def update_params_with_load_from_db_fields(
|
|||
continue
|
||||
|
||||
try:
|
||||
key = custom_component.variables(params[field], field)
|
||||
key = await custom_component.variables(params[field], field)
|
||||
except ValueError as e:
|
||||
if any(reason in str(e) for reason in ["User id is not set", "variable not found."]):
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ class VariableService(Service):
|
|||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_variable(self, user_id: UUID | str, name: str, field: str, session: Session) -> str:
|
||||
def get_variable_sync(self, user_id: UUID | str, name: str, field: str, session: Session) -> str:
|
||||
"""Get a variable value.
|
||||
|
||||
Args:
|
||||
|
|
@ -36,6 +36,20 @@ class VariableService(Service):
|
|||
The value of the variable.
|
||||
"""
|
||||
|
||||
async def get_variable(self, user_id: UUID | str, name: str, field: str, session: AsyncSession) -> str:
|
||||
"""Async get a variable value.
|
||||
|
||||
Args:
|
||||
user_id: The user ID.
|
||||
name: The name of the variable.
|
||||
field: The field of the variable.
|
||||
session: The database session.
|
||||
|
||||
Returns:
|
||||
The value of the variable.
|
||||
"""
|
||||
return await session.run_sync(lambda session_: self.get_variable_sync(user_id, name, field, session_))
|
||||
|
||||
@abc.abstractmethod
|
||||
def list_variables_sync(self, user_id: UUID | str, session: Session) -> list[str | None]:
|
||||
"""List all variables.
|
||||
|
|
@ -48,7 +62,6 @@ class VariableService(Service):
|
|||
A list of variable names.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def list_variables(self, user_id: UUID | str, session: AsyncSession) -> list[str | None]:
|
||||
"""List all variables.
|
||||
|
||||
|
|
@ -59,6 +72,7 @@ class VariableService(Service):
|
|||
Returns:
|
||||
A list of variable names.
|
||||
"""
|
||||
return await session.run_sync(lambda session_: self.list_variables_sync(user_id, session_))
|
||||
|
||||
@abc.abstractmethod
|
||||
async def update_variable(self, user_id: UUID | str, name: str, value: str, session: AsyncSession) -> Variable:
|
||||
|
|
|
|||
|
|
@ -79,7 +79,7 @@ class KubernetesSecretService(VariableService, Service):
|
|||
raise ValueError(msg)
|
||||
|
||||
@override
|
||||
def get_variable(
|
||||
def get_variable_sync(
|
||||
self,
|
||||
user_id: UUID | str,
|
||||
name: str,
|
||||
|
|
@ -114,14 +114,6 @@ class KubernetesSecretService(VariableService, Service):
|
|||
names.append(key)
|
||||
return names
|
||||
|
||||
@override
|
||||
async def list_variables(
|
||||
self,
|
||||
user_id: UUID | str,
|
||||
session: AsyncSession,
|
||||
) -> list[str | None]:
|
||||
return await asyncio.to_thread(self.list_variables_sync, user_id, session.sync_session)
|
||||
|
||||
def _update_variable(
|
||||
self,
|
||||
user_id: UUID | str,
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ class DatabaseVariableService(VariableService, Service):
|
|||
except Exception as e: # noqa: BLE001
|
||||
logger.exception(f"Error processing {var_name} variable: {e!s}")
|
||||
|
||||
def get_variable(
|
||||
def get_variable_sync(
|
||||
self,
|
||||
user_id: UUID | str,
|
||||
name: str,
|
||||
|
|
|
|||
|
|
@ -11,8 +11,8 @@ def component():
|
|||
return ChatOllamaComponent()
|
||||
|
||||
|
||||
@patch("httpx.Client.get")
|
||||
def test_get_model_success(mock_get, component):
|
||||
@patch("httpx.AsyncClient.get")
|
||||
async def test_get_model_success(mock_get, component):
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"models": [{"name": "model1"}, {"name": "model2"}]}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
|
|
@ -20,7 +20,7 @@ def test_get_model_success(mock_get, component):
|
|||
|
||||
base_url = "http://localhost:11434"
|
||||
|
||||
model_names = component.get_model(base_url)
|
||||
model_names = await component.get_model(base_url)
|
||||
|
||||
expected_url = urljoin(base_url, "/api/tags")
|
||||
|
||||
|
|
@ -29,8 +29,8 @@ def test_get_model_success(mock_get, component):
|
|||
assert model_names == ["model1", "model2"]
|
||||
|
||||
|
||||
@patch("httpx.Client.get")
|
||||
def test_get_model_failure(mock_get, component):
|
||||
@patch("httpx.AsyncClient.get")
|
||||
async def test_get_model_failure(mock_get, component):
|
||||
# Mock the response for the HTTP GET request to raise an exception
|
||||
mock_get.side_effect = Exception("HTTP request failed")
|
||||
|
||||
|
|
@ -38,10 +38,10 @@ def test_get_model_failure(mock_get, component):
|
|||
|
||||
# Assert that the ValueError is raised when an exception occurs
|
||||
with pytest.raises(ValueError, match="Could not retrieve models"):
|
||||
component.get_model(url)
|
||||
await component.get_model(url)
|
||||
|
||||
|
||||
def test_update_build_config_mirostat_disabled(component):
|
||||
async def test_update_build_config_mirostat_disabled(component):
|
||||
build_config = {
|
||||
"mirostat_eta": {"advanced": False, "value": 0.1},
|
||||
"mirostat_tau": {"advanced": False, "value": 5},
|
||||
|
|
@ -49,7 +49,7 @@ def test_update_build_config_mirostat_disabled(component):
|
|||
field_value = "Disabled"
|
||||
field_name = "mirostat"
|
||||
|
||||
updated_config = component.update_build_config(build_config, field_value, field_name)
|
||||
updated_config = await component.aupdate_build_config(build_config, field_value, field_name)
|
||||
|
||||
assert updated_config["mirostat_eta"]["advanced"] is True
|
||||
assert updated_config["mirostat_tau"]["advanced"] is True
|
||||
|
|
@ -57,7 +57,7 @@ def test_update_build_config_mirostat_disabled(component):
|
|||
assert updated_config["mirostat_tau"]["value"] is None
|
||||
|
||||
|
||||
def test_update_build_config_mirostat_enabled(component):
|
||||
async def test_update_build_config_mirostat_enabled(component):
|
||||
build_config = {
|
||||
"mirostat_eta": {"advanced": False, "value": None},
|
||||
"mirostat_tau": {"advanced": False, "value": None},
|
||||
|
|
@ -65,7 +65,7 @@ def test_update_build_config_mirostat_enabled(component):
|
|||
field_value = "Mirostat 2.0"
|
||||
field_name = "mirostat"
|
||||
|
||||
updated_config = component.update_build_config(build_config, field_value, field_name)
|
||||
updated_config = await component.aupdate_build_config(build_config, field_value, field_name)
|
||||
|
||||
assert updated_config["mirostat_eta"]["advanced"] is False
|
||||
assert updated_config["mirostat_tau"]["advanced"] is False
|
||||
|
|
@ -73,8 +73,8 @@ def test_update_build_config_mirostat_enabled(component):
|
|||
assert updated_config["mirostat_tau"]["value"] == 10
|
||||
|
||||
|
||||
@patch("httpx.Client.get")
|
||||
def test_update_build_config_model_name(mock_get, component):
|
||||
@patch("httpx.AsyncClient.get")
|
||||
async def test_update_build_config_model_name(mock_get, component):
|
||||
# Mock the response for the HTTP GET request
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"models": [{"name": "model1"}, {"name": "model2"}]}
|
||||
|
|
@ -88,22 +88,22 @@ def test_update_build_config_model_name(mock_get, component):
|
|||
field_value = None
|
||||
field_name = "model_name"
|
||||
|
||||
updated_config = component.update_build_config(build_config, field_value, field_name)
|
||||
updated_config = await component.aupdate_build_config(build_config, field_value, field_name)
|
||||
|
||||
assert updated_config["model_name"]["options"] == ["model1", "model2"]
|
||||
|
||||
|
||||
def test_update_build_config_keep_alive(component):
|
||||
async def test_update_build_config_keep_alive(component):
|
||||
build_config = {"keep_alive": {"value": None, "advanced": False}}
|
||||
field_value = "Keep"
|
||||
field_name = "keep_alive_flag"
|
||||
|
||||
updated_config = component.update_build_config(build_config, field_value, field_name)
|
||||
updated_config = await component.aupdate_build_config(build_config, field_value, field_name)
|
||||
assert updated_config["keep_alive"]["value"] == "-1"
|
||||
assert updated_config["keep_alive"]["advanced"] is True
|
||||
|
||||
field_value = "Immediately"
|
||||
updated_config = component.update_build_config(build_config, field_value, field_name)
|
||||
updated_config = await component.aupdate_build_config(build_config, field_value, field_name)
|
||||
assert updated_config["keep_alive"]["value"] == "0"
|
||||
assert updated_config["keep_alive"]["advanced"] is True
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,15 @@
|
|||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from langflow.components.agents import AgentComponent
|
||||
from langflow.components.crewai import CrewAIAgentComponent, SequentialTaskComponent
|
||||
from langflow.components.custom_component import CustomComponent
|
||||
from langflow.components.inputs import ChatInput
|
||||
from langflow.components.models import OpenAIModelComponent
|
||||
from langflow.components.outputs import ChatOutput
|
||||
from langflow.schema import dotdict
|
||||
from langflow.template import Output
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
def test_set_invalid_output():
|
||||
|
|
@ -58,3 +63,21 @@ def test_set_required_inputs_various_components():
|
|||
assert _assert_all_outputs_have_different_required_inputs(chatoutput.outputs)
|
||||
assert _assert_all_outputs_have_different_required_inputs(task.outputs)
|
||||
assert _assert_all_outputs_have_different_required_inputs(agent.outputs)
|
||||
|
||||
|
||||
async def test_update_build_config_backward_compatibility():
|
||||
class TestComponent(CustomComponent):
|
||||
@override
|
||||
def update_build_config(
|
||||
self,
|
||||
build_config: dotdict,
|
||||
field_value: Any,
|
||||
field_name: str | None = None,
|
||||
):
|
||||
build_config["foo"] = "bar"
|
||||
return build_config
|
||||
|
||||
component = TestComponent()
|
||||
build_config = dotdict()
|
||||
build_config = await component.aupdate_build_config(build_config, "", "")
|
||||
assert build_config["foo"] == "bar"
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from datetime import datetime
|
||||
from unittest.mock import patch
|
||||
from uuid import UUID, uuid4
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from langflow.services.database.models.variable.model import VariableUpdate
|
||||
|
|
@ -9,7 +9,7 @@ from langflow.services.settings.constants import VARIABLES_TO_GET_FROM_ENVIRONME
|
|||
from langflow.services.variable.constants import CREDENTIAL_TYPE, GENERIC_TYPE
|
||||
from langflow.services.variable.service import DatabaseVariableService
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlmodel import Session, SQLModel
|
||||
from sqlmodel import SQLModel
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
|
|
@ -28,16 +28,6 @@ async def session():
|
|||
yield session
|
||||
|
||||
|
||||
def _get_variable(
|
||||
session: Session,
|
||||
service,
|
||||
user_id: UUID | str,
|
||||
name: str,
|
||||
field: str,
|
||||
):
|
||||
return service.get_variable(user_id, name, field, session=session)
|
||||
|
||||
|
||||
async def test_initialize_user_variables__create_and_update(service, session: AsyncSession):
|
||||
user_id = uuid4()
|
||||
field = ""
|
||||
|
|
@ -53,7 +43,7 @@ async def test_initialize_user_variables__create_and_update(service, session: As
|
|||
|
||||
variables = await service.list_variables(user_id, session=session)
|
||||
for name in variables:
|
||||
value = await session.run_sync(_get_variable, service, user_id, name, field)
|
||||
value = await service.get_variable(user_id, name, field, session=session)
|
||||
assert value == env_vars[name]
|
||||
|
||||
assert all(i in variables for i in good_vars)
|
||||
|
|
@ -80,7 +70,7 @@ async def test_get_variable(service, session: AsyncSession):
|
|||
field = ""
|
||||
await service.create_variable(user_id, name, value, session=session)
|
||||
|
||||
result = await session.run_sync(_get_variable, service, user_id, name, field)
|
||||
result = await service.get_variable(user_id, name, field, session=session)
|
||||
|
||||
assert result == value
|
||||
|
||||
|
|
@ -91,7 +81,7 @@ async def test_get_variable__valueerror(service, session: AsyncSession):
|
|||
field = ""
|
||||
|
||||
with pytest.raises(ValueError, match=f"{name} variable not found."):
|
||||
await session.run_sync(_get_variable, service, user_id, name, field)
|
||||
await service.get_variable(user_id, name, field, session=session)
|
||||
|
||||
|
||||
async def test_get_variable__typeerror(service, session: AsyncSession):
|
||||
|
|
@ -103,7 +93,7 @@ async def test_get_variable__typeerror(service, session: AsyncSession):
|
|||
await service.create_variable(user_id, name, value, type_=type_, session=session)
|
||||
|
||||
with pytest.raises(TypeError) as exc:
|
||||
await session.run_sync(_get_variable, service, user_id, name, field)
|
||||
await service.get_variable(user_id, name, field, session=session)
|
||||
|
||||
assert name in str(exc.value)
|
||||
assert "purpose is to prevent the exposure of value" in str(exc.value)
|
||||
|
|
@ -136,9 +126,9 @@ async def test_update_variable(service, session: AsyncSession):
|
|||
field = ""
|
||||
await service.create_variable(user_id, name, old_value, session=session)
|
||||
|
||||
old_recovered = await session.run_sync(_get_variable, service, user_id, name, field)
|
||||
old_recovered = await service.get_variable(user_id, name, field, session=session)
|
||||
result = await service.update_variable(user_id, name, new_value, session=session)
|
||||
new_recovered = await session.run_sync(_get_variable, service, user_id, name, field)
|
||||
new_recovered = await service.get_variable(user_id, name, field, session=session)
|
||||
|
||||
assert old_value == old_recovered
|
||||
assert new_value == new_recovered
|
||||
|
|
@ -197,10 +187,10 @@ async def test_delete_variable(service, session: AsyncSession):
|
|||
field = ""
|
||||
|
||||
await service.create_variable(user_id, name, value, session=session)
|
||||
recovered = await session.run_sync(_get_variable, service, user_id, name, field)
|
||||
recovered = await service.get_variable(user_id, name, field, session=session)
|
||||
await service.delete_variable(user_id, name, session=session)
|
||||
with pytest.raises(ValueError, match=f"{name} variable not found."):
|
||||
await session.run_sync(_get_variable, service, user_id, name, field)
|
||||
await service.get_variable(user_id, name, field, session=session)
|
||||
|
||||
assert recovered == value
|
||||
|
||||
|
|
@ -220,10 +210,10 @@ async def test_delete_variable_by_id(service, session: AsyncSession):
|
|||
field = "field"
|
||||
|
||||
saved = await service.create_variable(user_id, name, value, session=session)
|
||||
recovered = await session.run_sync(_get_variable, service, user_id, name, field)
|
||||
recovered = await service.get_variable(user_id, name, field, session=session)
|
||||
await service.delete_variable_by_id(user_id, saved.id, session=session)
|
||||
with pytest.raises(ValueError, match=f"{name} variable not found."):
|
||||
await session.run_sync(_get_variable, service, user_id, name, field)
|
||||
await service.get_variable(user_id, name, field, session=session)
|
||||
|
||||
assert recovered == value
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue