fix: add default models to Anthropic and make sure template is updated (#5839)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Gabriel Luiz Freitas Almeida 2025-01-21 12:25:47 -03:00 committed by GitHub
commit 050c12df35
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 240 additions and 75 deletions

View file

@ -8,16 +8,7 @@ from typing import TYPE_CHECKING, Annotated
from uuid import UUID
import sqlalchemy as sa
from fastapi import (
APIRouter,
BackgroundTasks,
Body,
Depends,
HTTPException,
Request,
UploadFile,
status,
)
from fastapi import APIRouter, BackgroundTasks, Body, Depends, HTTPException, Request, UploadFile, status
from fastapi.encoders import jsonable_encoder
from fastapi.responses import StreamingResponse
from loguru import logger
@ -36,11 +27,7 @@ from langflow.api.v1.schemas import (
UploadFileResponse,
)
from langflow.custom.custom_component.component import Component
from langflow.custom.utils import (
build_custom_component_template,
get_instance_name,
update_component_build_config,
)
from langflow.custom.utils import build_custom_component_template, get_instance_name, update_component_build_config
from langflow.events.event_manager import create_stream_tokens_event_manager
from langflow.exceptions.api import APIException, InvalidChatInputError
from langflow.exceptions.serialization import SerializationError
@ -55,16 +42,9 @@ from langflow.services.auth.utils import api_key_security, get_current_active_us
from langflow.services.cache.utils import save_uploaded_file
from langflow.services.database.models.flow import Flow
from langflow.services.database.models.flow.model import FlowRead
from langflow.services.database.models.flow.utils import (
get_all_webhook_components_in_flow,
)
from langflow.services.database.models.flow.utils import get_all_webhook_components_in_flow
from langflow.services.database.models.user.model import User, UserRead
from langflow.services.deps import (
get_session_service,
get_settings_service,
get_task_service,
get_telemetry_service,
)
from langflow.services.deps import get_session_service, get_settings_service, get_task_service, get_telemetry_service
from langflow.services.settings.feature_flags import FEATURE_FLAGS
from langflow.services.telemetry.schema import RunPayload
from langflow.utils.version import get_version_info
@ -720,7 +700,6 @@ async def custom_component_update(
user_id=user.id,
)
template_data = code_request.model_dump().get("template", {}).copy()
component_node["tool_mode"] = code_request.tool_mode
if hasattr(cc_instance, "set_attributes"):
@ -749,12 +728,6 @@ async def custom_component_update(
)
component_node["template"] = updated_build_config
# Preserve previous field values by merging filtered template data into
# the component node's template. Only include entries where the value
# is a dictionary containing the key "value".
filtered_data = {k: v for k, v in template_data.items() if isinstance(v, dict) and "value" in v}
component_node["template"] |= filtered_data
if isinstance(cc_instance, Component):
await cc_instance.run_and_validate_update_outputs(
frontend_node=component_node,

View file

@ -8,6 +8,7 @@ from langflow.components.models.groq import GroqModel
from langflow.components.models.nvidia import NVIDIAModelComponent
from langflow.components.models.openai import OpenAIModelComponent
from langflow.inputs.inputs import InputTypes, SecretStrInput
from langflow.template.field.base import Input
class ModelProvidersDict(TypedDict):
@ -24,7 +25,7 @@ def get_filtered_inputs(component_class):
return [process_inputs(input_) for input_ in component_instance.inputs if input_.name not in base_input_names]
def process_inputs(component_data):
def process_inputs(component_data: Input):
if isinstance(component_data, SecretStrInput):
component_data.value = ""
component_data.load_from_db = False
@ -61,8 +62,8 @@ def add_combobox_true(component_input):
return component_input
def create_input_fields_dict(inputs, prefix):
return {f"{prefix}{input_.name}": input_ for input_ in inputs}
def create_input_fields_dict(inputs: list[Input], prefix: str) -> dict[str, Input]:
return {f"{prefix}{input_.name}": input_.to_dict() for input_ in inputs}
def _get_openai_inputs_and_fields():
@ -73,7 +74,7 @@ def _get_openai_inputs_and_fields():
except ImportError as e:
msg = "OpenAI is not installed. Please install it with `pip install langchain-openai`."
raise ImportError(msg) from e
return openai_inputs, {input_.name: input_ for input_ in openai_inputs}
return openai_inputs, create_input_fields_dict(openai_inputs, "")
def _get_azure_inputs_and_fields():
@ -204,10 +205,4 @@ except ImportError:
MODEL_PROVIDERS = list(MODEL_PROVIDERS_DICT.keys())
ALL_PROVIDER_FIELDS: list[str] = [field for provider in MODEL_PROVIDERS_DICT.values() for field in provider["fields"]]
MODEL_DYNAMIC_UPDATE_FIELDS = [
"api_key",
"model",
"tool_model_enabled",
"base_url",
"model_name",
]
MODEL_DYNAMIC_UPDATE_FIELDS = ["api_key", "model", "tool_model_enabled", "base_url", "model_name"]

View file

@ -9,9 +9,7 @@ from langflow.base.models.model_input_constants import (
from langflow.base.models.model_utils import get_model_name
from langflow.components.helpers import CurrentDateComponent
from langflow.components.helpers.memory import MemoryComponent
from langflow.components.langchain_utilities.tool_calling import (
ToolCallingAgentComponent,
)
from langflow.components.langchain_utilities.tool_calling import ToolCallingAgentComponent
from langflow.custom.utils import update_component_build_config
from langflow.io import BoolInput, DropdownInput, MultilineInput, Output
from langflow.logging import logger
@ -121,6 +119,8 @@ class AgentComponent(ToolCallingAgentComponent):
memory_kwargs = {
component_input.name: getattr(self, f"{component_input.name}") for component_input in self.memory_inputs
}
# filter out empty values
memory_kwargs = {k: v for k, v in memory_kwargs.items() if v}
return await MemoryComponent().set(**memory_kwargs).retrieve_messages()
@ -177,13 +177,14 @@ class AgentComponent(ToolCallingAgentComponent):
# Iterate over all providers in the MODEL_PROVIDERS_DICT
# Existing logic for updating build_config
if field_name in ("agent_llm",):
build_config["agent_llm"]["value"] = field_value
provider_info = MODEL_PROVIDERS_DICT.get(field_value)
if provider_info:
component_class = provider_info.get("component_class")
if component_class and hasattr(component_class, "update_build_config"):
# Call the component class's update_build_config method
build_config = await update_component_build_config(
component_class, build_config, field_value, field_name
component_class, build_config, field_value, "model_name"
)
provider_configs: dict[str, tuple[dict, list[dict]]] = {
@ -261,6 +262,6 @@ class AgentComponent(ToolCallingAgentComponent):
if isinstance(field_name, str) and isinstance(prefix, str):
field_name = field_name.replace(prefix, "")
build_config = await update_component_build_config(
component_class, build_config, field_value, field_name
component_class, build_config, field_value, "model_name"
)
return build_config
return {k: v.to_dict() if hasattr(v, "to_dict") else v for k, v in build_config.items()}

View file

@ -29,8 +29,9 @@ class AnthropicModelComponent(LCModelComponent):
DropdownInput(
name="model_name",
display_name="Model Name",
options=[],
options=ANTHROPIC_MODELS,
refresh_button=True,
value=ANTHROPIC_MODELS[0],
),
SecretStrInput(
name="api_key",
@ -138,14 +139,16 @@ class AnthropicModelComponent(LCModelComponent):
def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
if field_name in ("base_url", "model_name", "tool_model_enabled", "api_key") and field_value:
try:
if len(self.api_key) != 0:
if len(self.api_key) == 0:
ids = ANTHROPIC_MODELS
else:
try:
ids = self.get_models(tool_model_enabled=self.tool_model_enabled)
except (ImportError, ValueError, requests.exceptions.RequestException) as e:
logger.exception(f"Error getting model names: {e}")
ids = ANTHROPIC_MODELS
build_config["model_name"]["options"] = ids
build_config["model_name"]["value"] = ids[0]
build_config["model_name"]["options"] = ids
build_config["model_name"]["value"] = ids[0]
except Exception as e:
msg = f"Error getting model names: {e}"
raise ValueError(msg) from e

View file

@ -58,7 +58,8 @@ class GroqModel(LCModelComponent):
name="model_name",
display_name="Model",
info="The name of the model to use.",
options=[],
options=GROQ_MODELS,
value=GROQ_MODELS[0],
refresh_button=True,
real_time_refresh=True,
),

View file

@ -694,7 +694,11 @@ class Component(CustomComponent):
name, f"Input is connected to {input_value.__self__.display_name}.{input_value.__name__}"
)
raise ValueError(msg)
self._inputs[name].value = value
try:
self._inputs[name].value = value
except Exception as e:
msg = f"Error setting input value for {name}: {e}"
raise ValueError(msg) from e
if hasattr(self._inputs[name], "load_from_db"):
self._inputs[name].load_from_db = False
else:

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -254,7 +254,10 @@ async def astore_message(
return []
if not message.session_id or not message.sender or not message.sender_name:
msg = "All of session_id, sender, and sender_name must be provided."
msg = (
f"All of session_id, sender, and sender_name must be provided. Session ID: {message.session_id},"
f" Sender: {message.sender}, Sender Name: {message.sender_name}"
)
raise ValueError(msg)
if hasattr(message, "id") and message.id:
# if message has an id and exist in the database, update it

View file

@ -1,6 +1,12 @@
import inspect
from typing import Any
from unittest.mock import Mock
from uuid import uuid4
import pytest
from langflow.custom.custom_component.component import Component
from langflow.graph.graph.base import Graph
from langflow.graph.vertex.base import Vertex
from typing_extensions import TypedDict
from tests.constants import SUPPORTED_VERSIONS
@ -45,9 +51,20 @@ class ComponentTestBase:
msg = f"{self.__class__.__name__} must implement the file_names_mapping fixture"
raise NotImplementedError(msg)
def component_setup(self, component_class: type[Any], default_kwargs: dict[str, Any]) -> Component:
mock_vertex = Mock(spec=Vertex)
mock_vertex.graph = Mock(spec=Graph)
mock_vertex.graph.session_id = str(uuid4())
mock_vertex.graph.flow_id = str(uuid4())
source_code = inspect.getsource(component_class)
component_instance = component_class(_code=source_code, **default_kwargs)
component_instance._vertex = mock_vertex
return component_instance
def test_latest_version(self, component_class: type[Any], default_kwargs: dict[str, Any]) -> None:
"""Test that the component works with the latest version."""
result = component_class(**default_kwargs)()
component_instance = self.component_setup(component_class, default_kwargs)
result = component_instance()
assert result is not None, "Component returned None for the latest version."
def test_all_versions_have_a_file_name_defined(self, file_names_mapping: list[VersionComponentMapping]) -> None:

View file

@ -1,9 +1,12 @@
import inspect
from typing import Any
from aiofile import async_open
from fastapi import status
from httpx import AsyncClient
from langflow.api.v1.schemas import UpdateCustomComponentRequest
from langflow.components.agents.agent import AgentComponent
from langflow.custom.utils import build_custom_component_template
async def test_get_version(client: AsyncClient):
@ -46,3 +49,59 @@ async def test_update_component_outputs(client: AsyncClient, logged_in_headers:
assert response.status_code == status.HTTP_200_OK
output_names = [output["name"] for output in result["outputs"]]
assert "tool_output" in output_names
async def test_update_component_model_name_options(client: AsyncClient, logged_in_headers: dict):
"""Test that model_name options are updated when selecting a provider."""
component = AgentComponent()
component_node, cc_instance = build_custom_component_template(
component,
)
# Initial template with OpenAI as the provider
template = component_node["template"]
current_model_names = template["model_name"]["options"]
# load the code from the file at langflow.components.agents.agent.py asynchronously
# we are at str/backend/tests/unit/api/v1/test_endpoints.py
# find the file by using the class AgentComponent
agent_component_file = inspect.getsourcefile(AgentComponent)
async with async_open(agent_component_file, encoding="utf-8") as f:
code = await f.read()
# Create the request to update the component
request = UpdateCustomComponentRequest(
code=code,
frontend_node=component_node,
field="agent_llm",
field_value="Anthropic",
template=template,
)
# Make the request to update the component
response = await client.post("api/v1/custom_component/update", json=request.model_dump(), headers=logged_in_headers)
result = response.json()
# Verify the response
assert response.status_code == status.HTTP_200_OK, f"Response: {response.json()}"
assert "template" in result
assert "model_name" in result["template"]
assert isinstance(result["template"]["model_name"]["options"], list)
assert len(result["template"]["model_name"]["options"]) > 0, (
f"Model names: {result['template']['model_name']['options']}"
)
assert current_model_names != result["template"]["model_name"]["options"], (
f"Current model names: {current_model_names}, New model names: {result['template']['model_name']['options']}"
)
# Now test with Custom provider
template["agent_llm"]["value"] = "Custom"
request.field_value = "Custom"
request.template = template
response = await client.post("api/v1/custom_component/update", json=request.model_dump(), headers=logged_in_headers)
result = response.json()
# Verify that model_name is not present for Custom provider
assert response.status_code == status.HTTP_200_OK
assert "template" in result
assert "model_name" not in result["template"]

View file

@ -1,8 +1,99 @@
import os
from typing import Any
from uuid import uuid4
import pytest
from langflow.base.models.model_input_constants import MODEL_PROVIDERS_DICT
from langflow.components.agents.agent import AgentComponent
from langflow.components.tools.calculator import CalculatorToolComponent
from langflow.custom import Component
from langflow.utils.constants import MESSAGE_SENDER_AI, MESSAGE_SENDER_NAME_AI
from tests.base import ComponentTestBaseWithoutClient
from tests.unit.mock_language_model import MockLanguageModel
class TestAgentComponent(ComponentTestBaseWithoutClient):
@pytest.fixture
def component_class(self):
return AgentComponent
@pytest.fixture
def file_names_mapping(self):
return []
def component_setup(self, component_class: type[Any], default_kwargs: dict[str, Any]) -> Component:
component_instance = super().component_setup(component_class, default_kwargs)
# Mock _should_process_output method
component_instance._should_process_output = lambda output: False # noqa: ARG005
return component_instance
@pytest.fixture
def default_kwargs(self):
return {
"_type": "Agent",
"add_current_date_tool": True,
"agent_description": "A helpful agent",
"agent_llm": MockLanguageModel(),
"handle_parsing_errors": True,
"input_value": "",
"max_iterations": 10,
"system_prompt": "You are a helpful assistant.",
"tools": [],
"verbose": True,
"session_id": str(uuid4()),
"sender": MESSAGE_SENDER_AI,
"sender_name": MESSAGE_SENDER_NAME_AI,
}
async def test_build_config_update(self, component_class, default_kwargs):
component = self.component_setup(component_class, default_kwargs)
frontend_node = component.to_frontend_node()
build_config = frontend_node["data"]["node"]["template"]
# Test updating build config for OpenAI
component.set(agent_llm="OpenAI")
updated_config = await component.update_build_config(build_config, "OpenAI", "agent_llm")
assert "agent_llm" in updated_config
assert updated_config["agent_llm"]["value"] == "OpenAI"
assert isinstance(updated_config["agent_llm"]["options"], list)
assert len(updated_config["agent_llm"]["options"]) > 0
assert all(provider in updated_config["agent_llm"]["options"] for provider in MODEL_PROVIDERS_DICT)
assert "Custom" in updated_config["agent_llm"]["options"]
# Verify model_name field is populated for OpenAI
assert "model_name" in updated_config
model_name_dict = updated_config["model_name"]
assert isinstance(model_name_dict["options"], list)
assert len(model_name_dict["options"]) > 0 # OpenAI should have available models
assert "gpt-4o" in model_name_dict["options"]
# Test Anthropic
component.set(agent_llm="Anthropic")
updated_config = await component.update_build_config(build_config, "Anthropic", "agent_llm")
assert "agent_llm" in updated_config
assert updated_config["agent_llm"]["value"] == "Anthropic"
assert isinstance(updated_config["agent_llm"]["options"], list)
assert len(updated_config["agent_llm"]["options"]) > 0
assert all(provider in updated_config["agent_llm"]["options"] for provider in MODEL_PROVIDERS_DICT)
assert "Anthropic" in updated_config["agent_llm"]["options"]
assert updated_config["agent_llm"]["input_types"] == []
assert any("sonnet" in option.lower() for option in updated_config["model_name"]["options"]), (
f"Options: {updated_config['model_name']['options']}"
)
# Test updating build config for Custom
updated_config = await component.update_build_config(build_config, "Custom", "agent_llm")
assert "agent_llm" in updated_config
assert updated_config["agent_llm"]["value"] == "Custom"
assert isinstance(updated_config["agent_llm"]["options"], list)
assert len(updated_config["agent_llm"]["options"]) > 0
assert all(provider in updated_config["agent_llm"]["options"] for provider in MODEL_PROVIDERS_DICT)
assert "Custom" in updated_config["agent_llm"]["options"]
assert updated_config["agent_llm"]["input_types"] == ["LanguageModel"]
# Verify model_name field is cleared for Custom
assert "model_name" not in updated_config
@pytest.mark.api_key_required

View file

@ -1,17 +1,21 @@
from unittest.mock import MagicMock
from langchain_core.language_models import BaseLanguageModel
from pydantic import BaseModel, Field
from typing_extensions import override
class MockLanguageModel(BaseLanguageModel):
class MockLanguageModel(BaseLanguageModel, BaseModel):
"""A mock language model for testing purposes."""
def __init__(self, response_generator=None):
tools: list = Field(default_factory=list)
response_generator: callable = Field(default_factory=lambda: lambda msg: f"Response for {msg}")
def __init__(self, response_generator=None, **kwargs):
"""Initialize the mock model with an optional response generator function."""
super().__init__()
# Use object's __dict__ to bypass pydantic validation
object.__setattr__(self, "_response_generator", response_generator or (lambda msg: f"Response for {msg}"))
super().__init__(**kwargs)
if response_generator:
self.response_generator = response_generator
@override
def with_config(self, *args, **kwargs):
@ -30,7 +34,7 @@ class MockLanguageModel(BaseLanguageModel):
for msg_list in messages:
content = msg_list[-1]["content"] if isinstance(msg_list, list) else msg_list
mock_response = MagicMock()
mock_response.content = self._response_generator(content)
mock_response.content = self.response_generator(content)
responses.append(mock_response)
return responses
@ -61,3 +65,8 @@ class MockLanguageModel(BaseLanguageModel):
@override
async def apredict_messages(self, *args, **kwargs):
raise NotImplementedError
def bind_tools(self, tools):
"""Bind tools to the model for testing."""
self.tools = tools
return self