ref: Add and use update_component_build_config utility (#5226)

* Add and use update_component_build_config utility

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Christophe Bornet 2024-12-13 00:58:19 +01:00 committed by GitHub
commit 384ac5e80e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 78 additions and 50 deletions

View file

@ -33,7 +33,7 @@ from langflow.api.v1.schemas import (
UploadFileResponse,
)
from langflow.custom.custom_component.component import Component
from langflow.custom.utils import build_custom_component_template, get_instance_name
from langflow.custom.utils import build_custom_component_template, get_instance_name, update_component_build_config
from langflow.exceptions.api import APIException, InvalidChatInputError
from langflow.graph.graph.base import Graph
from langflow.graph.schema import RunOutputs
@ -633,7 +633,8 @@ async def custom_component_update(
params = await update_params_with_load_from_db_fields(cc_instance, params, load_from_db_fields)
cc_instance.set_attributes(params)
updated_build_config = code_request.get_template()
await cc_instance.aupdate_build_config(
await update_component_build_config(
cc_instance,
build_config=updated_build_config,
field_value=code_request.field_value,
field_name=code_request.field,

View file

@ -11,6 +11,7 @@ from langflow.components.helpers.memory import MemoryComponent
from langflow.components.langchain_utilities.tool_calling import (
ToolCallingAgentComponent,
)
from langflow.custom.utils import update_component_build_config
from langflow.io import BoolInput, DropdownInput, MultilineInput, Output
from langflow.schema.dotdict import dotdict
from langflow.schema.message import Message
@ -136,7 +137,7 @@ class AgentComponent(ToolCallingAgentComponent):
value.input_types = []
return build_config
async def aupdate_build_config(
async def update_build_config(
self, build_config: dotdict, field_value: str, field_name: str | None = None
) -> dotdict:
# Iterate over all providers in the MODEL_PROVIDERS_DICT
@ -145,9 +146,11 @@ class AgentComponent(ToolCallingAgentComponent):
provider_info = MODEL_PROVIDERS_DICT.get(field_value)
if provider_info:
component_class = provider_info.get("component_class")
if component_class and hasattr(component_class, "aupdate_build_config"):
# Call the component class's aupdate_build_config method
build_config = await component_class.aupdate_build_config(build_config, field_value, field_name)
if component_class and hasattr(component_class, "update_build_config"):
# Call the component class's update_build_config method
build_config = await update_component_build_config(
component_class, build_config, field_value, field_name
)
provider_configs: dict[str, tuple[dict, list[dict]]] = {
provider: (
@ -213,11 +216,13 @@ class AgentComponent(ToolCallingAgentComponent):
if provider_info:
component_class = provider_info.get("component_class")
prefix = provider_info.get("prefix")
if component_class and hasattr(component_class, "aupdate_build_config"):
# Call each component class's aupdate_build_config method
if component_class and hasattr(component_class, "update_build_config"):
# Call each component class's update_build_config method
# remove the prefix from the field_name
if isinstance(field_name, str) and isinstance(prefix, str):
field_name = field_name.replace(prefix, "")
build_config = await component_class.aupdate_build_config(build_config, field_value, field_name)
build_config = await update_component_build_config(
component_class, build_config, field_value, field_name
)
return build_config

View file

@ -35,7 +35,7 @@ class SubFlowComponent(CustomComponent):
return flow_data
return None
async def aupdate_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
async def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
logger.debug(f"Updating build config with field value {field_value} and field name {field_name}")
if field_name == "flow_name":
build_config["flow_name"]["options"] = await self.get_flow_names()

View file

@ -16,7 +16,7 @@ class LMStudioEmbeddingsComponent(LCEmbeddingsModel):
icon = "LMStudio"
@override
async def aupdate_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None):
async def update_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None):
if field_name == "model":
base_url_dict = build_config.get("base_url", {})
base_url_load_from_db = base_url_dict.get("load_from_db", False)

View file

@ -42,7 +42,7 @@ class FlowToolComponent(LCToolComponent):
return None
@override
async def aupdate_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
async def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
if field_name == "flow_name":
build_config["flow_name"]["options"] = self.get_flow_names()

View file

@ -23,7 +23,7 @@ class RunFlowComponent(Component):
return [flow_data.data["name"] for flow_data in flow_data]
@override
async def aupdate_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
async def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
if field_name == "flow_name":
build_config["flow_name"]["options"] = await self.get_flow_names()

View file

@ -29,7 +29,7 @@ class SubFlowComponent(Component):
return flow_data
return None
async def aupdate_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
async def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
if field_name == "flow_name":
build_config["flow_name"]["options"] = await self.get_flow_names()

View file

@ -19,7 +19,7 @@ class LMStudioModelComponent(LCModelComponent):
name = "LMStudioModel"
@override
async def aupdate_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None):
async def update_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None):
if field_name == "model_name":
base_url_dict = build_config.get("base_url", {})
base_url_load_from_db = base_url_dict.get("load_from_db", False)

View file

@ -16,7 +16,7 @@ class ChatOllamaComponent(LCModelComponent):
icon = "Ollama"
name = "OllamaModel"
async def aupdate_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None):
async def update_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None):
if field_name == "mirostat":
if field_value == "Disabled":
build_config["mirostat_eta"]["advanced"] = True

View file

@ -98,7 +98,7 @@ class PythonCodeStructuredTool(LCToolComponent):
]
@override
async def aupdate_build_config(
async def update_build_config(
self, build_config: dotdict, field_value: Any, field_name: str | None = None
) -> dotdict:
if field_name is None:
@ -231,7 +231,7 @@ class PythonCodeStructuredTool(LCToolComponent):
async def post_code_processing(self, new_frontend_node: dict, current_frontend_node: dict):
"""This function is called after the code validation is done."""
frontend_node = await super().post_code_processing(new_frontend_node, current_frontend_node)
frontend_node["template"] = await self.aupdate_build_config(
frontend_node["template"] = await self.update_build_config(
frontend_node["template"],
frontend_node["template"]["tool_code"]["value"],
"tool_code",
@ -240,7 +240,7 @@ class PythonCodeStructuredTool(LCToolComponent):
for key in frontend_node["template"]:
if key in self.DEFAULT_KEYS:
continue
frontend_node["template"] = await self.aupdate_build_config(
frontend_node["template"] = await self.update_build_config(
frontend_node["template"], frontend_node["template"][key]["value"], key
)
frontend_node = await super().post_code_processing(new_frontend_node, current_frontend_node)

View file

@ -1,6 +1,5 @@
from __future__ import annotations
import asyncio
import uuid
from collections.abc import Callable, Sequence
from pathlib import Path
@ -233,19 +232,10 @@ class CustomComponent(BaseComponent):
field_value: Any,
field_name: str | None = None,
):
if type(self).aupdate_build_config != CustomComponent.aupdate_build_config:
raise NotImplementedError
build_config[field_name]["value"] = field_value
return build_config
"""Updates the build configuration for the custom component.
async def aupdate_build_config(
self,
build_config: dotdict,
field_value: Any,
field_name: str | None = None,
):
if type(self).update_build_config != CustomComponent.update_build_config:
return await asyncio.to_thread(self.update_build_config, build_config, field_value, field_name)
Do not call directly as implementation can be a coroutine.
"""
build_config[field_name]["value"] = field_value
return build_config

View file

@ -1,5 +1,7 @@
import ast
import asyncio
import contextlib
import inspect
import re
import traceback
from typing import Any
@ -547,3 +549,14 @@ def get_instance_name(instance):
if hasattr(instance, "name") and instance.name:
name = instance.name
return name
async def update_component_build_config(
component: CustomComponent,
build_config: dotdict,
field_value: Any,
field_name: str | None = None,
):
if inspect.iscoroutinefunction(component.update_build_config):
return await component.update_build_config(build_config, field_value, field_name)
return await asyncio.to_thread(component.update_build_config, build_config, field_value, field_name)

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -49,7 +49,7 @@ async def test_update_build_config_mirostat_disabled(component):
field_value = "Disabled"
field_name = "mirostat"
updated_config = await component.aupdate_build_config(build_config, field_value, field_name)
updated_config = await component.update_build_config(build_config, field_value, field_name)
assert updated_config["mirostat_eta"]["advanced"] is True
assert updated_config["mirostat_tau"]["advanced"] is True
@ -65,7 +65,7 @@ async def test_update_build_config_mirostat_enabled(component):
field_value = "Mirostat 2.0"
field_name = "mirostat"
updated_config = await component.aupdate_build_config(build_config, field_value, field_name)
updated_config = await component.update_build_config(build_config, field_value, field_name)
assert updated_config["mirostat_eta"]["advanced"] is False
assert updated_config["mirostat_tau"]["advanced"] is False
@ -88,7 +88,7 @@ async def test_update_build_config_model_name(mock_get, component):
field_value = None
field_name = "model_name"
updated_config = await component.aupdate_build_config(build_config, field_value, field_name)
updated_config = await component.update_build_config(build_config, field_value, field_name)
assert updated_config["model_name"]["options"] == ["model1", "model2"]
@ -98,12 +98,12 @@ async def test_update_build_config_keep_alive(component):
field_value = "Keep"
field_name = "keep_alive_flag"
updated_config = await component.aupdate_build_config(build_config, field_value, field_name)
updated_config = await component.update_build_config(build_config, field_value, field_name)
assert updated_config["keep_alive"]["value"] == "-1"
assert updated_config["keep_alive"]["advanced"] is True
field_value = "Immediately"
updated_config = await component.aupdate_build_config(build_config, field_value, field_name)
updated_config = await component.update_build_config(build_config, field_value, field_name)
assert updated_config["keep_alive"]["value"] == "0"
assert updated_config["keep_alive"]["advanced"] is True

View file

@ -7,6 +7,7 @@ from langflow.components.custom_component import CustomComponent
from langflow.components.inputs import ChatInput
from langflow.components.models import OpenAIModelComponent
from langflow.components.outputs import ChatOutput
from langflow.custom.utils import update_component_build_config
from langflow.schema import dotdict
from langflow.template import Output
from typing_extensions import override
@ -65,7 +66,7 @@ def test_set_required_inputs_various_components():
assert _assert_all_outputs_have_different_required_inputs(agent.outputs)
async def test_update_build_config_backward_compatibility():
async def test_update_component_build_config_sync():
class TestComponent(CustomComponent):
@override
def update_build_config(
@ -79,5 +80,23 @@ async def test_update_build_config_backward_compatibility():
component = TestComponent()
build_config = dotdict()
build_config = await component.aupdate_build_config(build_config, "", "")
build_config = await update_component_build_config(component, build_config, "", "")
assert build_config["foo"] == "bar"
async def test_update_component_build_config_async():
class TestComponent(CustomComponent):
@override
async def update_build_config(
self,
build_config: dotdict,
field_value: Any,
field_name: str | None = None,
):
build_config["foo"] = "bar"
return build_config
component = TestComponent()
build_config = dotdict()
build_config = await update_component_build_config(component, build_config, "", "")
assert build_config["foo"] == "bar"