fix: Agent Component update_build_config, Nvidia and other model dynamic parameter loading (#4556)

* update in update_build_config

added support for dynamic models.

* [autofix.ci] apply automated fixes

* Add type checks for 'field_name' and 'prefix' in agent config update

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>
Co-authored-by: Nadir J <31660040+NadirJ@users.noreply.github.com>
This commit is contained in:
Edwin Jose 2024-11-12 22:43:31 -05:00 committed by GitHub
commit 7b78a169d0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 30 additions and 13 deletions

View file

@ -22,7 +22,7 @@ def set_advanced_true(component_input):
def create_input_fields_dict(inputs, prefix):
return {f"{prefix}_{input_.name}": input_ for input_ in inputs}
return {f"{prefix}{input_.name}": input_ for input_ in inputs}
OPENAI_INPUTS = get_filtered_inputs(OpenAIModelComponent)
@ -35,11 +35,11 @@ AMAZON_BEDROCK_INPUTS = get_filtered_inputs(AmazonBedrockComponent)
OPENAI_FIELDS = {input_.name: input_ for input_ in OPENAI_INPUTS}
AZURE_FIELDS = create_input_fields_dict(AZURE_INPUTS, "azure")
GROQ_FIELDS = create_input_fields_dict(GROQ_INPUTS, "groq")
ANTHROPIC_FIELDS = create_input_fields_dict(ANTHROPIC_INPUTS, "anthropic")
NVIDIA_FIELDS = create_input_fields_dict(NVIDIA_INPUTS, "nvidia")
AMAZON_BEDROCK_FIELDS = create_input_fields_dict(AMAZON_BEDROCK_INPUTS, "amazon_bedrock")
AZURE_FIELDS = create_input_fields_dict(AZURE_INPUTS, "")
GROQ_FIELDS = create_input_fields_dict(GROQ_INPUTS, "")
ANTHROPIC_FIELDS = create_input_fields_dict(ANTHROPIC_INPUTS, "")
NVIDIA_FIELDS = create_input_fields_dict(NVIDIA_INPUTS, "")
AMAZON_BEDROCK_FIELDS = create_input_fields_dict(AMAZON_BEDROCK_INPUTS, "")
MODEL_PROVIDERS = ["Azure OpenAI", "OpenAI", "Groq", "Anthropic", "NVIDIA", "Amazon Bedrock"]
@ -47,7 +47,7 @@ MODEL_PROVIDERS_DICT = {
"Azure OpenAI": {
"fields": AZURE_FIELDS,
"inputs": AZURE_INPUTS,
"prefix": "azure_",
"prefix": "",
"component_class": AzureChatOpenAIComponent(),
},
"OpenAI": {
@ -60,19 +60,19 @@ MODEL_PROVIDERS_DICT = {
"Anthropic": {
"fields": ANTHROPIC_FIELDS,
"inputs": ANTHROPIC_INPUTS,
"prefix": "anthropic_",
"prefix": "",
"component_class": AnthropicModelComponent(),
},
"NVIDIA": {
"fields": NVIDIA_FIELDS,
"inputs": NVIDIA_INPUTS,
"prefix": "nvidia_",
"prefix": "",
"component_class": NVIDIAModelComponent(),
},
"Amazon Bedrock": {
"fields": AMAZON_BEDROCK_FIELDS,
"inputs": AMAZON_BEDROCK_INPUTS,
"prefix": "amazon_bedrock_",
"prefix": "",
"component_class": AmazonBedrockComponent(),
},
}

View file

@ -129,8 +129,16 @@ class AgentComponent(ToolCallingAgentComponent):
return build_config
def update_build_config(self, build_config: dotdict, field_value: str, field_name: str | None = None) -> dotdict:
# Iterate over all providers in the MODEL_PROVIDERS_DICT
# Existing logic for updating build_config
if field_name == "agent_llm":
# Define provider configurations as (fields_to_add, fields_to_delete)
provider_info = MODEL_PROVIDERS_DICT.get(field_value)
if provider_info:
component_class = provider_info.get("component_class")
if component_class and hasattr(component_class, "update_build_config"):
# Call the component class's update_build_config method
build_config = component_class.update_build_config(build_config, field_value, field_name)
provider_configs: dict[str, tuple[dict, list[dict]]] = {
provider: (
MODEL_PROVIDERS_DICT[provider]["fields"],
@ -142,7 +150,6 @@ class AgentComponent(ToolCallingAgentComponent):
)
for provider in MODEL_PROVIDERS_DICT
}
if field_value in provider_configs:
fields_to_add, fields_to_delete = provider_configs[field_value]
@ -170,7 +177,6 @@ class AgentComponent(ToolCallingAgentComponent):
input_types=["LanguageModel"],
)
build_config.update({"agent_llm": custom_component.to_dict()})
# Update input types for all fields
build_config = self.update_input_types(build_config)
@ -192,5 +198,16 @@ class AgentComponent(ToolCallingAgentComponent):
if missing_keys:
msg = f"Missing required keys in build_config: {missing_keys}"
raise ValueError(msg)
if isinstance(self.agent_llm, str) and self.agent_llm in MODEL_PROVIDERS_DICT:
provider_info = MODEL_PROVIDERS_DICT.get(self.agent_llm)
if provider_info:
component_class = provider_info.get("component_class")
prefix = provider_info.get("prefix")
if component_class and hasattr(component_class, "update_build_config"):
# Call each component class's update_build_config method
# remove the prefix from the field_name
if isinstance(field_name, str) and isinstance(prefix, str):
field_name = field_name.replace(prefix, "")
build_config = component_class.update_build_config(build_config, field_value, field_name)
return build_config