fix: Agent Component update_build_config, Nvidia and other model dynamic parameter loading (#4556)
* update in update_build_config added support for dynamic models. * [autofix.ci] apply automated fixes * Add type checks for 'field_name' and 'prefix' in agent config update --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org> Co-authored-by: Nadir J <31660040+NadirJ@users.noreply.github.com>
This commit is contained in:
parent
0abc51d75a
commit
7b78a169d0
2 changed files with 30 additions and 13 deletions
|
|
@ -22,7 +22,7 @@ def set_advanced_true(component_input):
|
|||
|
||||
|
||||
def create_input_fields_dict(inputs, prefix):
|
||||
return {f"{prefix}_{input_.name}": input_ for input_ in inputs}
|
||||
return {f"{prefix}{input_.name}": input_ for input_ in inputs}
|
||||
|
||||
|
||||
OPENAI_INPUTS = get_filtered_inputs(OpenAIModelComponent)
|
||||
|
|
@ -35,11 +35,11 @@ AMAZON_BEDROCK_INPUTS = get_filtered_inputs(AmazonBedrockComponent)
|
|||
OPENAI_FIELDS = {input_.name: input_ for input_ in OPENAI_INPUTS}
|
||||
|
||||
|
||||
AZURE_FIELDS = create_input_fields_dict(AZURE_INPUTS, "azure")
|
||||
GROQ_FIELDS = create_input_fields_dict(GROQ_INPUTS, "groq")
|
||||
ANTHROPIC_FIELDS = create_input_fields_dict(ANTHROPIC_INPUTS, "anthropic")
|
||||
NVIDIA_FIELDS = create_input_fields_dict(NVIDIA_INPUTS, "nvidia")
|
||||
AMAZON_BEDROCK_FIELDS = create_input_fields_dict(AMAZON_BEDROCK_INPUTS, "amazon_bedrock")
|
||||
AZURE_FIELDS = create_input_fields_dict(AZURE_INPUTS, "")
|
||||
GROQ_FIELDS = create_input_fields_dict(GROQ_INPUTS, "")
|
||||
ANTHROPIC_FIELDS = create_input_fields_dict(ANTHROPIC_INPUTS, "")
|
||||
NVIDIA_FIELDS = create_input_fields_dict(NVIDIA_INPUTS, "")
|
||||
AMAZON_BEDROCK_FIELDS = create_input_fields_dict(AMAZON_BEDROCK_INPUTS, "")
|
||||
|
||||
MODEL_PROVIDERS = ["Azure OpenAI", "OpenAI", "Groq", "Anthropic", "NVIDIA", "Amazon Bedrock"]
|
||||
|
||||
|
|
@ -47,7 +47,7 @@ MODEL_PROVIDERS_DICT = {
|
|||
"Azure OpenAI": {
|
||||
"fields": AZURE_FIELDS,
|
||||
"inputs": AZURE_INPUTS,
|
||||
"prefix": "azure_",
|
||||
"prefix": "",
|
||||
"component_class": AzureChatOpenAIComponent(),
|
||||
},
|
||||
"OpenAI": {
|
||||
|
|
@ -60,19 +60,19 @@ MODEL_PROVIDERS_DICT = {
|
|||
"Anthropic": {
|
||||
"fields": ANTHROPIC_FIELDS,
|
||||
"inputs": ANTHROPIC_INPUTS,
|
||||
"prefix": "anthropic_",
|
||||
"prefix": "",
|
||||
"component_class": AnthropicModelComponent(),
|
||||
},
|
||||
"NVIDIA": {
|
||||
"fields": NVIDIA_FIELDS,
|
||||
"inputs": NVIDIA_INPUTS,
|
||||
"prefix": "nvidia_",
|
||||
"prefix": "",
|
||||
"component_class": NVIDIAModelComponent(),
|
||||
},
|
||||
"Amazon Bedrock": {
|
||||
"fields": AMAZON_BEDROCK_FIELDS,
|
||||
"inputs": AMAZON_BEDROCK_INPUTS,
|
||||
"prefix": "amazon_bedrock_",
|
||||
"prefix": "",
|
||||
"component_class": AmazonBedrockComponent(),
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -129,8 +129,16 @@ class AgentComponent(ToolCallingAgentComponent):
|
|||
return build_config
|
||||
|
||||
def update_build_config(self, build_config: dotdict, field_value: str, field_name: str | None = None) -> dotdict:
|
||||
# Iterate over all providers in the MODEL_PROVIDERS_DICT
|
||||
# Existing logic for updating build_config
|
||||
if field_name == "agent_llm":
|
||||
# Define provider configurations as (fields_to_add, fields_to_delete)
|
||||
provider_info = MODEL_PROVIDERS_DICT.get(field_value)
|
||||
if provider_info:
|
||||
component_class = provider_info.get("component_class")
|
||||
if component_class and hasattr(component_class, "update_build_config"):
|
||||
# Call the component class's update_build_config method
|
||||
build_config = component_class.update_build_config(build_config, field_value, field_name)
|
||||
|
||||
provider_configs: dict[str, tuple[dict, list[dict]]] = {
|
||||
provider: (
|
||||
MODEL_PROVIDERS_DICT[provider]["fields"],
|
||||
|
|
@ -142,7 +150,6 @@ class AgentComponent(ToolCallingAgentComponent):
|
|||
)
|
||||
for provider in MODEL_PROVIDERS_DICT
|
||||
}
|
||||
|
||||
if field_value in provider_configs:
|
||||
fields_to_add, fields_to_delete = provider_configs[field_value]
|
||||
|
||||
|
|
@ -170,7 +177,6 @@ class AgentComponent(ToolCallingAgentComponent):
|
|||
input_types=["LanguageModel"],
|
||||
)
|
||||
build_config.update({"agent_llm": custom_component.to_dict()})
|
||||
|
||||
# Update input types for all fields
|
||||
build_config = self.update_input_types(build_config)
|
||||
|
||||
|
|
@ -192,5 +198,16 @@ class AgentComponent(ToolCallingAgentComponent):
|
|||
if missing_keys:
|
||||
msg = f"Missing required keys in build_config: {missing_keys}"
|
||||
raise ValueError(msg)
|
||||
if isinstance(self.agent_llm, str) and self.agent_llm in MODEL_PROVIDERS_DICT:
|
||||
provider_info = MODEL_PROVIDERS_DICT.get(self.agent_llm)
|
||||
if provider_info:
|
||||
component_class = provider_info.get("component_class")
|
||||
prefix = provider_info.get("prefix")
|
||||
if component_class and hasattr(component_class, "update_build_config"):
|
||||
# Call each component class's update_build_config method
|
||||
# remove the prefix from the field_name
|
||||
if isinstance(field_name, str) and isinstance(prefix, str):
|
||||
field_name = field_name.replace(prefix, "")
|
||||
build_config = component_class.update_build_config(build_config, field_value, field_name)
|
||||
|
||||
return build_config
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue