ref: Add and use update_component_build_config utility (#5226)
* Add and use update_component_build_config utility * [autofix.ci] apply automated fixes --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
5a4aef0f82
commit
384ac5e80e
21 changed files with 78 additions and 50 deletions
|
|
@ -33,7 +33,7 @@ from langflow.api.v1.schemas import (
|
|||
UploadFileResponse,
|
||||
)
|
||||
from langflow.custom.custom_component.component import Component
|
||||
from langflow.custom.utils import build_custom_component_template, get_instance_name
|
||||
from langflow.custom.utils import build_custom_component_template, get_instance_name, update_component_build_config
|
||||
from langflow.exceptions.api import APIException, InvalidChatInputError
|
||||
from langflow.graph.graph.base import Graph
|
||||
from langflow.graph.schema import RunOutputs
|
||||
|
|
@ -633,7 +633,8 @@ async def custom_component_update(
|
|||
params = await update_params_with_load_from_db_fields(cc_instance, params, load_from_db_fields)
|
||||
cc_instance.set_attributes(params)
|
||||
updated_build_config = code_request.get_template()
|
||||
await cc_instance.aupdate_build_config(
|
||||
await update_component_build_config(
|
||||
cc_instance,
|
||||
build_config=updated_build_config,
|
||||
field_value=code_request.field_value,
|
||||
field_name=code_request.field,
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from langflow.components.helpers.memory import MemoryComponent
|
|||
from langflow.components.langchain_utilities.tool_calling import (
|
||||
ToolCallingAgentComponent,
|
||||
)
|
||||
from langflow.custom.utils import update_component_build_config
|
||||
from langflow.io import BoolInput, DropdownInput, MultilineInput, Output
|
||||
from langflow.schema.dotdict import dotdict
|
||||
from langflow.schema.message import Message
|
||||
|
|
@ -136,7 +137,7 @@ class AgentComponent(ToolCallingAgentComponent):
|
|||
value.input_types = []
|
||||
return build_config
|
||||
|
||||
async def aupdate_build_config(
|
||||
async def update_build_config(
|
||||
self, build_config: dotdict, field_value: str, field_name: str | None = None
|
||||
) -> dotdict:
|
||||
# Iterate over all providers in the MODEL_PROVIDERS_DICT
|
||||
|
|
@ -145,9 +146,11 @@ class AgentComponent(ToolCallingAgentComponent):
|
|||
provider_info = MODEL_PROVIDERS_DICT.get(field_value)
|
||||
if provider_info:
|
||||
component_class = provider_info.get("component_class")
|
||||
if component_class and hasattr(component_class, "aupdate_build_config"):
|
||||
# Call the component class's aupdate_build_config method
|
||||
build_config = await component_class.aupdate_build_config(build_config, field_value, field_name)
|
||||
if component_class and hasattr(component_class, "update_build_config"):
|
||||
# Call the component class's update_build_config method
|
||||
build_config = await update_component_build_config(
|
||||
component_class, build_config, field_value, field_name
|
||||
)
|
||||
|
||||
provider_configs: dict[str, tuple[dict, list[dict]]] = {
|
||||
provider: (
|
||||
|
|
@ -213,11 +216,13 @@ class AgentComponent(ToolCallingAgentComponent):
|
|||
if provider_info:
|
||||
component_class = provider_info.get("component_class")
|
||||
prefix = provider_info.get("prefix")
|
||||
if component_class and hasattr(component_class, "aupdate_build_config"):
|
||||
# Call each component class's aupdate_build_config method
|
||||
if component_class and hasattr(component_class, "update_build_config"):
|
||||
# Call each component class's update_build_config method
|
||||
# remove the prefix from the field_name
|
||||
if isinstance(field_name, str) and isinstance(prefix, str):
|
||||
field_name = field_name.replace(prefix, "")
|
||||
build_config = await component_class.aupdate_build_config(build_config, field_value, field_name)
|
||||
build_config = await update_component_build_config(
|
||||
component_class, build_config, field_value, field_name
|
||||
)
|
||||
|
||||
return build_config
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ class SubFlowComponent(CustomComponent):
|
|||
return flow_data
|
||||
return None
|
||||
|
||||
async def aupdate_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
|
||||
async def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
|
||||
logger.debug(f"Updating build config with field value {field_value} and field name {field_name}")
|
||||
if field_name == "flow_name":
|
||||
build_config["flow_name"]["options"] = await self.get_flow_names()
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ class LMStudioEmbeddingsComponent(LCEmbeddingsModel):
|
|||
icon = "LMStudio"
|
||||
|
||||
@override
|
||||
async def aupdate_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None):
|
||||
async def update_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None):
|
||||
if field_name == "model":
|
||||
base_url_dict = build_config.get("base_url", {})
|
||||
base_url_load_from_db = base_url_dict.get("load_from_db", False)
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ class FlowToolComponent(LCToolComponent):
|
|||
return None
|
||||
|
||||
@override
|
||||
async def aupdate_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
|
||||
async def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
|
||||
if field_name == "flow_name":
|
||||
build_config["flow_name"]["options"] = self.get_flow_names()
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ class RunFlowComponent(Component):
|
|||
return [flow_data.data["name"] for flow_data in flow_data]
|
||||
|
||||
@override
|
||||
async def aupdate_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
|
||||
async def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
|
||||
if field_name == "flow_name":
|
||||
build_config["flow_name"]["options"] = await self.get_flow_names()
|
||||
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ class SubFlowComponent(Component):
|
|||
return flow_data
|
||||
return None
|
||||
|
||||
async def aupdate_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
|
||||
async def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
|
||||
if field_name == "flow_name":
|
||||
build_config["flow_name"]["options"] = await self.get_flow_names()
|
||||
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ class LMStudioModelComponent(LCModelComponent):
|
|||
name = "LMStudioModel"
|
||||
|
||||
@override
|
||||
async def aupdate_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None):
|
||||
async def update_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None):
|
||||
if field_name == "model_name":
|
||||
base_url_dict = build_config.get("base_url", {})
|
||||
base_url_load_from_db = base_url_dict.get("load_from_db", False)
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
icon = "Ollama"
|
||||
name = "OllamaModel"
|
||||
|
||||
async def aupdate_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None):
|
||||
async def update_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None):
|
||||
if field_name == "mirostat":
|
||||
if field_value == "Disabled":
|
||||
build_config["mirostat_eta"]["advanced"] = True
|
||||
|
|
|
|||
|
|
@ -98,7 +98,7 @@ class PythonCodeStructuredTool(LCToolComponent):
|
|||
]
|
||||
|
||||
@override
|
||||
async def aupdate_build_config(
|
||||
async def update_build_config(
|
||||
self, build_config: dotdict, field_value: Any, field_name: str | None = None
|
||||
) -> dotdict:
|
||||
if field_name is None:
|
||||
|
|
@ -231,7 +231,7 @@ class PythonCodeStructuredTool(LCToolComponent):
|
|||
async def post_code_processing(self, new_frontend_node: dict, current_frontend_node: dict):
|
||||
"""This function is called after the code validation is done."""
|
||||
frontend_node = await super().post_code_processing(new_frontend_node, current_frontend_node)
|
||||
frontend_node["template"] = await self.aupdate_build_config(
|
||||
frontend_node["template"] = await self.update_build_config(
|
||||
frontend_node["template"],
|
||||
frontend_node["template"]["tool_code"]["value"],
|
||||
"tool_code",
|
||||
|
|
@ -240,7 +240,7 @@ class PythonCodeStructuredTool(LCToolComponent):
|
|||
for key in frontend_node["template"]:
|
||||
if key in self.DEFAULT_KEYS:
|
||||
continue
|
||||
frontend_node["template"] = await self.aupdate_build_config(
|
||||
frontend_node["template"] = await self.update_build_config(
|
||||
frontend_node["template"], frontend_node["template"][key]["value"], key
|
||||
)
|
||||
frontend_node = await super().post_code_processing(new_frontend_node, current_frontend_node)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from collections.abc import Callable, Sequence
|
||||
from pathlib import Path
|
||||
|
|
@ -233,19 +232,10 @@ class CustomComponent(BaseComponent):
|
|||
field_value: Any,
|
||||
field_name: str | None = None,
|
||||
):
|
||||
if type(self).aupdate_build_config != CustomComponent.aupdate_build_config:
|
||||
raise NotImplementedError
|
||||
build_config[field_name]["value"] = field_value
|
||||
return build_config
|
||||
"""Updates the build configuration for the custom component.
|
||||
|
||||
async def aupdate_build_config(
|
||||
self,
|
||||
build_config: dotdict,
|
||||
field_value: Any,
|
||||
field_name: str | None = None,
|
||||
):
|
||||
if type(self).update_build_config != CustomComponent.update_build_config:
|
||||
return await asyncio.to_thread(self.update_build_config, build_config, field_value, field_name)
|
||||
Do not call directly as implementation can be a coroutine.
|
||||
"""
|
||||
build_config[field_name]["value"] = field_value
|
||||
return build_config
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
import ast
|
||||
import asyncio
|
||||
import contextlib
|
||||
import inspect
|
||||
import re
|
||||
import traceback
|
||||
from typing import Any
|
||||
|
|
@ -547,3 +549,14 @@ def get_instance_name(instance):
|
|||
if hasattr(instance, "name") and instance.name:
|
||||
name = instance.name
|
||||
return name
|
||||
|
||||
|
||||
async def update_component_build_config(
|
||||
component: CustomComponent,
|
||||
build_config: dotdict,
|
||||
field_value: Any,
|
||||
field_name: str | None = None,
|
||||
):
|
||||
if inspect.iscoroutinefunction(component.update_build_config):
|
||||
return await component.update_build_config(build_config, field_value, field_name)
|
||||
return await asyncio.to_thread(component.update_build_config, build_config, field_value, field_name)
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
|
@ -49,7 +49,7 @@ async def test_update_build_config_mirostat_disabled(component):
|
|||
field_value = "Disabled"
|
||||
field_name = "mirostat"
|
||||
|
||||
updated_config = await component.aupdate_build_config(build_config, field_value, field_name)
|
||||
updated_config = await component.update_build_config(build_config, field_value, field_name)
|
||||
|
||||
assert updated_config["mirostat_eta"]["advanced"] is True
|
||||
assert updated_config["mirostat_tau"]["advanced"] is True
|
||||
|
|
@ -65,7 +65,7 @@ async def test_update_build_config_mirostat_enabled(component):
|
|||
field_value = "Mirostat 2.0"
|
||||
field_name = "mirostat"
|
||||
|
||||
updated_config = await component.aupdate_build_config(build_config, field_value, field_name)
|
||||
updated_config = await component.update_build_config(build_config, field_value, field_name)
|
||||
|
||||
assert updated_config["mirostat_eta"]["advanced"] is False
|
||||
assert updated_config["mirostat_tau"]["advanced"] is False
|
||||
|
|
@ -88,7 +88,7 @@ async def test_update_build_config_model_name(mock_get, component):
|
|||
field_value = None
|
||||
field_name = "model_name"
|
||||
|
||||
updated_config = await component.aupdate_build_config(build_config, field_value, field_name)
|
||||
updated_config = await component.update_build_config(build_config, field_value, field_name)
|
||||
|
||||
assert updated_config["model_name"]["options"] == ["model1", "model2"]
|
||||
|
||||
|
|
@ -98,12 +98,12 @@ async def test_update_build_config_keep_alive(component):
|
|||
field_value = "Keep"
|
||||
field_name = "keep_alive_flag"
|
||||
|
||||
updated_config = await component.aupdate_build_config(build_config, field_value, field_name)
|
||||
updated_config = await component.update_build_config(build_config, field_value, field_name)
|
||||
assert updated_config["keep_alive"]["value"] == "-1"
|
||||
assert updated_config["keep_alive"]["advanced"] is True
|
||||
|
||||
field_value = "Immediately"
|
||||
updated_config = await component.aupdate_build_config(build_config, field_value, field_name)
|
||||
updated_config = await component.update_build_config(build_config, field_value, field_name)
|
||||
assert updated_config["keep_alive"]["value"] == "0"
|
||||
assert updated_config["keep_alive"]["advanced"] is True
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from langflow.components.custom_component import CustomComponent
|
|||
from langflow.components.inputs import ChatInput
|
||||
from langflow.components.models import OpenAIModelComponent
|
||||
from langflow.components.outputs import ChatOutput
|
||||
from langflow.custom.utils import update_component_build_config
|
||||
from langflow.schema import dotdict
|
||||
from langflow.template import Output
|
||||
from typing_extensions import override
|
||||
|
|
@ -65,7 +66,7 @@ def test_set_required_inputs_various_components():
|
|||
assert _assert_all_outputs_have_different_required_inputs(agent.outputs)
|
||||
|
||||
|
||||
async def test_update_build_config_backward_compatibility():
|
||||
async def test_update_component_build_config_sync():
|
||||
class TestComponent(CustomComponent):
|
||||
@override
|
||||
def update_build_config(
|
||||
|
|
@ -79,5 +80,23 @@ async def test_update_build_config_backward_compatibility():
|
|||
|
||||
component = TestComponent()
|
||||
build_config = dotdict()
|
||||
build_config = await component.aupdate_build_config(build_config, "", "")
|
||||
build_config = await update_component_build_config(component, build_config, "", "")
|
||||
assert build_config["foo"] == "bar"
|
||||
|
||||
|
||||
async def test_update_component_build_config_async():
|
||||
class TestComponent(CustomComponent):
|
||||
@override
|
||||
async def update_build_config(
|
||||
self,
|
||||
build_config: dotdict,
|
||||
field_value: Any,
|
||||
field_name: str | None = None,
|
||||
):
|
||||
build_config["foo"] = "bar"
|
||||
return build_config
|
||||
|
||||
component = TestComponent()
|
||||
build_config = dotdict()
|
||||
build_config = await update_component_build_config(component, build_config, "", "")
|
||||
assert build_config["foo"] == "bar"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue