diff --git a/src/backend/base/langflow/base/models/model_constants.py b/src/backend/base/langflow/base/models/model_constants.py deleted file mode 100644 index aa1f05766..000000000 --- a/src/backend/base/langflow/base/models/model_constants.py +++ /dev/null @@ -1,17 +0,0 @@ -class ModelConstants: - """Class to hold model-related constants. To solve circular import issue.""" - - PROVIDER_NAMES: list[str] = [] - MODEL_INFO: dict[str, dict[str, str | list]] = {} # Adjusted type hint - - @staticmethod - def initialize(): - from langflow.base.models.model_utils import get_model_info # Delayed import - - model_info = get_model_info() - ModelConstants.MODEL_INFO = model_info - ModelConstants.PROVIDER_NAMES = [ - str(model.get("display_name")) - for model in model_info.values() - if isinstance(model.get("display_name"), str) - ] diff --git a/src/backend/base/langflow/base/models/model_input_constants.py b/src/backend/base/langflow/base/models/model_input_constants.py new file mode 100644 index 000000000..83099d44e --- /dev/null +++ b/src/backend/base/langflow/base/models/model_input_constants.py @@ -0,0 +1,71 @@ +from langflow.base.models.model import LCModelComponent +from langflow.components.models.anthropic import AnthropicModelComponent +from langflow.components.models.azure_openai import AzureChatOpenAIComponent +from langflow.components.models.groq import GroqModel +from langflow.components.models.nvidia import NVIDIAModelComponent +from langflow.components.models.openai import OpenAIModelComponent + + +def get_filtered_inputs(component_class): + base_input_names = {field.name for field in LCModelComponent._base_inputs} + return [ + set_advanced_true(input_) if input_.name == "temperature" else input_ + for input_ in component_class().inputs + if input_.name not in base_input_names + ] + + +def set_advanced_true(component_input): + component_input.advanced = True + return component_input + + +def create_input_fields_dict(inputs, prefix): + return {f"{prefix}_{input_.name}": input_ for input_ in inputs} + + +OPENAI_INPUTS = get_filtered_inputs(OpenAIModelComponent) +AZURE_INPUTS = get_filtered_inputs(AzureChatOpenAIComponent) +GROQ_INPUTS = get_filtered_inputs(GroqModel) +ANTHROPIC_INPUTS = get_filtered_inputs(AnthropicModelComponent) +NVIDIA_INPUTS = get_filtered_inputs(NVIDIAModelComponent) + + +OPENAI_FIELDS = {input_.name: input_ for input_ in OPENAI_INPUTS} + + +AZURE_FIELDS = create_input_fields_dict(AZURE_INPUTS, "azure") +GROQ_FIELDS = create_input_fields_dict(GROQ_INPUTS, "groq") +ANTHROPIC_FIELDS = create_input_fields_dict(ANTHROPIC_INPUTS, "anthropic") +NVIDIA_FIELDS = create_input_fields_dict(NVIDIA_INPUTS, "nvidia") + +MODEL_PROVIDERS = ["Azure OpenAI", "OpenAI", "Groq", "Anthropic", "NVIDIA"] + +MODEL_PROVIDERS_DICT = { + "Azure OpenAI": { + "fields": AZURE_FIELDS, + "inputs": AZURE_INPUTS, + "prefix": "azure_", + "component_class": AzureChatOpenAIComponent(), + }, + "OpenAI": { + "fields": OPENAI_FIELDS, + "inputs": OPENAI_INPUTS, + "prefix": "", + "component_class": OpenAIModelComponent(), + }, + "Groq": {"fields": GROQ_FIELDS, "inputs": GROQ_INPUTS, "prefix": "groq_", "component_class": GroqModel()}, + "Anthropic": { + "fields": ANTHROPIC_FIELDS, + "inputs": ANTHROPIC_INPUTS, + "prefix": "anthropic_", + "component_class": AnthropicModelComponent(), + }, + "NVIDIA": { + "fields": NVIDIA_FIELDS, + "inputs": NVIDIA_INPUTS, + "prefix": "nvidia_", + "component_class": NVIDIAModelComponent(), + }, +} +ALL_PROVIDER_FIELDS: list[str] = [field for provider in MODEL_PROVIDERS_DICT.values() for field in provider["fields"]] diff --git a/src/backend/base/langflow/base/models/model_utils.py b/src/backend/base/langflow/base/models/model_utils.py deleted file mode 100644 index 9157d873c..000000000 --- a/src/backend/base/langflow/base/models/model_utils.py +++ /dev/null @@ -1,31 +0,0 @@ -import importlib - -from langflow.base.models.model import LCModelComponent -from langflow.inputs.inputs import InputTypes - - -def get_model_info() -> dict[str, dict[str, str | list[InputTypes]]]: - """Get inputs for all model components.""" - model_inputs = {} - models_module = importlib.import_module("langflow.components.models") - model_component_names = getattr(models_module, "__all__", []) - - for name in model_component_names: - if name in ("base", "DynamicLLMComponent"): # Skip the base module - continue - - component_class = getattr(models_module, name) - if issubclass(component_class, LCModelComponent): - component = component_class() - base_input_names = {input_field.name for input_field in LCModelComponent._base_inputs} - input_fields_list = [ - input_field for input_field in component.inputs if input_field.name not in base_input_names - ] - component_display_name = component.display_name - model_inputs[name] = { - "display_name": component_display_name, - "inputs": input_fields_list, - "icon": component.icon, - } - - return model_inputs diff --git a/src/backend/base/langflow/components/agents/agent.py b/src/backend/base/langflow/components/agents/agent.py index bd4aaf2ba..19d748a5f 100644 --- a/src/backend/base/langflow/components/agents/agent.py +++ b/src/backend/base/langflow/components/agents/agent.py @@ -1,9 +1,7 @@ from langflow.base.agents.agent import LCToolsAgentComponent -from langflow.base.models.model import LCModelComponent +from langflow.base.models.model_input_constants import ALL_PROVIDER_FIELDS, MODEL_PROVIDERS_DICT from langflow.components.agents.tool_calling import ToolCallingAgentComponent from langflow.components.helpers.memory import MemoryComponent -from langflow.components.models.azure_openai import AzureChatOpenAIComponent -from langflow.components.models.openai import OpenAIModelComponent from langflow.io import ( DropdownInput, MultilineInput, @@ -25,30 +23,19 @@ class AgentComponent(ToolCallingAgentComponent): beta = True name = "Agent" - azure_inputs = [ - set_advanced_true(component_input) if component_input.name == "temperature" else component_input - for component_input in AzureChatOpenAIComponent().inputs - if component_input.name not in [input_field.name for input_field in LCModelComponent._base_inputs] - ] - openai_inputs = [ - set_advanced_true(component_input) if component_input.name == "temperature" else component_input - for component_input in OpenAIModelComponent().inputs - if component_input.name not in [input_field.name for input_field in LCModelComponent._base_inputs] - ] - memory_inputs = [set_advanced_true(component_input) for component_input in MemoryComponent().inputs] inputs = [ DropdownInput( name="agent_llm", display_name="Model Provider", - options=["Azure OpenAI", "OpenAI", "Custom"], + options=[*sorted(MODEL_PROVIDERS_DICT.keys()), "Custom"], value="OpenAI", real_time_refresh=True, refresh_button=True, input_types=[], ), - *openai_inputs, + *MODEL_PROVIDERS_DICT["OpenAI"]["inputs"], MultilineInput( name="system_prompt", display_name="Agent Instructions", @@ -86,71 +73,89 @@ class AgentComponent(ToolCallingAgentComponent): return MemoryComponent().set(**memory_kwargs).retrieve_messages() def get_llm(self): - try: - if self.agent_llm == "OpenAI": - return self._build_llm_model(OpenAIModelComponent(), self.openai_inputs) - if self.agent_llm == "Azure OpenAI": - return self._build_llm_model(AzureChatOpenAIComponent(), self.azure_inputs, prefix="azure_param_") - except Exception as e: - msg = f"Error building {self.agent_llm} language model" - raise ValueError(msg) from e + if isinstance(self.agent_llm, str): + try: + provider_info = MODEL_PROVIDERS_DICT.get(self.agent_llm) + if provider_info: + component_class = provider_info.get("component_class") + inputs = provider_info.get("inputs") + prefix = provider_info.get("prefix", "") + return self._build_llm_model(component_class, inputs, prefix) + except Exception as e: + msg = f"Error building {self.agent_llm} language model" + raise ValueError(msg) from e return self.agent_llm def _build_llm_model(self, component, inputs, prefix=""): - return component.set( - **{component_input.name: getattr(self, f"{prefix}{component_input.name}") for component_input in inputs} - ).build_model() + model_kwargs = {input_.name: getattr(self, f"{prefix}{input_.name}") for input_ in inputs} + return component.set(**model_kwargs).build_model() - def delete_fields(self, build_config, fields): + def delete_fields(self, build_config: dotdict, fields: dict | list[str]) -> None: + """Delete specified fields from build_config.""" for field in fields: build_config.pop(field, None) - def update_build_config(self, build_config: dotdict, field_value: str, field_name: str | None = None): + def update_input_types(self, build_config: dotdict) -> dotdict: + """Update input types for all fields in build_config.""" + for key, value in build_config.items(): + if isinstance(value, dict): + if value.get("input_types") is None: + build_config[key]["input_types"] = [] + elif hasattr(value, "input_types") and value.input_types is None: + value.input_types = [] + return build_config + + def update_build_config(self, build_config: dotdict, field_value: str, field_name: str | None = None) -> dotdict: if field_name == "agent_llm": - openai_fields = {component_input.name: component_input for component_input in self.openai_inputs} - azure_fields = { - f"azure_param_{component_input.name}": component_input for component_input in self.azure_inputs + # Define provider configurations as (fields_to_add, fields_to_delete) + provider_configs: dict[str, tuple[dict, list[dict]]] = { + provider: ( + MODEL_PROVIDERS_DICT[provider]["fields"], + [ + MODEL_PROVIDERS_DICT[other_provider]["fields"] + for other_provider in MODEL_PROVIDERS_DICT + if other_provider != provider + ], + ) + for provider in MODEL_PROVIDERS_DICT } - if field_value == "OpenAI": - self.delete_fields(build_config, {**azure_fields}) - if not any(field in build_config for field in openai_fields): - build_config.update(openai_fields) - build_config["agent_llm"]["input_types"] = [] - build_config = self.update_input_types(build_config) + if field_value in provider_configs: + fields_to_add, fields_to_delete = provider_configs[field_value] - elif field_value == "Azure OpenAI": - self.delete_fields(build_config, {**openai_fields}) - build_config.update(azure_fields) + # Delete fields from other providers + for fields in fields_to_delete: + self.delete_fields(build_config, fields) + + # Add provider-specific fields + if field_value == "OpenAI" and not any(field in build_config for field in fields_to_add): + build_config.update(fields_to_add) + else: + build_config.update(fields_to_add) + # Reset input types for agent_llm build_config["agent_llm"]["input_types"] = [] - build_config = self.update_input_types(build_config) elif field_value == "Custom": - self.delete_fields(build_config, {**openai_fields}) - self.delete_fields(build_config, {**azure_fields}) - new_component = DropdownInput( + # Delete all provider fields + self.delete_fields(build_config, ALL_PROVIDER_FIELDS) + # Update with custom component + custom_component = DropdownInput( name="agent_llm", display_name="Language Model", - options=["Azure OpenAI", "OpenAI", "Custom"], + options=[*sorted(MODEL_PROVIDERS_DICT.keys()), "Custom"], value="Custom", real_time_refresh=True, input_types=["LanguageModel"], ) - build_config.update({"agent_llm": new_component.to_dict()}) - build_config = self.update_input_types(build_config) + build_config.update({"agent_llm": custom_component.to_dict()}) + + # Update input types for all fields + build_config = self.update_input_types(build_config) + + # Validate required keys default_keys = ["code", "_type", "agent_llm", "tools", "input_value"] missing_keys = [key for key in default_keys if key not in build_config] if missing_keys: msg = f"Missing required keys in build_config: {missing_keys}" raise ValueError(msg) - return build_config - def update_input_types(self, build_config): - for key, value in build_config.items(): - # Check if the value is a dictionary - if isinstance(value, dict): - if value.get("input_types") is None: - build_config[key]["input_types"] = [] - # Check if the value has an attribute 'input_types' and it is None - elif hasattr(value, "input_types") and value.input_types is None: - value.input_types = [] return build_config diff --git a/src/backend/tests/unit/base/models/__init__.py b/src/backend/tests/unit/base/models/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/backend/tests/unit/base/models/test_model_constants.py b/src/backend/tests/unit/base/models/test_model_constants.py deleted file mode 100644 index 27d0f4704..000000000 --- a/src/backend/tests/unit/base/models/test_model_constants.py +++ /dev/null @@ -1,25 +0,0 @@ -from src.backend.base.langflow.base.models.model_constants import ModelConstants - - -def test_provider_names(): - # Initialize the ModelConstants - ModelConstants.initialize() - - # Expected provider names - expected_provider_names = [ - "AIML", - "Amazon Bedrock", - "Anthropic", - "Azure OpenAI", - "Ollama", - "Vertex AI", - "Cohere", - "Google Generative AI", - "HuggingFace", - "OpenAI", - "Perplexity", - "Qianfan", - ] - - # Assert that the provider names match the expected list - assert expected_provider_names == ModelConstants.PROVIDER_NAMES