feat: Add Multi-Model Provider Support to Agent Component (#4416)
* Add Multi-Model Provider Support to Agent Component - Integrated model provider constants from `model_input_constants.py` into the Agent component to support multiple LLM providers - Added dynamic field management for different model providers (OpenAI, Azure, Groq, Anthropic, NVIDIA) - Implemented a dropdown for model provider selection with automatic input field updates * sorted list * Update agent.py making custom separate from sort * chore: remove unit test --------- Co-authored-by: italojohnny <italojohnnydosanjos@gmail.com>
This commit is contained in:
parent
16924958da
commit
ae7d037fe1
6 changed files with 134 additions and 131 deletions
|
|
@ -1,17 +0,0 @@
|
|||
class ModelConstants:
|
||||
"""Class to hold model-related constants. To solve circular import issue."""
|
||||
|
||||
PROVIDER_NAMES: list[str] = []
|
||||
MODEL_INFO: dict[str, dict[str, str | list]] = {} # Adjusted type hint
|
||||
|
||||
@staticmethod
|
||||
def initialize():
|
||||
from langflow.base.models.model_utils import get_model_info # Delayed import
|
||||
|
||||
model_info = get_model_info()
|
||||
ModelConstants.MODEL_INFO = model_info
|
||||
ModelConstants.PROVIDER_NAMES = [
|
||||
str(model.get("display_name"))
|
||||
for model in model_info.values()
|
||||
if isinstance(model.get("display_name"), str)
|
||||
]
|
||||
|
|
@ -0,0 +1,71 @@
|
|||
from langflow.base.models.model import LCModelComponent
|
||||
from langflow.components.models.anthropic import AnthropicModelComponent
|
||||
from langflow.components.models.azure_openai import AzureChatOpenAIComponent
|
||||
from langflow.components.models.groq import GroqModel
|
||||
from langflow.components.models.nvidia import NVIDIAModelComponent
|
||||
from langflow.components.models.openai import OpenAIModelComponent
|
||||
|
||||
|
||||
def get_filtered_inputs(component_class):
|
||||
base_input_names = {field.name for field in LCModelComponent._base_inputs}
|
||||
return [
|
||||
set_advanced_true(input_) if input_.name == "temperature" else input_
|
||||
for input_ in component_class().inputs
|
||||
if input_.name not in base_input_names
|
||||
]
|
||||
|
||||
|
||||
def set_advanced_true(component_input):
|
||||
component_input.advanced = True
|
||||
return component_input
|
||||
|
||||
|
||||
def create_input_fields_dict(inputs, prefix):
|
||||
return {f"{prefix}_{input_.name}": input_ for input_ in inputs}
|
||||
|
||||
|
||||
OPENAI_INPUTS = get_filtered_inputs(OpenAIModelComponent)
|
||||
AZURE_INPUTS = get_filtered_inputs(AzureChatOpenAIComponent)
|
||||
GROQ_INPUTS = get_filtered_inputs(GroqModel)
|
||||
ANTHROPIC_INPUTS = get_filtered_inputs(AnthropicModelComponent)
|
||||
NVIDIA_INPUTS = get_filtered_inputs(NVIDIAModelComponent)
|
||||
|
||||
|
||||
OPENAI_FIELDS = {input_.name: input_ for input_ in OPENAI_INPUTS}
|
||||
|
||||
|
||||
AZURE_FIELDS = create_input_fields_dict(AZURE_INPUTS, "azure")
|
||||
GROQ_FIELDS = create_input_fields_dict(GROQ_INPUTS, "groq")
|
||||
ANTHROPIC_FIELDS = create_input_fields_dict(ANTHROPIC_INPUTS, "anthropic")
|
||||
NVIDIA_FIELDS = create_input_fields_dict(NVIDIA_INPUTS, "nvidia")
|
||||
|
||||
MODEL_PROVIDERS = ["Azure OpenAI", "OpenAI", "Groq", "Anthropic", "NVIDIA"]
|
||||
|
||||
MODEL_PROVIDERS_DICT = {
|
||||
"Azure OpenAI": {
|
||||
"fields": AZURE_FIELDS,
|
||||
"inputs": AZURE_INPUTS,
|
||||
"prefix": "azure_",
|
||||
"component_class": AzureChatOpenAIComponent(),
|
||||
},
|
||||
"OpenAI": {
|
||||
"fields": OPENAI_FIELDS,
|
||||
"inputs": OPENAI_INPUTS,
|
||||
"prefix": "",
|
||||
"component_class": OpenAIModelComponent(),
|
||||
},
|
||||
"Groq": {"fields": GROQ_FIELDS, "inputs": GROQ_INPUTS, "prefix": "groq_", "component_class": GroqModel()},
|
||||
"Anthropic": {
|
||||
"fields": ANTHROPIC_FIELDS,
|
||||
"inputs": ANTHROPIC_INPUTS,
|
||||
"prefix": "anthropic_",
|
||||
"component_class": AnthropicModelComponent(),
|
||||
},
|
||||
"NVIDIA": {
|
||||
"fields": NVIDIA_FIELDS,
|
||||
"inputs": NVIDIA_INPUTS,
|
||||
"prefix": "nvidia_",
|
||||
"component_class": NVIDIAModelComponent(),
|
||||
},
|
||||
}
|
||||
ALL_PROVIDER_FIELDS: list[str] = [field for provider in MODEL_PROVIDERS_DICT.values() for field in provider["fields"]]
|
||||
|
|
@ -1,31 +0,0 @@
|
|||
import importlib
|
||||
|
||||
from langflow.base.models.model import LCModelComponent
|
||||
from langflow.inputs.inputs import InputTypes
|
||||
|
||||
|
||||
def get_model_info() -> dict[str, dict[str, str | list[InputTypes]]]:
|
||||
"""Get inputs for all model components."""
|
||||
model_inputs = {}
|
||||
models_module = importlib.import_module("langflow.components.models")
|
||||
model_component_names = getattr(models_module, "__all__", [])
|
||||
|
||||
for name in model_component_names:
|
||||
if name in ("base", "DynamicLLMComponent"): # Skip the base module
|
||||
continue
|
||||
|
||||
component_class = getattr(models_module, name)
|
||||
if issubclass(component_class, LCModelComponent):
|
||||
component = component_class()
|
||||
base_input_names = {input_field.name for input_field in LCModelComponent._base_inputs}
|
||||
input_fields_list = [
|
||||
input_field for input_field in component.inputs if input_field.name not in base_input_names
|
||||
]
|
||||
component_display_name = component.display_name
|
||||
model_inputs[name] = {
|
||||
"display_name": component_display_name,
|
||||
"inputs": input_fields_list,
|
||||
"icon": component.icon,
|
||||
}
|
||||
|
||||
return model_inputs
|
||||
|
|
@ -1,9 +1,7 @@
|
|||
from langflow.base.agents.agent import LCToolsAgentComponent
|
||||
from langflow.base.models.model import LCModelComponent
|
||||
from langflow.base.models.model_input_constants import ALL_PROVIDER_FIELDS, MODEL_PROVIDERS_DICT
|
||||
from langflow.components.agents.tool_calling import ToolCallingAgentComponent
|
||||
from langflow.components.helpers.memory import MemoryComponent
|
||||
from langflow.components.models.azure_openai import AzureChatOpenAIComponent
|
||||
from langflow.components.models.openai import OpenAIModelComponent
|
||||
from langflow.io import (
|
||||
DropdownInput,
|
||||
MultilineInput,
|
||||
|
|
@ -25,30 +23,19 @@ class AgentComponent(ToolCallingAgentComponent):
|
|||
beta = True
|
||||
name = "Agent"
|
||||
|
||||
azure_inputs = [
|
||||
set_advanced_true(component_input) if component_input.name == "temperature" else component_input
|
||||
for component_input in AzureChatOpenAIComponent().inputs
|
||||
if component_input.name not in [input_field.name for input_field in LCModelComponent._base_inputs]
|
||||
]
|
||||
openai_inputs = [
|
||||
set_advanced_true(component_input) if component_input.name == "temperature" else component_input
|
||||
for component_input in OpenAIModelComponent().inputs
|
||||
if component_input.name not in [input_field.name for input_field in LCModelComponent._base_inputs]
|
||||
]
|
||||
|
||||
memory_inputs = [set_advanced_true(component_input) for component_input in MemoryComponent().inputs]
|
||||
|
||||
inputs = [
|
||||
DropdownInput(
|
||||
name="agent_llm",
|
||||
display_name="Model Provider",
|
||||
options=["Azure OpenAI", "OpenAI", "Custom"],
|
||||
options=[*sorted(MODEL_PROVIDERS_DICT.keys()), "Custom"],
|
||||
value="OpenAI",
|
||||
real_time_refresh=True,
|
||||
refresh_button=True,
|
||||
input_types=[],
|
||||
),
|
||||
*openai_inputs,
|
||||
*MODEL_PROVIDERS_DICT["OpenAI"]["inputs"],
|
||||
MultilineInput(
|
||||
name="system_prompt",
|
||||
display_name="Agent Instructions",
|
||||
|
|
@ -86,71 +73,89 @@ class AgentComponent(ToolCallingAgentComponent):
|
|||
return MemoryComponent().set(**memory_kwargs).retrieve_messages()
|
||||
|
||||
def get_llm(self):
|
||||
try:
|
||||
if self.agent_llm == "OpenAI":
|
||||
return self._build_llm_model(OpenAIModelComponent(), self.openai_inputs)
|
||||
if self.agent_llm == "Azure OpenAI":
|
||||
return self._build_llm_model(AzureChatOpenAIComponent(), self.azure_inputs, prefix="azure_param_")
|
||||
except Exception as e:
|
||||
msg = f"Error building {self.agent_llm} language model"
|
||||
raise ValueError(msg) from e
|
||||
if isinstance(self.agent_llm, str):
|
||||
try:
|
||||
provider_info = MODEL_PROVIDERS_DICT.get(self.agent_llm)
|
||||
if provider_info:
|
||||
component_class = provider_info.get("component_class")
|
||||
inputs = provider_info.get("inputs")
|
||||
prefix = provider_info.get("prefix", "")
|
||||
return self._build_llm_model(component_class, inputs, prefix)
|
||||
except Exception as e:
|
||||
msg = f"Error building {self.agent_llm} language model"
|
||||
raise ValueError(msg) from e
|
||||
return self.agent_llm
|
||||
|
||||
def _build_llm_model(self, component, inputs, prefix=""):
|
||||
return component.set(
|
||||
**{component_input.name: getattr(self, f"{prefix}{component_input.name}") for component_input in inputs}
|
||||
).build_model()
|
||||
model_kwargs = {input_.name: getattr(self, f"{prefix}{input_.name}") for input_ in inputs}
|
||||
return component.set(**model_kwargs).build_model()
|
||||
|
||||
def delete_fields(self, build_config, fields):
|
||||
def delete_fields(self, build_config: dotdict, fields: dict | list[str]) -> None:
|
||||
"""Delete specified fields from build_config."""
|
||||
for field in fields:
|
||||
build_config.pop(field, None)
|
||||
|
||||
def update_build_config(self, build_config: dotdict, field_value: str, field_name: str | None = None):
|
||||
def update_input_types(self, build_config: dotdict) -> dotdict:
|
||||
"""Update input types for all fields in build_config."""
|
||||
for key, value in build_config.items():
|
||||
if isinstance(value, dict):
|
||||
if value.get("input_types") is None:
|
||||
build_config[key]["input_types"] = []
|
||||
elif hasattr(value, "input_types") and value.input_types is None:
|
||||
value.input_types = []
|
||||
return build_config
|
||||
|
||||
def update_build_config(self, build_config: dotdict, field_value: str, field_name: str | None = None) -> dotdict:
|
||||
if field_name == "agent_llm":
|
||||
openai_fields = {component_input.name: component_input for component_input in self.openai_inputs}
|
||||
azure_fields = {
|
||||
f"azure_param_{component_input.name}": component_input for component_input in self.azure_inputs
|
||||
# Define provider configurations as (fields_to_add, fields_to_delete)
|
||||
provider_configs: dict[str, tuple[dict, list[dict]]] = {
|
||||
provider: (
|
||||
MODEL_PROVIDERS_DICT[provider]["fields"],
|
||||
[
|
||||
MODEL_PROVIDERS_DICT[other_provider]["fields"]
|
||||
for other_provider in MODEL_PROVIDERS_DICT
|
||||
if other_provider != provider
|
||||
],
|
||||
)
|
||||
for provider in MODEL_PROVIDERS_DICT
|
||||
}
|
||||
|
||||
if field_value == "OpenAI":
|
||||
self.delete_fields(build_config, {**azure_fields})
|
||||
if not any(field in build_config for field in openai_fields):
|
||||
build_config.update(openai_fields)
|
||||
build_config["agent_llm"]["input_types"] = []
|
||||
build_config = self.update_input_types(build_config)
|
||||
if field_value in provider_configs:
|
||||
fields_to_add, fields_to_delete = provider_configs[field_value]
|
||||
|
||||
elif field_value == "Azure OpenAI":
|
||||
self.delete_fields(build_config, {**openai_fields})
|
||||
build_config.update(azure_fields)
|
||||
# Delete fields from other providers
|
||||
for fields in fields_to_delete:
|
||||
self.delete_fields(build_config, fields)
|
||||
|
||||
# Add provider-specific fields
|
||||
if field_value == "OpenAI" and not any(field in build_config for field in fields_to_add):
|
||||
build_config.update(fields_to_add)
|
||||
else:
|
||||
build_config.update(fields_to_add)
|
||||
# Reset input types for agent_llm
|
||||
build_config["agent_llm"]["input_types"] = []
|
||||
build_config = self.update_input_types(build_config)
|
||||
elif field_value == "Custom":
|
||||
self.delete_fields(build_config, {**openai_fields})
|
||||
self.delete_fields(build_config, {**azure_fields})
|
||||
new_component = DropdownInput(
|
||||
# Delete all provider fields
|
||||
self.delete_fields(build_config, ALL_PROVIDER_FIELDS)
|
||||
# Update with custom component
|
||||
custom_component = DropdownInput(
|
||||
name="agent_llm",
|
||||
display_name="Language Model",
|
||||
options=["Azure OpenAI", "OpenAI", "Custom"],
|
||||
options=[*sorted(MODEL_PROVIDERS_DICT.keys()), "Custom"],
|
||||
value="Custom",
|
||||
real_time_refresh=True,
|
||||
input_types=["LanguageModel"],
|
||||
)
|
||||
build_config.update({"agent_llm": new_component.to_dict()})
|
||||
build_config = self.update_input_types(build_config)
|
||||
build_config.update({"agent_llm": custom_component.to_dict()})
|
||||
|
||||
# Update input types for all fields
|
||||
build_config = self.update_input_types(build_config)
|
||||
|
||||
# Validate required keys
|
||||
default_keys = ["code", "_type", "agent_llm", "tools", "input_value"]
|
||||
missing_keys = [key for key in default_keys if key not in build_config]
|
||||
if missing_keys:
|
||||
msg = f"Missing required keys in build_config: {missing_keys}"
|
||||
raise ValueError(msg)
|
||||
return build_config
|
||||
|
||||
def update_input_types(self, build_config):
|
||||
for key, value in build_config.items():
|
||||
# Check if the value is a dictionary
|
||||
if isinstance(value, dict):
|
||||
if value.get("input_types") is None:
|
||||
build_config[key]["input_types"] = []
|
||||
# Check if the value has an attribute 'input_types' and it is None
|
||||
elif hasattr(value, "input_types") and value.input_types is None:
|
||||
value.input_types = []
|
||||
return build_config
|
||||
|
|
|
|||
|
|
@ -1,25 +0,0 @@
|
|||
from src.backend.base.langflow.base.models.model_constants import ModelConstants
|
||||
|
||||
|
||||
def test_provider_names():
|
||||
# Initialize the ModelConstants
|
||||
ModelConstants.initialize()
|
||||
|
||||
# Expected provider names
|
||||
expected_provider_names = [
|
||||
"AIML",
|
||||
"Amazon Bedrock",
|
||||
"Anthropic",
|
||||
"Azure OpenAI",
|
||||
"Ollama",
|
||||
"Vertex AI",
|
||||
"Cohere",
|
||||
"Google Generative AI",
|
||||
"HuggingFace",
|
||||
"OpenAI",
|
||||
"Perplexity",
|
||||
"Qianfan",
|
||||
]
|
||||
|
||||
# Assert that the provider names match the expected list
|
||||
assert expected_provider_names == ModelConstants.PROVIDER_NAMES
|
||||
Loading…
Add table
Add a link
Reference in a new issue