fix: improve tool calling filter in ollama model component (#8056)

* improve tool calling filter

* [autofix.ci] apply automated fixes

* Update ollama.py

* update tests

* [autofix.ci] apply automated fixes

* fix: correct variable reference for tool model capability check in ChatOllamaComponent

* Update test_chatollama_component.py

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>
Co-authored-by: Carlos Coelho <80289056+carlosrcoelho@users.noreply.github.com>
This commit is contained in:
Edwin Jose 2025-05-21 12:27:07 -04:00 committed by GitHub
commit 43629b21ad
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 184 additions and 132 deletions

View file

@ -6,7 +6,7 @@ import httpx
from langchain_ollama import ChatOllama
from langflow.base.models.model import LCModelComponent
from langflow.base.models.ollama_constants import OLLAMA_TOOL_MODELS_BASE, URL_LIST
from langflow.base.models.ollama_constants import URL_LIST
from langflow.field_typing import LanguageModel
from langflow.field_typing.range_spec import RangeSpec
from langflow.io import BoolInput, DictInput, DropdownInput, FloatInput, IntInput, MessageTextInput, SliderInput
@ -26,6 +26,7 @@ class ChatOllamaComponent(LCModelComponent):
JSON_NAME_KEY = "name"
JSON_CAPABILITIES_KEY = "capabilities"
DESIRED_CAPABILITY = "completion"
TOOL_CALLING_CAPABILITY = "tools"
inputs = [
MessageTextInput(
@ -130,7 +131,7 @@ class ChatOllamaComponent(LCModelComponent):
name="tool_model_enabled",
display_name="Tool Model Enabled",
info="Whether to enable tool calling in the model.",
value=False,
value=True,
real_time_refresh=True,
),
MessageTextInput(
@ -314,17 +315,13 @@ class ChatOllamaComponent(LCModelComponent):
capabilities = json_data.get(self.JSON_CAPABILITIES_KEY, [])
logger.debug(f"Model: {model_name}, Capabilities: {capabilities}")
if self.DESIRED_CAPABILITY in capabilities:
if self.DESIRED_CAPABILITY in capabilities and (
not tool_model_enabled or self.TOOL_CALLING_CAPABILITY in capabilities
):
model_ids.append(model_name)
except (httpx.RequestError, ValueError) as e:
msg = "Could not get model names from Ollama."
raise ValueError(msg) from e
return (
model_ids if not tool_model_enabled else [model for model in model_ids if self.supports_tool_calling(model)]
)
def supports_tool_calling(self, model: str) -> bool:
"""Check if model name is in the base of any models example llama3.3 can have 1b and 2b."""
return any(model.startswith(f"{tool_model}") for tool_model in OLLAMA_TOOL_MODELS_BASE)
return model_ids

View file

@ -2,141 +2,196 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from langchain_ollama import ChatOllama
from langflow.components.models import ChatOllamaComponent
from langflow.components.models.ollama import ChatOllamaComponent
from tests.base import ComponentTestBaseWithoutClient
@pytest.fixture
def component():
return ChatOllamaComponent()
class TestChatOllamaComponent(ComponentTestBaseWithoutClient):
@pytest.fixture
def component_class(self):
return ChatOllamaComponent
@pytest.fixture
def default_kwargs(self):
return {
"base_url": "http://localhost:8000",
"model_name": "ollama-model",
"temperature": 0.1,
"format": "json",
"metadata": {},
"tags": "",
"mirostat": "Disabled",
"num_ctx": 2048,
"num_gpu": 1,
"num_thread": 4,
"repeat_last_n": 64,
"repeat_penalty": 1.1,
"tfs_z": 1.0,
"timeout": 30,
"top_k": 40,
"top_p": 0.9,
"verbose": False,
"tool_model_enabled": True,
}
@pytest.mark.asyncio
@patch("langflow.components.models.ollama.httpx.AsyncClient.post")
@patch("langflow.components.models.ollama.httpx.AsyncClient.get")
async def test_get_models_success(mock_get, mock_post, component):
# The revised approach to get_models filters based on model capabilities.
# It requires one request to ollama to get the models and another to check
# the capabilities of each model.
mock_get_response = AsyncMock()
mock_get_response.raise_for_status.return_value = None
mock_get_response.json.return_value = {
component.JSON_MODELS_KEY: [{component.JSON_NAME_KEY: "model1"}, {component.JSON_NAME_KEY: "model2"}]
}
mock_get.return_value = mock_get_response
@pytest.fixture
def file_names_mapping(self):
# Provide an empty list or the actual mapping if versioned files exist
return []
# Mock the response for the HTTP POST request to check capabilities.
# Note that this is not exactly what happens if the Ollama server is running,
# but it is a good approximation.
# The first call checks the capabilities of model1, and the second call checks the capabilities of model2.
mock_post_response = AsyncMock()
mock_post_response.raise_for_status.return_value = None
mock_post_response.json.side_effect = [
{component.JSON_CAPABILITIES_KEY: [component.DESIRED_CAPABILITY]},
{component.JSON_CAPABILITIES_KEY: []},
]
mock_post.return_value = mock_post_response
@patch("langflow.components.models.ollama.ChatOllama")
async def test_build_model(self, mock_chat_ollama, component_class, default_kwargs):
mock_instance = MagicMock()
mock_chat_ollama.return_value = mock_instance
component = component_class(**default_kwargs)
model = component.build_model()
mock_chat_ollama.assert_called_once_with(
base_url="http://localhost:8000",
model="ollama-model",
mirostat=0,
format="json",
metadata={},
num_ctx=2048,
num_gpu=1,
num_thread=4,
repeat_last_n=64,
repeat_penalty=1.1,
temperature=0.1,
system="",
tfs_z=1.0,
timeout=30,
top_k=40,
top_p=0.9,
verbose=False,
template="",
)
assert model == mock_instance
base_url = "http://localhost:11434"
result = await component.get_models(base_url)
@patch("langflow.components.models.ollama.ChatOllama")
async def test_build_model_missing_base_url(self, mock_chat_ollama, component_class, default_kwargs):
# Make the mock raise an exception to simulate connection failure
mock_chat_ollama.side_effect = Exception("connection error")
component = component_class(**default_kwargs)
component.base_url = None
with pytest.raises(ValueError, match="Unable to connect to the Ollama API."):
component.build_model()
# Check that the correct URL was used for the GET request
assert result == ["model1"]
assert mock_get.call_count == 1
assert mock_post.call_count == 2
@pytest.mark.asyncio
@patch("langflow.components.models.ollama.httpx.AsyncClient.post")
@patch("langflow.components.models.ollama.httpx.AsyncClient.get")
async def test_get_models_success(self, mock_get, mock_post):
component = ChatOllamaComponent()
mock_get_response = AsyncMock()
mock_get_response.raise_for_status.return_value = None
mock_get_response.json.return_value = {
component.JSON_MODELS_KEY: [
{component.JSON_NAME_KEY: "model1"},
{component.JSON_NAME_KEY: "model2"},
]
}
mock_get.return_value = mock_get_response
mock_post_response = AsyncMock()
mock_post_response.raise_for_status.return_value = None
mock_post_response.json.side_effect = [
{component.JSON_CAPABILITIES_KEY: [component.DESIRED_CAPABILITY]},
{component.JSON_CAPABILITIES_KEY: []},
]
mock_post.return_value = mock_post_response
@pytest.mark.asyncio
@patch("langflow.components.models.ollama.httpx.AsyncClient.get")
async def test_get_models_failure(mock_get, component):
# Simulate a network error for /api/tags
import httpx
base_url = "http://localhost:11434"
result = await component.get_models(base_url)
assert result == ["model1"]
assert mock_get.call_count == 1
assert mock_post.call_count == 2
mock_get.side_effect = httpx.RequestError("Connection error", request=None)
@pytest.mark.asyncio
@patch("langflow.components.models.ollama.httpx.AsyncClient.get")
async def test_get_models_failure(self, mock_get):
import httpx
base_url = "http://localhost:11434"
with pytest.raises(ValueError, match="Could not get model names from Ollama."):
await component.get_models(base_url)
component = ChatOllamaComponent()
mock_get.side_effect = httpx.RequestError("Connection error", request=None)
base_url = "http://localhost:11434"
with pytest.raises(ValueError, match="Could not get model names from Ollama."):
await component.get_models(base_url)
@pytest.mark.asyncio
async def test_update_build_config_mirostat_disabled(self):
component = ChatOllamaComponent()
build_config = {
"mirostat_eta": {"advanced": False, "value": 0.1},
"mirostat_tau": {"advanced": False, "value": 5},
}
field_value = "Disabled"
field_name = "mirostat"
updated_config = await component.update_build_config(build_config, field_value, field_name)
assert updated_config["mirostat_eta"]["advanced"] is True
assert updated_config["mirostat_tau"]["advanced"] is True
assert updated_config["mirostat_eta"]["value"] is None
assert updated_config["mirostat_tau"]["value"] is None
async def test_update_build_config_mirostat_disabled(component):
build_config = {
"mirostat_eta": {"advanced": False, "value": 0.1},
"mirostat_tau": {"advanced": False, "value": 5},
}
field_value = "Disabled"
field_name = "mirostat"
@pytest.mark.asyncio
async def test_update_build_config_mirostat_enabled(self):
component = ChatOllamaComponent()
build_config = {
"mirostat_eta": {"advanced": False, "value": None},
"mirostat_tau": {"advanced": False, "value": None},
}
field_value = "Mirostat 2.0"
field_name = "mirostat"
updated_config = await component.update_build_config(build_config, field_value, field_name)
assert updated_config["mirostat_eta"]["advanced"] is False
assert updated_config["mirostat_tau"]["advanced"] is False
assert updated_config["mirostat_eta"]["value"] == 0.2
assert updated_config["mirostat_tau"]["value"] == 10
updated_config = await component.update_build_config(build_config, field_value, field_name)
@patch("langflow.components.models.ollama.httpx.AsyncClient.get")
@pytest.mark.asyncio
async def test_update_build_config_model_name(self, mock_get):
component = ChatOllamaComponent()
mock_response = MagicMock()
mock_response.json.return_value = {"models": [{"name": "model1"}, {"name": "model2"}]}
mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response
build_config = {
"base_url": {"load_from_db": False, "value": None},
"model_name": {"options": []},
}
field_value = None
field_name = "model_name"
with pytest.raises(ValueError, match="No valid Ollama URL found"):
await component.update_build_config(build_config, field_value, field_name)
assert updated_config["mirostat_eta"]["advanced"] is True
assert updated_config["mirostat_tau"]["advanced"] is True
assert updated_config["mirostat_eta"]["value"] is None
assert updated_config["mirostat_tau"]["value"] is None
@pytest.mark.asyncio
async def test_update_build_config_keep_alive(self):
component = ChatOllamaComponent()
build_config = {"keep_alive": {"value": None, "advanced": False}}
field_value = "Keep"
field_name = "keep_alive_flag"
updated_config = await component.update_build_config(build_config, field_value, field_name)
assert updated_config["keep_alive"]["value"] == "-1"
assert updated_config["keep_alive"]["advanced"] is True
field_value = "Immediately"
updated_config = await component.update_build_config(build_config, field_value, field_name)
assert updated_config["keep_alive"]["value"] == "0"
assert updated_config["keep_alive"]["advanced"] is True
async def test_update_build_config_mirostat_enabled(component):
build_config = {
"mirostat_eta": {"advanced": False, "value": None},
"mirostat_tau": {"advanced": False, "value": None},
}
field_value = "Mirostat 2.0"
field_name = "mirostat"
updated_config = await component.update_build_config(build_config, field_value, field_name)
assert updated_config["mirostat_eta"]["advanced"] is False
assert updated_config["mirostat_tau"]["advanced"] is False
assert updated_config["mirostat_eta"]["value"] == 0.2
assert updated_config["mirostat_tau"]["value"] == 10
@patch("httpx.AsyncClient.get")
async def test_update_build_config_model_name(mock_get, component):
# Mock the response for the HTTP GET request
mock_response = MagicMock()
mock_response.json.return_value = {"models": [{"name": "model1"}, {"name": "model2"}]}
mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response
build_config = {
"base_url": {"load_from_db": False, "value": None},
"model_name": {"options": []},
}
field_value = None
field_name = "model_name"
with pytest.raises(ValueError, match="No valid Ollama URL found"):
await component.update_build_config(build_config, field_value, field_name)
async def test_update_build_config_keep_alive(component):
build_config = {"keep_alive": {"value": None, "advanced": False}}
field_value = "Keep"
field_name = "keep_alive_flag"
updated_config = await component.update_build_config(build_config, field_value, field_name)
assert updated_config["keep_alive"]["value"] == "-1"
assert updated_config["keep_alive"]["advanced"] is True
field_value = "Immediately"
updated_config = await component.update_build_config(build_config, field_value, field_name)
assert updated_config["keep_alive"]["value"] == "0"
assert updated_config["keep_alive"]["advanced"] is True
@patch(
"langchain_community.chat_models.ChatOllama",
return_value=ChatOllama(base_url="http://localhost:11434", model="llama3.1"),
)
def test_build_model(_mock_chat_ollama, component): # noqa: PT019
component.base_url = "http://localhost:11434"
component.model_name = "llama3.1"
component.mirostat = "Mirostat 2.0"
component.mirostat_eta = 0.2 # Ensure this is set as a float
component.mirostat_tau = 10.0 # Ensure this is set as a float
component.temperature = 0.2
component.verbose = True
model = component.build_model()
assert isinstance(model, ChatOllama)
assert model.base_url == "http://localhost:11434"
assert model.model == "llama3.1"
@patch(
"langchain_ollama.ChatOllama",
return_value=ChatOllama(base_url="http://localhost:11434", model="llama3.1"),
)
def test_build_model_integration(self, _mock_chat_ollama): # noqa: PT019
component = ChatOllamaComponent()
component.base_url = "http://localhost:11434"
component.model_name = "llama3.1"
component.mirostat = "Mirostat 2.0"
component.mirostat_eta = 0.2
component.mirostat_tau = 10.0
component.temperature = 0.2
component.verbose = True
model = component.build_model()
assert isinstance(model, ChatOllama)
assert model.base_url == "http://localhost:11434"
assert model.model == "llama3.1"