From 43629b21ade6043852727f4cfdbca25557a16430 Mon Sep 17 00:00:00 2001 From: Edwin Jose Date: Wed, 21 May 2025 12:27:07 -0400 Subject: [PATCH] fix: improve tool calling filter in ollama model component (#8056) * improve tool calling filter * [autofix.ci] apply automated fixes * Update ollama.py * update tests * [autofix.ci] apply automated fixes * fix: correct variable reference for tool model capability check in ChatOllamaComponent * Update test_chatollama_component.py * [autofix.ci] apply automated fixes --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Gabriel Luiz Freitas Almeida Co-authored-by: Carlos Coelho <80289056+carlosrcoelho@users.noreply.github.com> --- .../base/langflow/components/models/ollama.py | 17 +- .../models/test_chatollama_component.py | 299 +++++++++++------- 2 files changed, 184 insertions(+), 132 deletions(-) diff --git a/src/backend/base/langflow/components/models/ollama.py b/src/backend/base/langflow/components/models/ollama.py index 0374e63da..b31996d65 100644 --- a/src/backend/base/langflow/components/models/ollama.py +++ b/src/backend/base/langflow/components/models/ollama.py @@ -6,7 +6,7 @@ import httpx from langchain_ollama import ChatOllama from langflow.base.models.model import LCModelComponent -from langflow.base.models.ollama_constants import OLLAMA_TOOL_MODELS_BASE, URL_LIST +from langflow.base.models.ollama_constants import URL_LIST from langflow.field_typing import LanguageModel from langflow.field_typing.range_spec import RangeSpec from langflow.io import BoolInput, DictInput, DropdownInput, FloatInput, IntInput, MessageTextInput, SliderInput @@ -26,6 +26,7 @@ class ChatOllamaComponent(LCModelComponent): JSON_NAME_KEY = "name" JSON_CAPABILITIES_KEY = "capabilities" DESIRED_CAPABILITY = "completion" + TOOL_CALLING_CAPABILITY = "tools" inputs = [ MessageTextInput( @@ -130,7 +131,7 @@ class ChatOllamaComponent(LCModelComponent): name="tool_model_enabled", display_name="Tool Model Enabled", info="Whether to enable tool calling in the model.", - value=False, + value=True, real_time_refresh=True, ), MessageTextInput( @@ -314,17 +315,13 @@ class ChatOllamaComponent(LCModelComponent): capabilities = json_data.get(self.JSON_CAPABILITIES_KEY, []) logger.debug(f"Model: {model_name}, Capabilities: {capabilities}") - if self.DESIRED_CAPABILITY in capabilities: + if self.DESIRED_CAPABILITY in capabilities and ( + not tool_model_enabled or self.TOOL_CALLING_CAPABILITY in capabilities + ): model_ids.append(model_name) except (httpx.RequestError, ValueError) as e: msg = "Could not get model names from Ollama." raise ValueError(msg) from e - return ( - model_ids if not tool_model_enabled else [model for model in model_ids if self.supports_tool_calling(model)] - ) - - def supports_tool_calling(self, model: str) -> bool: - """Check if model name is in the base of any models example llama3.3 can have 1b and 2b.""" - return any(model.startswith(f"{tool_model}") for tool_model in OLLAMA_TOOL_MODELS_BASE) + return model_ids diff --git a/src/backend/tests/unit/components/models/test_chatollama_component.py b/src/backend/tests/unit/components/models/test_chatollama_component.py index c6b093025..892fc1a38 100644 --- a/src/backend/tests/unit/components/models/test_chatollama_component.py +++ b/src/backend/tests/unit/components/models/test_chatollama_component.py @@ -2,141 +2,196 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest from langchain_ollama import ChatOllama -from langflow.components.models import ChatOllamaComponent +from langflow.components.models.ollama import ChatOllamaComponent + +from tests.base import ComponentTestBaseWithoutClient -@pytest.fixture -def component(): - return ChatOllamaComponent() +class TestChatOllamaComponent(ComponentTestBaseWithoutClient): + @pytest.fixture + def component_class(self): + return ChatOllamaComponent + @pytest.fixture + def default_kwargs(self): + return { + "base_url": "http://localhost:8000", + "model_name": "ollama-model", + "temperature": 0.1, + "format": "json", + "metadata": {}, + "tags": "", + "mirostat": "Disabled", + "num_ctx": 2048, + "num_gpu": 1, + "num_thread": 4, + "repeat_last_n": 64, + "repeat_penalty": 1.1, + "tfs_z": 1.0, + "timeout": 30, + "top_k": 40, + "top_p": 0.9, + "verbose": False, + "tool_model_enabled": True, + } -@pytest.mark.asyncio -@patch("langflow.components.models.ollama.httpx.AsyncClient.post") -@patch("langflow.components.models.ollama.httpx.AsyncClient.get") -async def test_get_models_success(mock_get, mock_post, component): - # The revised approach to get_models filters based on model capabilities. - # It requires one request to ollama to get the models and another to check - # the capabilities of each model. - mock_get_response = AsyncMock() - mock_get_response.raise_for_status.return_value = None - mock_get_response.json.return_value = { - component.JSON_MODELS_KEY: [{component.JSON_NAME_KEY: "model1"}, {component.JSON_NAME_KEY: "model2"}] - } - mock_get.return_value = mock_get_response + @pytest.fixture + def file_names_mapping(self): + # Provide an empty list or the actual mapping if versioned files exist + return [] - # Mock the response for the HTTP POST request to check capabilities. - # Note that this is not exactly what happens if the Ollama server is running, - # but it is a good approximation. - # The first call checks the capabilities of model1, and the second call checks the capabilities of model2. - mock_post_response = AsyncMock() - mock_post_response.raise_for_status.return_value = None - mock_post_response.json.side_effect = [ - {component.JSON_CAPABILITIES_KEY: [component.DESIRED_CAPABILITY]}, - {component.JSON_CAPABILITIES_KEY: []}, - ] - mock_post.return_value = mock_post_response + @patch("langflow.components.models.ollama.ChatOllama") + async def test_build_model(self, mock_chat_ollama, component_class, default_kwargs): + mock_instance = MagicMock() + mock_chat_ollama.return_value = mock_instance + component = component_class(**default_kwargs) + model = component.build_model() + mock_chat_ollama.assert_called_once_with( + base_url="http://localhost:8000", + model="ollama-model", + mirostat=0, + format="json", + metadata={}, + num_ctx=2048, + num_gpu=1, + num_thread=4, + repeat_last_n=64, + repeat_penalty=1.1, + temperature=0.1, + system="", + tfs_z=1.0, + timeout=30, + top_k=40, + top_p=0.9, + verbose=False, + template="", + ) + assert model == mock_instance - base_url = "http://localhost:11434" - result = await component.get_models(base_url) + @patch("langflow.components.models.ollama.ChatOllama") + async def test_build_model_missing_base_url(self, mock_chat_ollama, component_class, default_kwargs): + # Make the mock raise an exception to simulate connection failure + mock_chat_ollama.side_effect = Exception("connection error") + component = component_class(**default_kwargs) + component.base_url = None + with pytest.raises(ValueError, match="Unable to connect to the Ollama API."): + component.build_model() - # Check that the correct URL was used for the GET request - assert result == ["model1"] - assert mock_get.call_count == 1 - assert mock_post.call_count == 2 + @pytest.mark.asyncio + @patch("langflow.components.models.ollama.httpx.AsyncClient.post") + @patch("langflow.components.models.ollama.httpx.AsyncClient.get") + async def test_get_models_success(self, mock_get, mock_post): + component = ChatOllamaComponent() + mock_get_response = AsyncMock() + mock_get_response.raise_for_status.return_value = None + mock_get_response.json.return_value = { + component.JSON_MODELS_KEY: [ + {component.JSON_NAME_KEY: "model1"}, + {component.JSON_NAME_KEY: "model2"}, + ] + } + mock_get.return_value = mock_get_response + mock_post_response = AsyncMock() + mock_post_response.raise_for_status.return_value = None + mock_post_response.json.side_effect = [ + {component.JSON_CAPABILITIES_KEY: [component.DESIRED_CAPABILITY]}, + {component.JSON_CAPABILITIES_KEY: []}, + ] + mock_post.return_value = mock_post_response -@pytest.mark.asyncio -@patch("langflow.components.models.ollama.httpx.AsyncClient.get") -async def test_get_models_failure(mock_get, component): - # Simulate a network error for /api/tags - import httpx + base_url = "http://localhost:11434" + result = await component.get_models(base_url) + assert result == ["model1"] + assert mock_get.call_count == 1 + assert mock_post.call_count == 2 - mock_get.side_effect = httpx.RequestError("Connection error", request=None) + @pytest.mark.asyncio + @patch("langflow.components.models.ollama.httpx.AsyncClient.get") + async def test_get_models_failure(self, mock_get): + import httpx - base_url = "http://localhost:11434" - with pytest.raises(ValueError, match="Could not get model names from Ollama."): - await component.get_models(base_url) + component = ChatOllamaComponent() + mock_get.side_effect = httpx.RequestError("Connection error", request=None) + base_url = "http://localhost:11434" + with pytest.raises(ValueError, match="Could not get model names from Ollama."): + await component.get_models(base_url) + @pytest.mark.asyncio + async def test_update_build_config_mirostat_disabled(self): + component = ChatOllamaComponent() + build_config = { + "mirostat_eta": {"advanced": False, "value": 0.1}, + "mirostat_tau": {"advanced": False, "value": 5}, + } + field_value = "Disabled" + field_name = "mirostat" + updated_config = await component.update_build_config(build_config, field_value, field_name) + assert updated_config["mirostat_eta"]["advanced"] is True + assert updated_config["mirostat_tau"]["advanced"] is True + assert updated_config["mirostat_eta"]["value"] is None + assert updated_config["mirostat_tau"]["value"] is None -async def test_update_build_config_mirostat_disabled(component): - build_config = { - "mirostat_eta": {"advanced": False, "value": 0.1}, - "mirostat_tau": {"advanced": False, "value": 5}, - } - field_value = "Disabled" - field_name = "mirostat" + @pytest.mark.asyncio + async def test_update_build_config_mirostat_enabled(self): + component = ChatOllamaComponent() + build_config = { + "mirostat_eta": {"advanced": False, "value": None}, + "mirostat_tau": {"advanced": False, "value": None}, + } + field_value = "Mirostat 2.0" + field_name = "mirostat" + updated_config = await component.update_build_config(build_config, field_value, field_name) + assert updated_config["mirostat_eta"]["advanced"] is False + assert updated_config["mirostat_tau"]["advanced"] is False + assert updated_config["mirostat_eta"]["value"] == 0.2 + assert updated_config["mirostat_tau"]["value"] == 10 - updated_config = await component.update_build_config(build_config, field_value, field_name) + @patch("langflow.components.models.ollama.httpx.AsyncClient.get") + @pytest.mark.asyncio + async def test_update_build_config_model_name(self, mock_get): + component = ChatOllamaComponent() + mock_response = MagicMock() + mock_response.json.return_value = {"models": [{"name": "model1"}, {"name": "model2"}]} + mock_response.raise_for_status.return_value = None + mock_get.return_value = mock_response + build_config = { + "base_url": {"load_from_db": False, "value": None}, + "model_name": {"options": []}, + } + field_value = None + field_name = "model_name" + with pytest.raises(ValueError, match="No valid Ollama URL found"): + await component.update_build_config(build_config, field_value, field_name) - assert updated_config["mirostat_eta"]["advanced"] is True - assert updated_config["mirostat_tau"]["advanced"] is True - assert updated_config["mirostat_eta"]["value"] is None - assert updated_config["mirostat_tau"]["value"] is None + @pytest.mark.asyncio + async def test_update_build_config_keep_alive(self): + component = ChatOllamaComponent() + build_config = {"keep_alive": {"value": None, "advanced": False}} + field_value = "Keep" + field_name = "keep_alive_flag" + updated_config = await component.update_build_config(build_config, field_value, field_name) + assert updated_config["keep_alive"]["value"] == "-1" + assert updated_config["keep_alive"]["advanced"] is True + field_value = "Immediately" + updated_config = await component.update_build_config(build_config, field_value, field_name) + assert updated_config["keep_alive"]["value"] == "0" + assert updated_config["keep_alive"]["advanced"] is True - -async def test_update_build_config_mirostat_enabled(component): - build_config = { - "mirostat_eta": {"advanced": False, "value": None}, - "mirostat_tau": {"advanced": False, "value": None}, - } - field_value = "Mirostat 2.0" - field_name = "mirostat" - - updated_config = await component.update_build_config(build_config, field_value, field_name) - - assert updated_config["mirostat_eta"]["advanced"] is False - assert updated_config["mirostat_tau"]["advanced"] is False - assert updated_config["mirostat_eta"]["value"] == 0.2 - assert updated_config["mirostat_tau"]["value"] == 10 - - -@patch("httpx.AsyncClient.get") -async def test_update_build_config_model_name(mock_get, component): - # Mock the response for the HTTP GET request - mock_response = MagicMock() - mock_response.json.return_value = {"models": [{"name": "model1"}, {"name": "model2"}]} - mock_response.raise_for_status.return_value = None - mock_get.return_value = mock_response - - build_config = { - "base_url": {"load_from_db": False, "value": None}, - "model_name": {"options": []}, - } - field_value = None - field_name = "model_name" - - with pytest.raises(ValueError, match="No valid Ollama URL found"): - await component.update_build_config(build_config, field_value, field_name) - - -async def test_update_build_config_keep_alive(component): - build_config = {"keep_alive": {"value": None, "advanced": False}} - field_value = "Keep" - field_name = "keep_alive_flag" - - updated_config = await component.update_build_config(build_config, field_value, field_name) - assert updated_config["keep_alive"]["value"] == "-1" - assert updated_config["keep_alive"]["advanced"] is True - - field_value = "Immediately" - updated_config = await component.update_build_config(build_config, field_value, field_name) - assert updated_config["keep_alive"]["value"] == "0" - assert updated_config["keep_alive"]["advanced"] is True - - -@patch( - "langchain_community.chat_models.ChatOllama", - return_value=ChatOllama(base_url="http://localhost:11434", model="llama3.1"), -) -def test_build_model(_mock_chat_ollama, component): # noqa: PT019 - component.base_url = "http://localhost:11434" - component.model_name = "llama3.1" - component.mirostat = "Mirostat 2.0" - component.mirostat_eta = 0.2 # Ensure this is set as a float - component.mirostat_tau = 10.0 # Ensure this is set as a float - component.temperature = 0.2 - component.verbose = True - model = component.build_model() - assert isinstance(model, ChatOllama) - assert model.base_url == "http://localhost:11434" - assert model.model == "llama3.1" + @patch( + "langchain_ollama.ChatOllama", + return_value=ChatOllama(base_url="http://localhost:11434", model="llama3.1"), + ) + def test_build_model_integration(self, _mock_chat_ollama): # noqa: PT019 + component = ChatOllamaComponent() + component.base_url = "http://localhost:11434" + component.model_name = "llama3.1" + component.mirostat = "Mirostat 2.0" + component.mirostat_eta = 0.2 + component.mirostat_tau = 10.0 + component.temperature = 0.2 + component.verbose = True + model = component.build_model() + assert isinstance(model, ChatOllama) + assert model.base_url == "http://localhost:11434" + assert model.model == "llama3.1"