diff --git a/src/backend/base/langflow/base/models/model_input_constants.py b/src/backend/base/langflow/base/models/model_input_constants.py index ae9b79133..1753c4a4d 100644 --- a/src/backend/base/langflow/base/models/model_input_constants.py +++ b/src/backend/base/langflow/base/models/model_input_constants.py @@ -22,7 +22,7 @@ def set_advanced_true(component_input): def create_input_fields_dict(inputs, prefix): - return {f"{prefix}_{input_.name}": input_ for input_ in inputs} + return {f"{prefix}{input_.name}": input_ for input_ in inputs} OPENAI_INPUTS = get_filtered_inputs(OpenAIModelComponent) @@ -35,11 +35,11 @@ AMAZON_BEDROCK_INPUTS = get_filtered_inputs(AmazonBedrockComponent) OPENAI_FIELDS = {input_.name: input_ for input_ in OPENAI_INPUTS} -AZURE_FIELDS = create_input_fields_dict(AZURE_INPUTS, "azure") -GROQ_FIELDS = create_input_fields_dict(GROQ_INPUTS, "groq") -ANTHROPIC_FIELDS = create_input_fields_dict(ANTHROPIC_INPUTS, "anthropic") -NVIDIA_FIELDS = create_input_fields_dict(NVIDIA_INPUTS, "nvidia") -AMAZON_BEDROCK_FIELDS = create_input_fields_dict(AMAZON_BEDROCK_INPUTS, "amazon_bedrock") +AZURE_FIELDS = create_input_fields_dict(AZURE_INPUTS, "") +GROQ_FIELDS = create_input_fields_dict(GROQ_INPUTS, "") +ANTHROPIC_FIELDS = create_input_fields_dict(ANTHROPIC_INPUTS, "") +NVIDIA_FIELDS = create_input_fields_dict(NVIDIA_INPUTS, "") +AMAZON_BEDROCK_FIELDS = create_input_fields_dict(AMAZON_BEDROCK_INPUTS, "") MODEL_PROVIDERS = ["Azure OpenAI", "OpenAI", "Groq", "Anthropic", "NVIDIA", "Amazon Bedrock"] @@ -47,7 +47,7 @@ MODEL_PROVIDERS_DICT = { "Azure OpenAI": { "fields": AZURE_FIELDS, "inputs": AZURE_INPUTS, - "prefix": "azure_", + "prefix": "", "component_class": AzureChatOpenAIComponent(), }, "OpenAI": { @@ -60,19 +60,19 @@ MODEL_PROVIDERS_DICT = { "Anthropic": { "fields": ANTHROPIC_FIELDS, "inputs": ANTHROPIC_INPUTS, - "prefix": "anthropic_", + "prefix": "", "component_class": AnthropicModelComponent(), }, "NVIDIA": { "fields": NVIDIA_FIELDS, "inputs": NVIDIA_INPUTS, - "prefix": "nvidia_", + "prefix": "", "component_class": NVIDIAModelComponent(), }, "Amazon Bedrock": { "fields": AMAZON_BEDROCK_FIELDS, "inputs": AMAZON_BEDROCK_INPUTS, - "prefix": "amazon_bedrock_", + "prefix": "", "component_class": AmazonBedrockComponent(), }, } diff --git a/src/backend/base/langflow/components/agents/agent.py b/src/backend/base/langflow/components/agents/agent.py index e23876dd1..62b743795 100644 --- a/src/backend/base/langflow/components/agents/agent.py +++ b/src/backend/base/langflow/components/agents/agent.py @@ -129,8 +129,16 @@ class AgentComponent(ToolCallingAgentComponent): return build_config def update_build_config(self, build_config: dotdict, field_value: str, field_name: str | None = None) -> dotdict: + # Iterate over all providers in the MODEL_PROVIDERS_DICT + # Existing logic for updating build_config if field_name == "agent_llm": - # Define provider configurations as (fields_to_add, fields_to_delete) + provider_info = MODEL_PROVIDERS_DICT.get(field_value) + if provider_info: + component_class = provider_info.get("component_class") + if component_class and hasattr(component_class, "update_build_config"): + # Call the component class's update_build_config method + build_config = component_class.update_build_config(build_config, field_value, field_name) + provider_configs: dict[str, tuple[dict, list[dict]]] = { provider: ( MODEL_PROVIDERS_DICT[provider]["fields"], @@ -142,7 +150,6 @@ class AgentComponent(ToolCallingAgentComponent): ) for provider in MODEL_PROVIDERS_DICT } - if field_value in provider_configs: fields_to_add, fields_to_delete = provider_configs[field_value] @@ -170,7 +177,6 @@ class AgentComponent(ToolCallingAgentComponent): input_types=["LanguageModel"], ) build_config.update({"agent_llm": custom_component.to_dict()}) - # Update input types for all fields build_config = self.update_input_types(build_config) @@ -192,5 +198,16 @@ class AgentComponent(ToolCallingAgentComponent): if missing_keys: msg = f"Missing required keys in build_config: {missing_keys}" raise ValueError(msg) + if isinstance(self.agent_llm, str) and self.agent_llm in MODEL_PROVIDERS_DICT: + provider_info = MODEL_PROVIDERS_DICT.get(self.agent_llm) + if provider_info: + component_class = provider_info.get("component_class") + prefix = provider_info.get("prefix") + if component_class and hasattr(component_class, "update_build_config"): + # Call each component class's update_build_config method + # remove the prefix from the field_name + if isinstance(field_name, str) and isinstance(prefix, str): + field_name = field_name.replace(prefix, "") + build_config = component_class.update_build_config(build_config, field_value, field_name) return build_config