feat: Add model_utils and model_constants intorducing PROVIDER_NAMES and MODEL_INFO variables for dynamic updation of model data (#4341)

feat: Add model_utils and model_constants

- Enhanced the initialization logic in model_constants to handle delayed imports and circular dependencies.
- Improved type hinting for better code clarity and maintainability.

Details:
- `get_model_info`: Retrieves comprehensive information about all available models, which is used to populate the `MODEL_INFO` dictionary.
- `MODEL_INFO`: A dictionary where each key is a model identifier, and the value is a dictionary containing details about the model, such as its `display_name` and configuration options.
- `PROVIDER_NAMES`: A list derived from `MODEL_INFO` that holds the names of model providers, providing a quick reference to all available model providers.
This commit is contained in:
Edwin Jose 2024-11-01 11:32:43 -04:00 committed by GitHub
commit 98ee051ca6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 138 additions and 4 deletions

View file

@ -1,3 +1,4 @@
import importlib
import json
import warnings
from abc import abstractmethod
@ -206,3 +207,63 @@ class LCModelComponent(Component):
@abstractmethod
def build_model(self) -> LanguageModel: # type: ignore[type-var]
"""Implement this method to build the model."""
def get_llm(self, provider_name: str, model_info: dict[str, dict[str, str | list[InputTypes]]]) -> LanguageModel:
"""Get LLM model based on provider name and inputs.
Args:
provider_name: Name of the model provider (e.g., "OpenAI", "Azure OpenAI")
inputs: Dictionary of input parameters for the model
model_info: Dictionary of model information
Returns:
Built LLM model instance
"""
try:
if provider_name not in [model.get("display_name") for model in model_info.values()]:
msg = f"Unknown model provider: {provider_name}"
raise ValueError(msg)
# Find the component class name from MODEL_INFO in a single iteration
component_info, module_name = next(
((info, key) for key, info in model_info.items() if info.get("display_name") == provider_name),
(None, None),
)
if not component_info:
msg = f"Component information not found for {provider_name}"
raise ValueError(msg)
component_inputs = component_info.get("inputs", [])
# Get the component class from the models module
# Ensure component_inputs is a list of the expected types
if not isinstance(component_inputs, list):
component_inputs = []
models_module = importlib.import_module("langflow.components.models")
component_class = getattr(models_module, str(module_name))
component = component_class()
return self.build_llm_model_from_inputs(component, component_inputs)
except Exception as e:
msg = f"Error building {provider_name} language model"
raise ValueError(msg) from e
def build_llm_model_from_inputs(
self, component: Component, inputs: list[InputTypes], prefix: str = ""
) -> LanguageModel:
"""Build LLM model from component and inputs.
Args:
component: LLM component instance
inputs: Dictionary of input parameters for the model
prefix: Prefix for the input names
Returns:
Built LLM model instance
"""
# Ensure prefix is a string
prefix = prefix or ""
# Filter inputs to only include valid component input names
input_data = {
str(component_input.name): getattr(self, f"{prefix}{component_input.name}", None)
for component_input in inputs
}
return component.set(**input_data).build_model()

View file

@ -0,0 +1,17 @@
class ModelConstants:
"""Class to hold model-related constants. To solve circular import issue."""
PROVIDER_NAMES: list[str] = []
MODEL_INFO: dict[str, dict[str, str | list]] = {} # Adjusted type hint
@staticmethod
def initialize():
from langflow.base.models.model_utils import get_model_info # Delayed import
model_info = get_model_info()
ModelConstants.MODEL_INFO = model_info
ModelConstants.PROVIDER_NAMES = [
str(model.get("display_name"))
for model in model_info.values()
if isinstance(model.get("display_name"), str)
]

View file

@ -0,0 +1,31 @@
import importlib
from langflow.base.models.model import LCModelComponent
from langflow.inputs.inputs import InputTypes
def get_model_info() -> dict[str, dict[str, str | list[InputTypes]]]:
"""Get inputs for all model components."""
model_inputs = {}
models_module = importlib.import_module("langflow.components.models")
model_component_names = getattr(models_module, "__all__", [])
for name in model_component_names:
if name in ("base", "DynamicLLMComponent"): # Skip the base module
continue
component_class = getattr(models_module, name)
if issubclass(component_class, LCModelComponent):
component = component_class()
base_input_names = {input_field.name for input_field in LCModelComponent._base_inputs}
input_fields_list = [
input_field for input_field in component.inputs if input_field.name not in base_input_names
]
component_display_name = component.display_name
model_inputs[name] = {
"display_name": component_display_name,
"inputs": input_fields_list,
"icon": component.icon,
}
return model_inputs

View file

@ -29,10 +29,10 @@ def build_vertex(self, vertex: Vertex) -> Vertex:
@celery_app.task(acks_late=True)
def process_graph_cached_task(
data_graph: dict[str, Any], # noqa: ARG001
inputs: dict | list[dict] | None = None, # noqa: ARG001
clear_cache=False, # noqa: ARG001, FBT002
session_id=None, # noqa: ARG001
data_graph: dict[str, Any],
inputs: dict | list[dict] | None = None,
clear_cache=False, # noqa: FBT002
session_id=None,
) -> dict[str, Any]:
msg = "This task is not implemented yet"
raise NotImplementedError(msg)

View file

@ -0,0 +1,25 @@
from src.backend.base.langflow.base.models.model_constants import ModelConstants
def test_provider_names():
# Initialize the ModelConstants
ModelConstants.initialize()
# Expected provider names
expected_provider_names = [
"AIML",
"Amazon Bedrock",
"Anthropic",
"Azure OpenAI",
"Ollama",
"Vertex AI",
"Cohere",
"Google Generative AI",
"HuggingFace",
"OpenAI",
"Perplexity",
"Qianfan",
]
# Assert that the provider names match the expected list
assert expected_provider_names == ModelConstants.PROVIDER_NAMES