fix: add default models to Anthropic and make sure template is updated (#5839)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
64d82d407a
commit
050c12df35
19 changed files with 240 additions and 75 deletions
|
|
@ -8,16 +8,7 @@ from typing import TYPE_CHECKING, Annotated
|
|||
from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
BackgroundTasks,
|
||||
Body,
|
||||
Depends,
|
||||
HTTPException,
|
||||
Request,
|
||||
UploadFile,
|
||||
status,
|
||||
)
|
||||
from fastapi import APIRouter, BackgroundTasks, Body, Depends, HTTPException, Request, UploadFile, status
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.responses import StreamingResponse
|
||||
from loguru import logger
|
||||
|
|
@ -36,11 +27,7 @@ from langflow.api.v1.schemas import (
|
|||
UploadFileResponse,
|
||||
)
|
||||
from langflow.custom.custom_component.component import Component
|
||||
from langflow.custom.utils import (
|
||||
build_custom_component_template,
|
||||
get_instance_name,
|
||||
update_component_build_config,
|
||||
)
|
||||
from langflow.custom.utils import build_custom_component_template, get_instance_name, update_component_build_config
|
||||
from langflow.events.event_manager import create_stream_tokens_event_manager
|
||||
from langflow.exceptions.api import APIException, InvalidChatInputError
|
||||
from langflow.exceptions.serialization import SerializationError
|
||||
|
|
@ -55,16 +42,9 @@ from langflow.services.auth.utils import api_key_security, get_current_active_us
|
|||
from langflow.services.cache.utils import save_uploaded_file
|
||||
from langflow.services.database.models.flow import Flow
|
||||
from langflow.services.database.models.flow.model import FlowRead
|
||||
from langflow.services.database.models.flow.utils import (
|
||||
get_all_webhook_components_in_flow,
|
||||
)
|
||||
from langflow.services.database.models.flow.utils import get_all_webhook_components_in_flow
|
||||
from langflow.services.database.models.user.model import User, UserRead
|
||||
from langflow.services.deps import (
|
||||
get_session_service,
|
||||
get_settings_service,
|
||||
get_task_service,
|
||||
get_telemetry_service,
|
||||
)
|
||||
from langflow.services.deps import get_session_service, get_settings_service, get_task_service, get_telemetry_service
|
||||
from langflow.services.settings.feature_flags import FEATURE_FLAGS
|
||||
from langflow.services.telemetry.schema import RunPayload
|
||||
from langflow.utils.version import get_version_info
|
||||
|
|
@ -720,7 +700,6 @@ async def custom_component_update(
|
|||
user_id=user.id,
|
||||
)
|
||||
|
||||
template_data = code_request.model_dump().get("template", {}).copy()
|
||||
component_node["tool_mode"] = code_request.tool_mode
|
||||
|
||||
if hasattr(cc_instance, "set_attributes"):
|
||||
|
|
@ -749,12 +728,6 @@ async def custom_component_update(
|
|||
)
|
||||
component_node["template"] = updated_build_config
|
||||
|
||||
# Preserve previous field values by merging filtered template data into
|
||||
# the component node's template. Only include entries where the value
|
||||
# is a dictionary containing the key "value".
|
||||
filtered_data = {k: v for k, v in template_data.items() if isinstance(v, dict) and "value" in v}
|
||||
component_node["template"] |= filtered_data
|
||||
|
||||
if isinstance(cc_instance, Component):
|
||||
await cc_instance.run_and_validate_update_outputs(
|
||||
frontend_node=component_node,
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from langflow.components.models.groq import GroqModel
|
|||
from langflow.components.models.nvidia import NVIDIAModelComponent
|
||||
from langflow.components.models.openai import OpenAIModelComponent
|
||||
from langflow.inputs.inputs import InputTypes, SecretStrInput
|
||||
from langflow.template.field.base import Input
|
||||
|
||||
|
||||
class ModelProvidersDict(TypedDict):
|
||||
|
|
@ -24,7 +25,7 @@ def get_filtered_inputs(component_class):
|
|||
return [process_inputs(input_) for input_ in component_instance.inputs if input_.name not in base_input_names]
|
||||
|
||||
|
||||
def process_inputs(component_data):
|
||||
def process_inputs(component_data: Input):
|
||||
if isinstance(component_data, SecretStrInput):
|
||||
component_data.value = ""
|
||||
component_data.load_from_db = False
|
||||
|
|
@ -61,8 +62,8 @@ def add_combobox_true(component_input):
|
|||
return component_input
|
||||
|
||||
|
||||
def create_input_fields_dict(inputs, prefix):
|
||||
return {f"{prefix}{input_.name}": input_ for input_ in inputs}
|
||||
def create_input_fields_dict(inputs: list[Input], prefix: str) -> dict[str, Input]:
|
||||
return {f"{prefix}{input_.name}": input_.to_dict() for input_ in inputs}
|
||||
|
||||
|
||||
def _get_openai_inputs_and_fields():
|
||||
|
|
@ -73,7 +74,7 @@ def _get_openai_inputs_and_fields():
|
|||
except ImportError as e:
|
||||
msg = "OpenAI is not installed. Please install it with `pip install langchain-openai`."
|
||||
raise ImportError(msg) from e
|
||||
return openai_inputs, {input_.name: input_ for input_ in openai_inputs}
|
||||
return openai_inputs, create_input_fields_dict(openai_inputs, "")
|
||||
|
||||
|
||||
def _get_azure_inputs_and_fields():
|
||||
|
|
@ -204,10 +205,4 @@ except ImportError:
|
|||
MODEL_PROVIDERS = list(MODEL_PROVIDERS_DICT.keys())
|
||||
ALL_PROVIDER_FIELDS: list[str] = [field for provider in MODEL_PROVIDERS_DICT.values() for field in provider["fields"]]
|
||||
|
||||
MODEL_DYNAMIC_UPDATE_FIELDS = [
|
||||
"api_key",
|
||||
"model",
|
||||
"tool_model_enabled",
|
||||
"base_url",
|
||||
"model_name",
|
||||
]
|
||||
MODEL_DYNAMIC_UPDATE_FIELDS = ["api_key", "model", "tool_model_enabled", "base_url", "model_name"]
|
||||
|
|
|
|||
|
|
@ -9,9 +9,7 @@ from langflow.base.models.model_input_constants import (
|
|||
from langflow.base.models.model_utils import get_model_name
|
||||
from langflow.components.helpers import CurrentDateComponent
|
||||
from langflow.components.helpers.memory import MemoryComponent
|
||||
from langflow.components.langchain_utilities.tool_calling import (
|
||||
ToolCallingAgentComponent,
|
||||
)
|
||||
from langflow.components.langchain_utilities.tool_calling import ToolCallingAgentComponent
|
||||
from langflow.custom.utils import update_component_build_config
|
||||
from langflow.io import BoolInput, DropdownInput, MultilineInput, Output
|
||||
from langflow.logging import logger
|
||||
|
|
@ -121,6 +119,8 @@ class AgentComponent(ToolCallingAgentComponent):
|
|||
memory_kwargs = {
|
||||
component_input.name: getattr(self, f"{component_input.name}") for component_input in self.memory_inputs
|
||||
}
|
||||
# filter out empty values
|
||||
memory_kwargs = {k: v for k, v in memory_kwargs.items() if v}
|
||||
|
||||
return await MemoryComponent().set(**memory_kwargs).retrieve_messages()
|
||||
|
||||
|
|
@ -177,13 +177,14 @@ class AgentComponent(ToolCallingAgentComponent):
|
|||
# Iterate over all providers in the MODEL_PROVIDERS_DICT
|
||||
# Existing logic for updating build_config
|
||||
if field_name in ("agent_llm",):
|
||||
build_config["agent_llm"]["value"] = field_value
|
||||
provider_info = MODEL_PROVIDERS_DICT.get(field_value)
|
||||
if provider_info:
|
||||
component_class = provider_info.get("component_class")
|
||||
if component_class and hasattr(component_class, "update_build_config"):
|
||||
# Call the component class's update_build_config method
|
||||
build_config = await update_component_build_config(
|
||||
component_class, build_config, field_value, field_name
|
||||
component_class, build_config, field_value, "model_name"
|
||||
)
|
||||
|
||||
provider_configs: dict[str, tuple[dict, list[dict]]] = {
|
||||
|
|
@ -261,6 +262,6 @@ class AgentComponent(ToolCallingAgentComponent):
|
|||
if isinstance(field_name, str) and isinstance(prefix, str):
|
||||
field_name = field_name.replace(prefix, "")
|
||||
build_config = await update_component_build_config(
|
||||
component_class, build_config, field_value, field_name
|
||||
component_class, build_config, field_value, "model_name"
|
||||
)
|
||||
return build_config
|
||||
return {k: v.to_dict() if hasattr(v, "to_dict") else v for k, v in build_config.items()}
|
||||
|
|
|
|||
|
|
@ -29,8 +29,9 @@ class AnthropicModelComponent(LCModelComponent):
|
|||
DropdownInput(
|
||||
name="model_name",
|
||||
display_name="Model Name",
|
||||
options=[],
|
||||
options=ANTHROPIC_MODELS,
|
||||
refresh_button=True,
|
||||
value=ANTHROPIC_MODELS[0],
|
||||
),
|
||||
SecretStrInput(
|
||||
name="api_key",
|
||||
|
|
@ -138,14 +139,16 @@ class AnthropicModelComponent(LCModelComponent):
|
|||
def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
|
||||
if field_name in ("base_url", "model_name", "tool_model_enabled", "api_key") and field_value:
|
||||
try:
|
||||
if len(self.api_key) != 0:
|
||||
if len(self.api_key) == 0:
|
||||
ids = ANTHROPIC_MODELS
|
||||
else:
|
||||
try:
|
||||
ids = self.get_models(tool_model_enabled=self.tool_model_enabled)
|
||||
except (ImportError, ValueError, requests.exceptions.RequestException) as e:
|
||||
logger.exception(f"Error getting model names: {e}")
|
||||
ids = ANTHROPIC_MODELS
|
||||
build_config["model_name"]["options"] = ids
|
||||
build_config["model_name"]["value"] = ids[0]
|
||||
build_config["model_name"]["options"] = ids
|
||||
build_config["model_name"]["value"] = ids[0]
|
||||
except Exception as e:
|
||||
msg = f"Error getting model names: {e}"
|
||||
raise ValueError(msg) from e
|
||||
|
|
|
|||
|
|
@ -58,7 +58,8 @@ class GroqModel(LCModelComponent):
|
|||
name="model_name",
|
||||
display_name="Model",
|
||||
info="The name of the model to use.",
|
||||
options=[],
|
||||
options=GROQ_MODELS,
|
||||
value=GROQ_MODELS[0],
|
||||
refresh_button=True,
|
||||
real_time_refresh=True,
|
||||
),
|
||||
|
|
|
|||
|
|
@ -694,7 +694,11 @@ class Component(CustomComponent):
|
|||
name, f"Input is connected to {input_value.__self__.display_name}.{input_value.__name__}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
self._inputs[name].value = value
|
||||
try:
|
||||
self._inputs[name].value = value
|
||||
except Exception as e:
|
||||
msg = f"Error setting input value for {name}: {e}"
|
||||
raise ValueError(msg) from e
|
||||
if hasattr(self._inputs[name], "load_from_db"):
|
||||
self._inputs[name].load_from_db = False
|
||||
else:
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
|
@ -254,7 +254,10 @@ async def astore_message(
|
|||
return []
|
||||
|
||||
if not message.session_id or not message.sender or not message.sender_name:
|
||||
msg = "All of session_id, sender, and sender_name must be provided."
|
||||
msg = (
|
||||
f"All of session_id, sender, and sender_name must be provided. Session ID: {message.session_id},"
|
||||
f" Sender: {message.sender}, Sender Name: {message.sender_name}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
if hasattr(message, "id") and message.id:
|
||||
# if message has an id and exist in the database, update it
|
||||
|
|
|
|||
|
|
@ -1,6 +1,12 @@
|
|||
import inspect
|
||||
from typing import Any
|
||||
from unittest.mock import Mock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from langflow.custom.custom_component.component import Component
|
||||
from langflow.graph.graph.base import Graph
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from tests.constants import SUPPORTED_VERSIONS
|
||||
|
|
@ -45,9 +51,20 @@ class ComponentTestBase:
|
|||
msg = f"{self.__class__.__name__} must implement the file_names_mapping fixture"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def component_setup(self, component_class: type[Any], default_kwargs: dict[str, Any]) -> Component:
|
||||
mock_vertex = Mock(spec=Vertex)
|
||||
mock_vertex.graph = Mock(spec=Graph)
|
||||
mock_vertex.graph.session_id = str(uuid4())
|
||||
mock_vertex.graph.flow_id = str(uuid4())
|
||||
source_code = inspect.getsource(component_class)
|
||||
component_instance = component_class(_code=source_code, **default_kwargs)
|
||||
component_instance._vertex = mock_vertex
|
||||
return component_instance
|
||||
|
||||
def test_latest_version(self, component_class: type[Any], default_kwargs: dict[str, Any]) -> None:
|
||||
"""Test that the component works with the latest version."""
|
||||
result = component_class(**default_kwargs)()
|
||||
component_instance = self.component_setup(component_class, default_kwargs)
|
||||
result = component_instance()
|
||||
assert result is not None, "Component returned None for the latest version."
|
||||
|
||||
def test_all_versions_have_a_file_name_defined(self, file_names_mapping: list[VersionComponentMapping]) -> None:
|
||||
|
|
|
|||
|
|
@ -1,9 +1,12 @@
|
|||
import inspect
|
||||
from typing import Any
|
||||
|
||||
from aiofile import async_open
|
||||
from fastapi import status
|
||||
from httpx import AsyncClient
|
||||
from langflow.api.v1.schemas import UpdateCustomComponentRequest
|
||||
from langflow.components.agents.agent import AgentComponent
|
||||
from langflow.custom.utils import build_custom_component_template
|
||||
|
||||
|
||||
async def test_get_version(client: AsyncClient):
|
||||
|
|
@ -46,3 +49,59 @@ async def test_update_component_outputs(client: AsyncClient, logged_in_headers:
|
|||
assert response.status_code == status.HTTP_200_OK
|
||||
output_names = [output["name"] for output in result["outputs"]]
|
||||
assert "tool_output" in output_names
|
||||
|
||||
|
||||
async def test_update_component_model_name_options(client: AsyncClient, logged_in_headers: dict):
|
||||
"""Test that model_name options are updated when selecting a provider."""
|
||||
component = AgentComponent()
|
||||
component_node, cc_instance = build_custom_component_template(
|
||||
component,
|
||||
)
|
||||
|
||||
# Initial template with OpenAI as the provider
|
||||
template = component_node["template"]
|
||||
current_model_names = template["model_name"]["options"]
|
||||
|
||||
# load the code from the file at langflow.components.agents.agent.py asynchronously
|
||||
# we are at str/backend/tests/unit/api/v1/test_endpoints.py
|
||||
# find the file by using the class AgentComponent
|
||||
agent_component_file = inspect.getsourcefile(AgentComponent)
|
||||
async with async_open(agent_component_file, encoding="utf-8") as f:
|
||||
code = await f.read()
|
||||
|
||||
# Create the request to update the component
|
||||
request = UpdateCustomComponentRequest(
|
||||
code=code,
|
||||
frontend_node=component_node,
|
||||
field="agent_llm",
|
||||
field_value="Anthropic",
|
||||
template=template,
|
||||
)
|
||||
|
||||
# Make the request to update the component
|
||||
response = await client.post("api/v1/custom_component/update", json=request.model_dump(), headers=logged_in_headers)
|
||||
result = response.json()
|
||||
|
||||
# Verify the response
|
||||
assert response.status_code == status.HTTP_200_OK, f"Response: {response.json()}"
|
||||
assert "template" in result
|
||||
assert "model_name" in result["template"]
|
||||
assert isinstance(result["template"]["model_name"]["options"], list)
|
||||
assert len(result["template"]["model_name"]["options"]) > 0, (
|
||||
f"Model names: {result['template']['model_name']['options']}"
|
||||
)
|
||||
assert current_model_names != result["template"]["model_name"]["options"], (
|
||||
f"Current model names: {current_model_names}, New model names: {result['template']['model_name']['options']}"
|
||||
)
|
||||
# Now test with Custom provider
|
||||
template["agent_llm"]["value"] = "Custom"
|
||||
request.field_value = "Custom"
|
||||
request.template = template
|
||||
|
||||
response = await client.post("api/v1/custom_component/update", json=request.model_dump(), headers=logged_in_headers)
|
||||
result = response.json()
|
||||
|
||||
# Verify that model_name is not present for Custom provider
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert "template" in result
|
||||
assert "model_name" not in result["template"]
|
||||
|
|
|
|||
|
|
@ -1,8 +1,99 @@
|
|||
import os
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from langflow.base.models.model_input_constants import MODEL_PROVIDERS_DICT
|
||||
from langflow.components.agents.agent import AgentComponent
|
||||
from langflow.components.tools.calculator import CalculatorToolComponent
|
||||
from langflow.custom import Component
|
||||
from langflow.utils.constants import MESSAGE_SENDER_AI, MESSAGE_SENDER_NAME_AI
|
||||
|
||||
from tests.base import ComponentTestBaseWithoutClient
|
||||
from tests.unit.mock_language_model import MockLanguageModel
|
||||
|
||||
|
||||
class TestAgentComponent(ComponentTestBaseWithoutClient):
|
||||
@pytest.fixture
|
||||
def component_class(self):
|
||||
return AgentComponent
|
||||
|
||||
@pytest.fixture
|
||||
def file_names_mapping(self):
|
||||
return []
|
||||
|
||||
def component_setup(self, component_class: type[Any], default_kwargs: dict[str, Any]) -> Component:
|
||||
component_instance = super().component_setup(component_class, default_kwargs)
|
||||
# Mock _should_process_output method
|
||||
component_instance._should_process_output = lambda output: False # noqa: ARG005
|
||||
return component_instance
|
||||
|
||||
@pytest.fixture
|
||||
def default_kwargs(self):
|
||||
return {
|
||||
"_type": "Agent",
|
||||
"add_current_date_tool": True,
|
||||
"agent_description": "A helpful agent",
|
||||
"agent_llm": MockLanguageModel(),
|
||||
"handle_parsing_errors": True,
|
||||
"input_value": "",
|
||||
"max_iterations": 10,
|
||||
"system_prompt": "You are a helpful assistant.",
|
||||
"tools": [],
|
||||
"verbose": True,
|
||||
"session_id": str(uuid4()),
|
||||
"sender": MESSAGE_SENDER_AI,
|
||||
"sender_name": MESSAGE_SENDER_NAME_AI,
|
||||
}
|
||||
|
||||
async def test_build_config_update(self, component_class, default_kwargs):
|
||||
component = self.component_setup(component_class, default_kwargs)
|
||||
frontend_node = component.to_frontend_node()
|
||||
build_config = frontend_node["data"]["node"]["template"]
|
||||
# Test updating build config for OpenAI
|
||||
component.set(agent_llm="OpenAI")
|
||||
updated_config = await component.update_build_config(build_config, "OpenAI", "agent_llm")
|
||||
assert "agent_llm" in updated_config
|
||||
assert updated_config["agent_llm"]["value"] == "OpenAI"
|
||||
assert isinstance(updated_config["agent_llm"]["options"], list)
|
||||
assert len(updated_config["agent_llm"]["options"]) > 0
|
||||
assert all(provider in updated_config["agent_llm"]["options"] for provider in MODEL_PROVIDERS_DICT)
|
||||
assert "Custom" in updated_config["agent_llm"]["options"]
|
||||
|
||||
# Verify model_name field is populated for OpenAI
|
||||
|
||||
assert "model_name" in updated_config
|
||||
model_name_dict = updated_config["model_name"]
|
||||
assert isinstance(model_name_dict["options"], list)
|
||||
assert len(model_name_dict["options"]) > 0 # OpenAI should have available models
|
||||
assert "gpt-4o" in model_name_dict["options"]
|
||||
|
||||
# Test Anthropic
|
||||
component.set(agent_llm="Anthropic")
|
||||
updated_config = await component.update_build_config(build_config, "Anthropic", "agent_llm")
|
||||
assert "agent_llm" in updated_config
|
||||
assert updated_config["agent_llm"]["value"] == "Anthropic"
|
||||
assert isinstance(updated_config["agent_llm"]["options"], list)
|
||||
assert len(updated_config["agent_llm"]["options"]) > 0
|
||||
assert all(provider in updated_config["agent_llm"]["options"] for provider in MODEL_PROVIDERS_DICT)
|
||||
assert "Anthropic" in updated_config["agent_llm"]["options"]
|
||||
assert updated_config["agent_llm"]["input_types"] == []
|
||||
assert any("sonnet" in option.lower() for option in updated_config["model_name"]["options"]), (
|
||||
f"Options: {updated_config['model_name']['options']}"
|
||||
)
|
||||
|
||||
# Test updating build config for Custom
|
||||
updated_config = await component.update_build_config(build_config, "Custom", "agent_llm")
|
||||
assert "agent_llm" in updated_config
|
||||
assert updated_config["agent_llm"]["value"] == "Custom"
|
||||
assert isinstance(updated_config["agent_llm"]["options"], list)
|
||||
assert len(updated_config["agent_llm"]["options"]) > 0
|
||||
assert all(provider in updated_config["agent_llm"]["options"] for provider in MODEL_PROVIDERS_DICT)
|
||||
assert "Custom" in updated_config["agent_llm"]["options"]
|
||||
assert updated_config["agent_llm"]["input_types"] == ["LanguageModel"]
|
||||
|
||||
# Verify model_name field is cleared for Custom
|
||||
assert "model_name" not in updated_config
|
||||
|
||||
|
||||
@pytest.mark.api_key_required
|
||||
|
|
|
|||
|
|
@ -1,17 +1,21 @@
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class MockLanguageModel(BaseLanguageModel):
|
||||
class MockLanguageModel(BaseLanguageModel, BaseModel):
|
||||
"""A mock language model for testing purposes."""
|
||||
|
||||
def __init__(self, response_generator=None):
|
||||
tools: list = Field(default_factory=list)
|
||||
response_generator: callable = Field(default_factory=lambda: lambda msg: f"Response for {msg}")
|
||||
|
||||
def __init__(self, response_generator=None, **kwargs):
|
||||
"""Initialize the mock model with an optional response generator function."""
|
||||
super().__init__()
|
||||
# Use object's __dict__ to bypass pydantic validation
|
||||
object.__setattr__(self, "_response_generator", response_generator or (lambda msg: f"Response for {msg}"))
|
||||
super().__init__(**kwargs)
|
||||
if response_generator:
|
||||
self.response_generator = response_generator
|
||||
|
||||
@override
|
||||
def with_config(self, *args, **kwargs):
|
||||
|
|
@ -30,7 +34,7 @@ class MockLanguageModel(BaseLanguageModel):
|
|||
for msg_list in messages:
|
||||
content = msg_list[-1]["content"] if isinstance(msg_list, list) else msg_list
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = self._response_generator(content)
|
||||
mock_response.content = self.response_generator(content)
|
||||
responses.append(mock_response)
|
||||
return responses
|
||||
|
||||
|
|
@ -61,3 +65,8 @@ class MockLanguageModel(BaseLanguageModel):
|
|||
@override
|
||||
async def apredict_messages(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def bind_tools(self, tools):
|
||||
"""Bind tools to the model for testing."""
|
||||
self.tools = tools
|
||||
return self
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue