feat: Add Multi-Model Provider Support to Agent Component (#4416)

* Add Multi-Model Provider Support to Agent Component

- Integrated model provider constants from `model_input_constants.py` into the Agent component to support multiple LLM providers
- Added dynamic field management for different model providers (OpenAI, Azure, Groq, Anthropic, NVIDIA)
- Implemented a dropdown for model provider selection with automatic input field updates

* sorted list

* Update agent.py

making custom separate from sort

* chore: remove unit test

---------

Co-authored-by: italojohnny <italojohnnydosanjos@gmail.com>
This commit is contained in:
Edwin Jose 2024-11-06 11:15:22 -05:00 committed by GitHub
commit ae7d037fe1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 134 additions and 131 deletions

View file

@ -1,17 +0,0 @@
class ModelConstants:
"""Class to hold model-related constants. To solve circular import issue."""
PROVIDER_NAMES: list[str] = []
MODEL_INFO: dict[str, dict[str, str | list]] = {} # Adjusted type hint
@staticmethod
def initialize():
from langflow.base.models.model_utils import get_model_info # Delayed import
model_info = get_model_info()
ModelConstants.MODEL_INFO = model_info
ModelConstants.PROVIDER_NAMES = [
str(model.get("display_name"))
for model in model_info.values()
if isinstance(model.get("display_name"), str)
]

View file

@ -0,0 +1,71 @@
from langflow.base.models.model import LCModelComponent
from langflow.components.models.anthropic import AnthropicModelComponent
from langflow.components.models.azure_openai import AzureChatOpenAIComponent
from langflow.components.models.groq import GroqModel
from langflow.components.models.nvidia import NVIDIAModelComponent
from langflow.components.models.openai import OpenAIModelComponent
def get_filtered_inputs(component_class):
base_input_names = {field.name for field in LCModelComponent._base_inputs}
return [
set_advanced_true(input_) if input_.name == "temperature" else input_
for input_ in component_class().inputs
if input_.name not in base_input_names
]
def set_advanced_true(component_input):
component_input.advanced = True
return component_input
def create_input_fields_dict(inputs, prefix):
return {f"{prefix}_{input_.name}": input_ for input_ in inputs}
OPENAI_INPUTS = get_filtered_inputs(OpenAIModelComponent)
AZURE_INPUTS = get_filtered_inputs(AzureChatOpenAIComponent)
GROQ_INPUTS = get_filtered_inputs(GroqModel)
ANTHROPIC_INPUTS = get_filtered_inputs(AnthropicModelComponent)
NVIDIA_INPUTS = get_filtered_inputs(NVIDIAModelComponent)
OPENAI_FIELDS = {input_.name: input_ for input_ in OPENAI_INPUTS}
AZURE_FIELDS = create_input_fields_dict(AZURE_INPUTS, "azure")
GROQ_FIELDS = create_input_fields_dict(GROQ_INPUTS, "groq")
ANTHROPIC_FIELDS = create_input_fields_dict(ANTHROPIC_INPUTS, "anthropic")
NVIDIA_FIELDS = create_input_fields_dict(NVIDIA_INPUTS, "nvidia")
MODEL_PROVIDERS = ["Azure OpenAI", "OpenAI", "Groq", "Anthropic", "NVIDIA"]
MODEL_PROVIDERS_DICT = {
"Azure OpenAI": {
"fields": AZURE_FIELDS,
"inputs": AZURE_INPUTS,
"prefix": "azure_",
"component_class": AzureChatOpenAIComponent(),
},
"OpenAI": {
"fields": OPENAI_FIELDS,
"inputs": OPENAI_INPUTS,
"prefix": "",
"component_class": OpenAIModelComponent(),
},
"Groq": {"fields": GROQ_FIELDS, "inputs": GROQ_INPUTS, "prefix": "groq_", "component_class": GroqModel()},
"Anthropic": {
"fields": ANTHROPIC_FIELDS,
"inputs": ANTHROPIC_INPUTS,
"prefix": "anthropic_",
"component_class": AnthropicModelComponent(),
},
"NVIDIA": {
"fields": NVIDIA_FIELDS,
"inputs": NVIDIA_INPUTS,
"prefix": "nvidia_",
"component_class": NVIDIAModelComponent(),
},
}
ALL_PROVIDER_FIELDS: list[str] = [field for provider in MODEL_PROVIDERS_DICT.values() for field in provider["fields"]]

View file

@ -1,31 +0,0 @@
import importlib
from langflow.base.models.model import LCModelComponent
from langflow.inputs.inputs import InputTypes
def get_model_info() -> dict[str, dict[str, str | list[InputTypes]]]:
"""Get inputs for all model components."""
model_inputs = {}
models_module = importlib.import_module("langflow.components.models")
model_component_names = getattr(models_module, "__all__", [])
for name in model_component_names:
if name in ("base", "DynamicLLMComponent"): # Skip the base module
continue
component_class = getattr(models_module, name)
if issubclass(component_class, LCModelComponent):
component = component_class()
base_input_names = {input_field.name for input_field in LCModelComponent._base_inputs}
input_fields_list = [
input_field for input_field in component.inputs if input_field.name not in base_input_names
]
component_display_name = component.display_name
model_inputs[name] = {
"display_name": component_display_name,
"inputs": input_fields_list,
"icon": component.icon,
}
return model_inputs

View file

@ -1,9 +1,7 @@
from langflow.base.agents.agent import LCToolsAgentComponent
from langflow.base.models.model import LCModelComponent
from langflow.base.models.model_input_constants import ALL_PROVIDER_FIELDS, MODEL_PROVIDERS_DICT
from langflow.components.agents.tool_calling import ToolCallingAgentComponent
from langflow.components.helpers.memory import MemoryComponent
from langflow.components.models.azure_openai import AzureChatOpenAIComponent
from langflow.components.models.openai import OpenAIModelComponent
from langflow.io import (
DropdownInput,
MultilineInput,
@ -25,30 +23,19 @@ class AgentComponent(ToolCallingAgentComponent):
beta = True
name = "Agent"
azure_inputs = [
set_advanced_true(component_input) if component_input.name == "temperature" else component_input
for component_input in AzureChatOpenAIComponent().inputs
if component_input.name not in [input_field.name for input_field in LCModelComponent._base_inputs]
]
openai_inputs = [
set_advanced_true(component_input) if component_input.name == "temperature" else component_input
for component_input in OpenAIModelComponent().inputs
if component_input.name not in [input_field.name for input_field in LCModelComponent._base_inputs]
]
memory_inputs = [set_advanced_true(component_input) for component_input in MemoryComponent().inputs]
inputs = [
DropdownInput(
name="agent_llm",
display_name="Model Provider",
options=["Azure OpenAI", "OpenAI", "Custom"],
options=[*sorted(MODEL_PROVIDERS_DICT.keys()), "Custom"],
value="OpenAI",
real_time_refresh=True,
refresh_button=True,
input_types=[],
),
*openai_inputs,
*MODEL_PROVIDERS_DICT["OpenAI"]["inputs"],
MultilineInput(
name="system_prompt",
display_name="Agent Instructions",
@ -86,71 +73,89 @@ class AgentComponent(ToolCallingAgentComponent):
return MemoryComponent().set(**memory_kwargs).retrieve_messages()
def get_llm(self):
try:
if self.agent_llm == "OpenAI":
return self._build_llm_model(OpenAIModelComponent(), self.openai_inputs)
if self.agent_llm == "Azure OpenAI":
return self._build_llm_model(AzureChatOpenAIComponent(), self.azure_inputs, prefix="azure_param_")
except Exception as e:
msg = f"Error building {self.agent_llm} language model"
raise ValueError(msg) from e
if isinstance(self.agent_llm, str):
try:
provider_info = MODEL_PROVIDERS_DICT.get(self.agent_llm)
if provider_info:
component_class = provider_info.get("component_class")
inputs = provider_info.get("inputs")
prefix = provider_info.get("prefix", "")
return self._build_llm_model(component_class, inputs, prefix)
except Exception as e:
msg = f"Error building {self.agent_llm} language model"
raise ValueError(msg) from e
return self.agent_llm
def _build_llm_model(self, component, inputs, prefix=""):
return component.set(
**{component_input.name: getattr(self, f"{prefix}{component_input.name}") for component_input in inputs}
).build_model()
model_kwargs = {input_.name: getattr(self, f"{prefix}{input_.name}") for input_ in inputs}
return component.set(**model_kwargs).build_model()
def delete_fields(self, build_config, fields):
def delete_fields(self, build_config: dotdict, fields: dict | list[str]) -> None:
"""Delete specified fields from build_config."""
for field in fields:
build_config.pop(field, None)
def update_build_config(self, build_config: dotdict, field_value: str, field_name: str | None = None):
def update_input_types(self, build_config: dotdict) -> dotdict:
"""Update input types for all fields in build_config."""
for key, value in build_config.items():
if isinstance(value, dict):
if value.get("input_types") is None:
build_config[key]["input_types"] = []
elif hasattr(value, "input_types") and value.input_types is None:
value.input_types = []
return build_config
def update_build_config(self, build_config: dotdict, field_value: str, field_name: str | None = None) -> dotdict:
if field_name == "agent_llm":
openai_fields = {component_input.name: component_input for component_input in self.openai_inputs}
azure_fields = {
f"azure_param_{component_input.name}": component_input for component_input in self.azure_inputs
# Define provider configurations as (fields_to_add, fields_to_delete)
provider_configs: dict[str, tuple[dict, list[dict]]] = {
provider: (
MODEL_PROVIDERS_DICT[provider]["fields"],
[
MODEL_PROVIDERS_DICT[other_provider]["fields"]
for other_provider in MODEL_PROVIDERS_DICT
if other_provider != provider
],
)
for provider in MODEL_PROVIDERS_DICT
}
if field_value == "OpenAI":
self.delete_fields(build_config, {**azure_fields})
if not any(field in build_config for field in openai_fields):
build_config.update(openai_fields)
build_config["agent_llm"]["input_types"] = []
build_config = self.update_input_types(build_config)
if field_value in provider_configs:
fields_to_add, fields_to_delete = provider_configs[field_value]
elif field_value == "Azure OpenAI":
self.delete_fields(build_config, {**openai_fields})
build_config.update(azure_fields)
# Delete fields from other providers
for fields in fields_to_delete:
self.delete_fields(build_config, fields)
# Add provider-specific fields
if field_value == "OpenAI" and not any(field in build_config for field in fields_to_add):
build_config.update(fields_to_add)
else:
build_config.update(fields_to_add)
# Reset input types for agent_llm
build_config["agent_llm"]["input_types"] = []
build_config = self.update_input_types(build_config)
elif field_value == "Custom":
self.delete_fields(build_config, {**openai_fields})
self.delete_fields(build_config, {**azure_fields})
new_component = DropdownInput(
# Delete all provider fields
self.delete_fields(build_config, ALL_PROVIDER_FIELDS)
# Update with custom component
custom_component = DropdownInput(
name="agent_llm",
display_name="Language Model",
options=["Azure OpenAI", "OpenAI", "Custom"],
options=[*sorted(MODEL_PROVIDERS_DICT.keys()), "Custom"],
value="Custom",
real_time_refresh=True,
input_types=["LanguageModel"],
)
build_config.update({"agent_llm": new_component.to_dict()})
build_config = self.update_input_types(build_config)
build_config.update({"agent_llm": custom_component.to_dict()})
# Update input types for all fields
build_config = self.update_input_types(build_config)
# Validate required keys
default_keys = ["code", "_type", "agent_llm", "tools", "input_value"]
missing_keys = [key for key in default_keys if key not in build_config]
if missing_keys:
msg = f"Missing required keys in build_config: {missing_keys}"
raise ValueError(msg)
return build_config
def update_input_types(self, build_config):
for key, value in build_config.items():
# Check if the value is a dictionary
if isinstance(value, dict):
if value.get("input_types") is None:
build_config[key]["input_types"] = []
# Check if the value has an attribute 'input_types' and it is None
elif hasattr(value, "input_types") and value.input_types is None:
value.input_types = []
return build_config

View file

@ -1,25 +0,0 @@
from src.backend.base.langflow.base.models.model_constants import ModelConstants
def test_provider_names():
# Initialize the ModelConstants
ModelConstants.initialize()
# Expected provider names
expected_provider_names = [
"AIML",
"Amazon Bedrock",
"Anthropic",
"Azure OpenAI",
"Ollama",
"Vertex AI",
"Cohere",
"Google Generative AI",
"HuggingFace",
"OpenAI",
"Perplexity",
"Qianfan",
]
# Assert that the provider names match the expected list
assert expected_provider_names == ModelConstants.PROVIDER_NAMES