feat: Add model_utils and model_constants intorducing PROVIDER_NAMES and MODEL_INFO variables for dynamic updation of model data (#4341)
feat: Add model_utils and model_constants - Enhanced the initialization logic in model_constants to handle delayed imports and circular dependencies. - Improved type hinting for better code clarity and maintainability. Details: - `get_model_info`: Retrieves comprehensive information about all available models, which is used to populate the `MODEL_INFO` dictionary. - `MODEL_INFO`: A dictionary where each key is a model identifier, and the value is a dictionary containing details about the model, such as its `display_name` and configuration options. - `PROVIDER_NAMES`: A list derived from `MODEL_INFO` that holds the names of model providers, providing a quick reference to all available model providers.
This commit is contained in:
parent
10a63ff07e
commit
98ee051ca6
6 changed files with 138 additions and 4 deletions
|
|
@ -1,3 +1,4 @@
|
|||
import importlib
|
||||
import json
|
||||
import warnings
|
||||
from abc import abstractmethod
|
||||
|
|
@ -206,3 +207,63 @@ class LCModelComponent(Component):
|
|||
@abstractmethod
|
||||
def build_model(self) -> LanguageModel: # type: ignore[type-var]
|
||||
"""Implement this method to build the model."""
|
||||
|
||||
def get_llm(self, provider_name: str, model_info: dict[str, dict[str, str | list[InputTypes]]]) -> LanguageModel:
|
||||
"""Get LLM model based on provider name and inputs.
|
||||
|
||||
Args:
|
||||
provider_name: Name of the model provider (e.g., "OpenAI", "Azure OpenAI")
|
||||
inputs: Dictionary of input parameters for the model
|
||||
model_info: Dictionary of model information
|
||||
|
||||
Returns:
|
||||
Built LLM model instance
|
||||
"""
|
||||
try:
|
||||
if provider_name not in [model.get("display_name") for model in model_info.values()]:
|
||||
msg = f"Unknown model provider: {provider_name}"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Find the component class name from MODEL_INFO in a single iteration
|
||||
component_info, module_name = next(
|
||||
((info, key) for key, info in model_info.items() if info.get("display_name") == provider_name),
|
||||
(None, None),
|
||||
)
|
||||
if not component_info:
|
||||
msg = f"Component information not found for {provider_name}"
|
||||
raise ValueError(msg)
|
||||
component_inputs = component_info.get("inputs", [])
|
||||
# Get the component class from the models module
|
||||
# Ensure component_inputs is a list of the expected types
|
||||
if not isinstance(component_inputs, list):
|
||||
component_inputs = []
|
||||
models_module = importlib.import_module("langflow.components.models")
|
||||
component_class = getattr(models_module, str(module_name))
|
||||
component = component_class()
|
||||
|
||||
return self.build_llm_model_from_inputs(component, component_inputs)
|
||||
except Exception as e:
|
||||
msg = f"Error building {provider_name} language model"
|
||||
raise ValueError(msg) from e
|
||||
|
||||
def build_llm_model_from_inputs(
|
||||
self, component: Component, inputs: list[InputTypes], prefix: str = ""
|
||||
) -> LanguageModel:
|
||||
"""Build LLM model from component and inputs.
|
||||
|
||||
Args:
|
||||
component: LLM component instance
|
||||
inputs: Dictionary of input parameters for the model
|
||||
prefix: Prefix for the input names
|
||||
Returns:
|
||||
Built LLM model instance
|
||||
"""
|
||||
# Ensure prefix is a string
|
||||
prefix = prefix or ""
|
||||
# Filter inputs to only include valid component input names
|
||||
input_data = {
|
||||
str(component_input.name): getattr(self, f"{prefix}{component_input.name}", None)
|
||||
for component_input in inputs
|
||||
}
|
||||
|
||||
return component.set(**input_data).build_model()
|
||||
|
|
|
|||
17
src/backend/base/langflow/base/models/model_constants.py
Normal file
17
src/backend/base/langflow/base/models/model_constants.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
class ModelConstants:
|
||||
"""Class to hold model-related constants. To solve circular import issue."""
|
||||
|
||||
PROVIDER_NAMES: list[str] = []
|
||||
MODEL_INFO: dict[str, dict[str, str | list]] = {} # Adjusted type hint
|
||||
|
||||
@staticmethod
|
||||
def initialize():
|
||||
from langflow.base.models.model_utils import get_model_info # Delayed import
|
||||
|
||||
model_info = get_model_info()
|
||||
ModelConstants.MODEL_INFO = model_info
|
||||
ModelConstants.PROVIDER_NAMES = [
|
||||
str(model.get("display_name"))
|
||||
for model in model_info.values()
|
||||
if isinstance(model.get("display_name"), str)
|
||||
]
|
||||
31
src/backend/base/langflow/base/models/model_utils.py
Normal file
31
src/backend/base/langflow/base/models/model_utils.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
import importlib
|
||||
|
||||
from langflow.base.models.model import LCModelComponent
|
||||
from langflow.inputs.inputs import InputTypes
|
||||
|
||||
|
||||
def get_model_info() -> dict[str, dict[str, str | list[InputTypes]]]:
|
||||
"""Get inputs for all model components."""
|
||||
model_inputs = {}
|
||||
models_module = importlib.import_module("langflow.components.models")
|
||||
model_component_names = getattr(models_module, "__all__", [])
|
||||
|
||||
for name in model_component_names:
|
||||
if name in ("base", "DynamicLLMComponent"): # Skip the base module
|
||||
continue
|
||||
|
||||
component_class = getattr(models_module, name)
|
||||
if issubclass(component_class, LCModelComponent):
|
||||
component = component_class()
|
||||
base_input_names = {input_field.name for input_field in LCModelComponent._base_inputs}
|
||||
input_fields_list = [
|
||||
input_field for input_field in component.inputs if input_field.name not in base_input_names
|
||||
]
|
||||
component_display_name = component.display_name
|
||||
model_inputs[name] = {
|
||||
"display_name": component_display_name,
|
||||
"inputs": input_fields_list,
|
||||
"icon": component.icon,
|
||||
}
|
||||
|
||||
return model_inputs
|
||||
|
|
@ -29,10 +29,10 @@ def build_vertex(self, vertex: Vertex) -> Vertex:
|
|||
|
||||
@celery_app.task(acks_late=True)
|
||||
def process_graph_cached_task(
|
||||
data_graph: dict[str, Any], # noqa: ARG001
|
||||
inputs: dict | list[dict] | None = None, # noqa: ARG001
|
||||
clear_cache=False, # noqa: ARG001, FBT002
|
||||
session_id=None, # noqa: ARG001
|
||||
data_graph: dict[str, Any],
|
||||
inputs: dict | list[dict] | None = None,
|
||||
clear_cache=False, # noqa: FBT002
|
||||
session_id=None,
|
||||
) -> dict[str, Any]:
|
||||
msg = "This task is not implemented yet"
|
||||
raise NotImplementedError(msg)
|
||||
|
|
|
|||
0
src/backend/tests/unit/base/models/__init__.py
Normal file
0
src/backend/tests/unit/base/models/__init__.py
Normal file
25
src/backend/tests/unit/base/models/test_model_constants.py
Normal file
25
src/backend/tests/unit/base/models/test_model_constants.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
from src.backend.base.langflow.base.models.model_constants import ModelConstants
|
||||
|
||||
|
||||
def test_provider_names():
|
||||
# Initialize the ModelConstants
|
||||
ModelConstants.initialize()
|
||||
|
||||
# Expected provider names
|
||||
expected_provider_names = [
|
||||
"AIML",
|
||||
"Amazon Bedrock",
|
||||
"Anthropic",
|
||||
"Azure OpenAI",
|
||||
"Ollama",
|
||||
"Vertex AI",
|
||||
"Cohere",
|
||||
"Google Generative AI",
|
||||
"HuggingFace",
|
||||
"OpenAI",
|
||||
"Perplexity",
|
||||
"Qianfan",
|
||||
]
|
||||
|
||||
# Assert that the provider names match the expected list
|
||||
assert expected_provider_names == ModelConstants.PROVIDER_NAMES
|
||||
Loading…
Add table
Add a link
Reference in a new issue