fix: improve tool calling filter in ollama model component (#8056)
* improve tool calling filter * [autofix.ci] apply automated fixes * Update ollama.py * update tests * [autofix.ci] apply automated fixes * fix: correct variable reference for tool model capability check in ChatOllamaComponent * Update test_chatollama_component.py * [autofix.ci] apply automated fixes --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org> Co-authored-by: Carlos Coelho <80289056+carlosrcoelho@users.noreply.github.com>
This commit is contained in:
parent
073659d5f4
commit
43629b21ad
2 changed files with 184 additions and 132 deletions
|
|
@ -6,7 +6,7 @@ import httpx
|
|||
from langchain_ollama import ChatOllama
|
||||
|
||||
from langflow.base.models.model import LCModelComponent
|
||||
from langflow.base.models.ollama_constants import OLLAMA_TOOL_MODELS_BASE, URL_LIST
|
||||
from langflow.base.models.ollama_constants import URL_LIST
|
||||
from langflow.field_typing import LanguageModel
|
||||
from langflow.field_typing.range_spec import RangeSpec
|
||||
from langflow.io import BoolInput, DictInput, DropdownInput, FloatInput, IntInput, MessageTextInput, SliderInput
|
||||
|
|
@ -26,6 +26,7 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
JSON_NAME_KEY = "name"
|
||||
JSON_CAPABILITIES_KEY = "capabilities"
|
||||
DESIRED_CAPABILITY = "completion"
|
||||
TOOL_CALLING_CAPABILITY = "tools"
|
||||
|
||||
inputs = [
|
||||
MessageTextInput(
|
||||
|
|
@ -130,7 +131,7 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
name="tool_model_enabled",
|
||||
display_name="Tool Model Enabled",
|
||||
info="Whether to enable tool calling in the model.",
|
||||
value=False,
|
||||
value=True,
|
||||
real_time_refresh=True,
|
||||
),
|
||||
MessageTextInput(
|
||||
|
|
@ -314,17 +315,13 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
capabilities = json_data.get(self.JSON_CAPABILITIES_KEY, [])
|
||||
logger.debug(f"Model: {model_name}, Capabilities: {capabilities}")
|
||||
|
||||
if self.DESIRED_CAPABILITY in capabilities:
|
||||
if self.DESIRED_CAPABILITY in capabilities and (
|
||||
not tool_model_enabled or self.TOOL_CALLING_CAPABILITY in capabilities
|
||||
):
|
||||
model_ids.append(model_name)
|
||||
|
||||
except (httpx.RequestError, ValueError) as e:
|
||||
msg = "Could not get model names from Ollama."
|
||||
raise ValueError(msg) from e
|
||||
|
||||
return (
|
||||
model_ids if not tool_model_enabled else [model for model in model_ids if self.supports_tool_calling(model)]
|
||||
)
|
||||
|
||||
def supports_tool_calling(self, model: str) -> bool:
|
||||
"""Check if model name is in the base of any models example llama3.3 can have 1b and 2b."""
|
||||
return any(model.startswith(f"{tool_model}") for tool_model in OLLAMA_TOOL_MODELS_BASE)
|
||||
return model_ids
|
||||
|
|
|
|||
|
|
@ -2,141 +2,196 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
|
||||
import pytest
|
||||
from langchain_ollama import ChatOllama
|
||||
from langflow.components.models import ChatOllamaComponent
|
||||
from langflow.components.models.ollama import ChatOllamaComponent
|
||||
|
||||
from tests.base import ComponentTestBaseWithoutClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def component():
|
||||
return ChatOllamaComponent()
|
||||
class TestChatOllamaComponent(ComponentTestBaseWithoutClient):
|
||||
@pytest.fixture
|
||||
def component_class(self):
|
||||
return ChatOllamaComponent
|
||||
|
||||
@pytest.fixture
|
||||
def default_kwargs(self):
|
||||
return {
|
||||
"base_url": "http://localhost:8000",
|
||||
"model_name": "ollama-model",
|
||||
"temperature": 0.1,
|
||||
"format": "json",
|
||||
"metadata": {},
|
||||
"tags": "",
|
||||
"mirostat": "Disabled",
|
||||
"num_ctx": 2048,
|
||||
"num_gpu": 1,
|
||||
"num_thread": 4,
|
||||
"repeat_last_n": 64,
|
||||
"repeat_penalty": 1.1,
|
||||
"tfs_z": 1.0,
|
||||
"timeout": 30,
|
||||
"top_k": 40,
|
||||
"top_p": 0.9,
|
||||
"verbose": False,
|
||||
"tool_model_enabled": True,
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("langflow.components.models.ollama.httpx.AsyncClient.post")
|
||||
@patch("langflow.components.models.ollama.httpx.AsyncClient.get")
|
||||
async def test_get_models_success(mock_get, mock_post, component):
|
||||
# The revised approach to get_models filters based on model capabilities.
|
||||
# It requires one request to ollama to get the models and another to check
|
||||
# the capabilities of each model.
|
||||
mock_get_response = AsyncMock()
|
||||
mock_get_response.raise_for_status.return_value = None
|
||||
mock_get_response.json.return_value = {
|
||||
component.JSON_MODELS_KEY: [{component.JSON_NAME_KEY: "model1"}, {component.JSON_NAME_KEY: "model2"}]
|
||||
}
|
||||
mock_get.return_value = mock_get_response
|
||||
@pytest.fixture
|
||||
def file_names_mapping(self):
|
||||
# Provide an empty list or the actual mapping if versioned files exist
|
||||
return []
|
||||
|
||||
# Mock the response for the HTTP POST request to check capabilities.
|
||||
# Note that this is not exactly what happens if the Ollama server is running,
|
||||
# but it is a good approximation.
|
||||
# The first call checks the capabilities of model1, and the second call checks the capabilities of model2.
|
||||
mock_post_response = AsyncMock()
|
||||
mock_post_response.raise_for_status.return_value = None
|
||||
mock_post_response.json.side_effect = [
|
||||
{component.JSON_CAPABILITIES_KEY: [component.DESIRED_CAPABILITY]},
|
||||
{component.JSON_CAPABILITIES_KEY: []},
|
||||
]
|
||||
mock_post.return_value = mock_post_response
|
||||
@patch("langflow.components.models.ollama.ChatOllama")
|
||||
async def test_build_model(self, mock_chat_ollama, component_class, default_kwargs):
|
||||
mock_instance = MagicMock()
|
||||
mock_chat_ollama.return_value = mock_instance
|
||||
component = component_class(**default_kwargs)
|
||||
model = component.build_model()
|
||||
mock_chat_ollama.assert_called_once_with(
|
||||
base_url="http://localhost:8000",
|
||||
model="ollama-model",
|
||||
mirostat=0,
|
||||
format="json",
|
||||
metadata={},
|
||||
num_ctx=2048,
|
||||
num_gpu=1,
|
||||
num_thread=4,
|
||||
repeat_last_n=64,
|
||||
repeat_penalty=1.1,
|
||||
temperature=0.1,
|
||||
system="",
|
||||
tfs_z=1.0,
|
||||
timeout=30,
|
||||
top_k=40,
|
||||
top_p=0.9,
|
||||
verbose=False,
|
||||
template="",
|
||||
)
|
||||
assert model == mock_instance
|
||||
|
||||
base_url = "http://localhost:11434"
|
||||
result = await component.get_models(base_url)
|
||||
@patch("langflow.components.models.ollama.ChatOllama")
|
||||
async def test_build_model_missing_base_url(self, mock_chat_ollama, component_class, default_kwargs):
|
||||
# Make the mock raise an exception to simulate connection failure
|
||||
mock_chat_ollama.side_effect = Exception("connection error")
|
||||
component = component_class(**default_kwargs)
|
||||
component.base_url = None
|
||||
with pytest.raises(ValueError, match="Unable to connect to the Ollama API."):
|
||||
component.build_model()
|
||||
|
||||
# Check that the correct URL was used for the GET request
|
||||
assert result == ["model1"]
|
||||
assert mock_get.call_count == 1
|
||||
assert mock_post.call_count == 2
|
||||
@pytest.mark.asyncio
|
||||
@patch("langflow.components.models.ollama.httpx.AsyncClient.post")
|
||||
@patch("langflow.components.models.ollama.httpx.AsyncClient.get")
|
||||
async def test_get_models_success(self, mock_get, mock_post):
|
||||
component = ChatOllamaComponent()
|
||||
mock_get_response = AsyncMock()
|
||||
mock_get_response.raise_for_status.return_value = None
|
||||
mock_get_response.json.return_value = {
|
||||
component.JSON_MODELS_KEY: [
|
||||
{component.JSON_NAME_KEY: "model1"},
|
||||
{component.JSON_NAME_KEY: "model2"},
|
||||
]
|
||||
}
|
||||
mock_get.return_value = mock_get_response
|
||||
|
||||
mock_post_response = AsyncMock()
|
||||
mock_post_response.raise_for_status.return_value = None
|
||||
mock_post_response.json.side_effect = [
|
||||
{component.JSON_CAPABILITIES_KEY: [component.DESIRED_CAPABILITY]},
|
||||
{component.JSON_CAPABILITIES_KEY: []},
|
||||
]
|
||||
mock_post.return_value = mock_post_response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("langflow.components.models.ollama.httpx.AsyncClient.get")
|
||||
async def test_get_models_failure(mock_get, component):
|
||||
# Simulate a network error for /api/tags
|
||||
import httpx
|
||||
base_url = "http://localhost:11434"
|
||||
result = await component.get_models(base_url)
|
||||
assert result == ["model1"]
|
||||
assert mock_get.call_count == 1
|
||||
assert mock_post.call_count == 2
|
||||
|
||||
mock_get.side_effect = httpx.RequestError("Connection error", request=None)
|
||||
@pytest.mark.asyncio
|
||||
@patch("langflow.components.models.ollama.httpx.AsyncClient.get")
|
||||
async def test_get_models_failure(self, mock_get):
|
||||
import httpx
|
||||
|
||||
base_url = "http://localhost:11434"
|
||||
with pytest.raises(ValueError, match="Could not get model names from Ollama."):
|
||||
await component.get_models(base_url)
|
||||
component = ChatOllamaComponent()
|
||||
mock_get.side_effect = httpx.RequestError("Connection error", request=None)
|
||||
base_url = "http://localhost:11434"
|
||||
with pytest.raises(ValueError, match="Could not get model names from Ollama."):
|
||||
await component.get_models(base_url)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_build_config_mirostat_disabled(self):
|
||||
component = ChatOllamaComponent()
|
||||
build_config = {
|
||||
"mirostat_eta": {"advanced": False, "value": 0.1},
|
||||
"mirostat_tau": {"advanced": False, "value": 5},
|
||||
}
|
||||
field_value = "Disabled"
|
||||
field_name = "mirostat"
|
||||
updated_config = await component.update_build_config(build_config, field_value, field_name)
|
||||
assert updated_config["mirostat_eta"]["advanced"] is True
|
||||
assert updated_config["mirostat_tau"]["advanced"] is True
|
||||
assert updated_config["mirostat_eta"]["value"] is None
|
||||
assert updated_config["mirostat_tau"]["value"] is None
|
||||
|
||||
async def test_update_build_config_mirostat_disabled(component):
|
||||
build_config = {
|
||||
"mirostat_eta": {"advanced": False, "value": 0.1},
|
||||
"mirostat_tau": {"advanced": False, "value": 5},
|
||||
}
|
||||
field_value = "Disabled"
|
||||
field_name = "mirostat"
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_build_config_mirostat_enabled(self):
|
||||
component = ChatOllamaComponent()
|
||||
build_config = {
|
||||
"mirostat_eta": {"advanced": False, "value": None},
|
||||
"mirostat_tau": {"advanced": False, "value": None},
|
||||
}
|
||||
field_value = "Mirostat 2.0"
|
||||
field_name = "mirostat"
|
||||
updated_config = await component.update_build_config(build_config, field_value, field_name)
|
||||
assert updated_config["mirostat_eta"]["advanced"] is False
|
||||
assert updated_config["mirostat_tau"]["advanced"] is False
|
||||
assert updated_config["mirostat_eta"]["value"] == 0.2
|
||||
assert updated_config["mirostat_tau"]["value"] == 10
|
||||
|
||||
updated_config = await component.update_build_config(build_config, field_value, field_name)
|
||||
@patch("langflow.components.models.ollama.httpx.AsyncClient.get")
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_build_config_model_name(self, mock_get):
|
||||
component = ChatOllamaComponent()
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"models": [{"name": "model1"}, {"name": "model2"}]}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_get.return_value = mock_response
|
||||
build_config = {
|
||||
"base_url": {"load_from_db": False, "value": None},
|
||||
"model_name": {"options": []},
|
||||
}
|
||||
field_value = None
|
||||
field_name = "model_name"
|
||||
with pytest.raises(ValueError, match="No valid Ollama URL found"):
|
||||
await component.update_build_config(build_config, field_value, field_name)
|
||||
|
||||
assert updated_config["mirostat_eta"]["advanced"] is True
|
||||
assert updated_config["mirostat_tau"]["advanced"] is True
|
||||
assert updated_config["mirostat_eta"]["value"] is None
|
||||
assert updated_config["mirostat_tau"]["value"] is None
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_build_config_keep_alive(self):
|
||||
component = ChatOllamaComponent()
|
||||
build_config = {"keep_alive": {"value": None, "advanced": False}}
|
||||
field_value = "Keep"
|
||||
field_name = "keep_alive_flag"
|
||||
updated_config = await component.update_build_config(build_config, field_value, field_name)
|
||||
assert updated_config["keep_alive"]["value"] == "-1"
|
||||
assert updated_config["keep_alive"]["advanced"] is True
|
||||
field_value = "Immediately"
|
||||
updated_config = await component.update_build_config(build_config, field_value, field_name)
|
||||
assert updated_config["keep_alive"]["value"] == "0"
|
||||
assert updated_config["keep_alive"]["advanced"] is True
|
||||
|
||||
|
||||
async def test_update_build_config_mirostat_enabled(component):
|
||||
build_config = {
|
||||
"mirostat_eta": {"advanced": False, "value": None},
|
||||
"mirostat_tau": {"advanced": False, "value": None},
|
||||
}
|
||||
field_value = "Mirostat 2.0"
|
||||
field_name = "mirostat"
|
||||
|
||||
updated_config = await component.update_build_config(build_config, field_value, field_name)
|
||||
|
||||
assert updated_config["mirostat_eta"]["advanced"] is False
|
||||
assert updated_config["mirostat_tau"]["advanced"] is False
|
||||
assert updated_config["mirostat_eta"]["value"] == 0.2
|
||||
assert updated_config["mirostat_tau"]["value"] == 10
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.get")
|
||||
async def test_update_build_config_model_name(mock_get, component):
|
||||
# Mock the response for the HTTP GET request
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"models": [{"name": "model1"}, {"name": "model2"}]}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
build_config = {
|
||||
"base_url": {"load_from_db": False, "value": None},
|
||||
"model_name": {"options": []},
|
||||
}
|
||||
field_value = None
|
||||
field_name = "model_name"
|
||||
|
||||
with pytest.raises(ValueError, match="No valid Ollama URL found"):
|
||||
await component.update_build_config(build_config, field_value, field_name)
|
||||
|
||||
|
||||
async def test_update_build_config_keep_alive(component):
|
||||
build_config = {"keep_alive": {"value": None, "advanced": False}}
|
||||
field_value = "Keep"
|
||||
field_name = "keep_alive_flag"
|
||||
|
||||
updated_config = await component.update_build_config(build_config, field_value, field_name)
|
||||
assert updated_config["keep_alive"]["value"] == "-1"
|
||||
assert updated_config["keep_alive"]["advanced"] is True
|
||||
|
||||
field_value = "Immediately"
|
||||
updated_config = await component.update_build_config(build_config, field_value, field_name)
|
||||
assert updated_config["keep_alive"]["value"] == "0"
|
||||
assert updated_config["keep_alive"]["advanced"] is True
|
||||
|
||||
|
||||
@patch(
|
||||
"langchain_community.chat_models.ChatOllama",
|
||||
return_value=ChatOllama(base_url="http://localhost:11434", model="llama3.1"),
|
||||
)
|
||||
def test_build_model(_mock_chat_ollama, component): # noqa: PT019
|
||||
component.base_url = "http://localhost:11434"
|
||||
component.model_name = "llama3.1"
|
||||
component.mirostat = "Mirostat 2.0"
|
||||
component.mirostat_eta = 0.2 # Ensure this is set as a float
|
||||
component.mirostat_tau = 10.0 # Ensure this is set as a float
|
||||
component.temperature = 0.2
|
||||
component.verbose = True
|
||||
model = component.build_model()
|
||||
assert isinstance(model, ChatOllama)
|
||||
assert model.base_url == "http://localhost:11434"
|
||||
assert model.model == "llama3.1"
|
||||
@patch(
|
||||
"langchain_ollama.ChatOllama",
|
||||
return_value=ChatOllama(base_url="http://localhost:11434", model="llama3.1"),
|
||||
)
|
||||
def test_build_model_integration(self, _mock_chat_ollama): # noqa: PT019
|
||||
component = ChatOllamaComponent()
|
||||
component.base_url = "http://localhost:11434"
|
||||
component.model_name = "llama3.1"
|
||||
component.mirostat = "Mirostat 2.0"
|
||||
component.mirostat_eta = 0.2
|
||||
component.mirostat_tau = 10.0
|
||||
component.temperature = 0.2
|
||||
component.verbose = True
|
||||
model = component.build_model()
|
||||
assert isinstance(model, ChatOllama)
|
||||
assert model.base_url == "http://localhost:11434"
|
||||
assert model.model == "llama3.1"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue