From 12f35a0edc29037bbe7776d455ee419a5410de70 Mon Sep 17 00:00:00 2001 From: Pedro Pacheco <3083335+pedrocassalpacheco@users.noreply.github.com> Date: Tue, 6 May 2025 11:10:32 -0600 Subject: [PATCH] fix: improvement to ollama component to allow for dynamic filtering based on model capabilities (#7696) * Updated model filtering to avoid hard coding of named based exclusions * Stylistic adjustments * Remove accidentally added package-lock.json from PR * revert removal of package lock * Modifications to the UT and changed component to be more async * [autofix.ci] apply automated fixes * Lint --------- Co-authored-by: Edwin Jose Co-authored-by: Gabriel Luiz Freitas Almeida Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../base/langflow/components/models/ollama.py | 78 ++++++++++++++----- .../models/test_chatollama_component.py | 61 +++++++++------ 2 files changed, 96 insertions(+), 43 deletions(-) diff --git a/src/backend/base/langflow/components/models/ollama.py b/src/backend/base/langflow/components/models/ollama.py index ddd48ef10..ec5c79d30 100644 --- a/src/backend/base/langflow/components/models/ollama.py +++ b/src/backend/base/langflow/components/models/ollama.py @@ -5,10 +5,11 @@ import httpx from langchain_ollama import ChatOllama from langflow.base.models.model import LCModelComponent -from langflow.base.models.ollama_constants import OLLAMA_EMBEDDING_MODELS, OLLAMA_TOOL_MODELS_BASE, URL_LIST +from langflow.base.models.ollama_constants import OLLAMA_TOOL_MODELS_BASE, URL_LIST from langflow.field_typing import LanguageModel from langflow.field_typing.range_spec import RangeSpec from langflow.io import BoolInput, DictInput, DropdownInput, FloatInput, IntInput, MessageTextInput, SliderInput +from langflow.logging import logger HTTP_STATUS_OK = 200 @@ -19,6 +20,12 @@ class ChatOllamaComponent(LCModelComponent): icon = "Ollama" name = "OllamaModel" + # Define constants for JSON keys + JSON_MODELS_KEY = "models" + JSON_NAME_KEY = "name" + JSON_CAPABILITIES_KEY = "capabilities" + DESIRED_CAPABILITY = "completion" + inputs = [ MessageTextInput( name="base_url", @@ -229,10 +236,10 @@ class ChatOllamaComponent(LCModelComponent): if field_name in {"model_name", "base_url", "tool_model_enabled"}: if await self.is_valid_ollama_url(self.base_url): tool_model_enabled = build_config["tool_model_enabled"].get("value", False) or self.tool_model_enabled - build_config["model_name"]["options"] = await self.get_model(self.base_url, tool_model_enabled) + build_config["model_name"]["options"] = await self.get_models(self.base_url, tool_model_enabled) elif await self.is_valid_ollama_url(build_config["base_url"].get("value", "")): tool_model_enabled = build_config["tool_model_enabled"].get("value", False) or self.tool_model_enabled - build_config["model_name"]["options"] = await self.get_model( + build_config["model_name"]["options"] = await self.get_models( build_config["base_url"].get("value", ""), tool_model_enabled ) else: @@ -249,30 +256,59 @@ class ChatOllamaComponent(LCModelComponent): return build_config - async def get_model(self, base_url_value: str, tool_model_enabled: bool | None = None) -> list[str]: + async def get_models(self, base_url_value: str, tool_model_enabled: bool | None = None) -> list[str]: + """Fetches a list of models from the Ollama API that do not have the "embedding" capability. + + Args: + base_url_value (str): The base URL of the Ollama API. + tool_model_enabled (bool | None, optional): If True, filters the models further to include + only those that support tool calling. Defaults to None. + + Returns: + list[str]: A list of model names that do not have the "embedding" capability. If + `tool_model_enabled` is True, only models supporting tool calling are included. + + Raises: + ValueError: If there is an issue with the API request or response, or if the model + names cannot be retrieved. + """ try: - url = urljoin(base_url_value, "api/tags") + # Normalize the base URL to avoid the repeated "/" at the end + base_url = base_url_value.rstrip("/") + "/" + + # Ollama REST API to return models + tags_url = urljoin(base_url, "api/tags") + + # Ollama REST API to return model capabilities + show_url = urljoin(base_url, "api/show") + async with httpx.AsyncClient() as client: - response = await client.get(url) - response.raise_for_status() - data = response.json() + # Fetch available models + tags_response = await client.get(tags_url) + tags_response.raise_for_status() + models = await tags_response.json() + logger.debug(f"Available models: {models}") - model_ids = [model["name"] for model in data.get("models", [])] - # this to ensure that not embedding models are included. - # not even the base models since models can have 1b 2b etc - # handles cases when embeddings models have tags like :latest - etc. - model_ids = [ - model - for model in model_ids - if not any( - model == embedding_model or model.startswith(embedding_model.split("-")[0]) - for embedding_model in OLLAMA_EMBEDDING_MODELS - ) - ] + # Filter models that are NOT embedding models + model_ids = [] + for model in models[self.JSON_MODELS_KEY]: + model_name = model[self.JSON_NAME_KEY] + logger.debug(f"Checking model: {model_name}") - except (ImportError, ValueError, httpx.RequestError, Exception) as e: + payload = {"model": model_name} + show_response = await client.post(show_url, json=payload) + show_response.raise_for_status() + json_data = await show_response.json() + capabilities = json_data.get(self.JSON_CAPABILITIES_KEY, []) + logger.debug(f"Model: {model_name}, Capabilities: {capabilities}") + + if self.DESIRED_CAPABILITY in capabilities: + model_ids.append(model_name) + + except (httpx.RequestError, ValueError) as e: msg = "Could not get model names from Ollama." raise ValueError(msg) from e + return ( model_ids if not tool_model_enabled else [model for model in model_ids if self.supports_tool_calling(model)] ) diff --git a/src/backend/tests/unit/components/models/test_chatollama_component.py b/src/backend/tests/unit/components/models/test_chatollama_component.py index f4519629a..c6b093025 100644 --- a/src/backend/tests/unit/components/models/test_chatollama_component.py +++ b/src/backend/tests/unit/components/models/test_chatollama_component.py @@ -1,5 +1,4 @@ -from unittest.mock import MagicMock, patch -from urllib.parse import urljoin +from unittest.mock import AsyncMock, MagicMock, patch import pytest from langchain_ollama import ChatOllama @@ -11,34 +10,52 @@ def component(): return ChatOllamaComponent() -@patch("httpx.AsyncClient.get") -async def test_get_model_success(mock_get, component): - mock_response = MagicMock() - mock_response.json.return_value = {"models": [{"name": "model1"}, {"name": "model2"}]} - mock_response.raise_for_status.return_value = None - mock_get.return_value = mock_response +@pytest.mark.asyncio +@patch("langflow.components.models.ollama.httpx.AsyncClient.post") +@patch("langflow.components.models.ollama.httpx.AsyncClient.get") +async def test_get_models_success(mock_get, mock_post, component): + # The revised approach to get_models filters based on model capabilities. + # It requires one request to ollama to get the models and another to check + # the capabilities of each model. + mock_get_response = AsyncMock() + mock_get_response.raise_for_status.return_value = None + mock_get_response.json.return_value = { + component.JSON_MODELS_KEY: [{component.JSON_NAME_KEY: "model1"}, {component.JSON_NAME_KEY: "model2"}] + } + mock_get.return_value = mock_get_response + + # Mock the response for the HTTP POST request to check capabilities. + # Note that this is not exactly what happens if the Ollama server is running, + # but it is a good approximation. + # The first call checks the capabilities of model1, and the second call checks the capabilities of model2. + mock_post_response = AsyncMock() + mock_post_response.raise_for_status.return_value = None + mock_post_response.json.side_effect = [ + {component.JSON_CAPABILITIES_KEY: [component.DESIRED_CAPABILITY]}, + {component.JSON_CAPABILITIES_KEY: []}, + ] + mock_post.return_value = mock_post_response base_url = "http://localhost:11434" + result = await component.get_models(base_url) - model_names = await component.get_model(base_url) - - expected_url = urljoin(base_url, "/api/tags") - - mock_get.assert_called_once_with(expected_url) - - assert model_names == ["model1", "model2"] + # Check that the correct URL was used for the GET request + assert result == ["model1"] + assert mock_get.call_count == 1 + assert mock_post.call_count == 2 -@patch("httpx.AsyncClient.get") -async def test_get_model_failure(mock_get, component): - # Mock the response for the HTTP GET request to raise an exception - mock_get.side_effect = Exception("HTTP request failed") +@pytest.mark.asyncio +@patch("langflow.components.models.ollama.httpx.AsyncClient.get") +async def test_get_models_failure(mock_get, component): + # Simulate a network error for /api/tags + import httpx - url = "http://localhost:11434/" + mock_get.side_effect = httpx.RequestError("Connection error", request=None) - # Assert that the ValueError is raised when an exception occurs + base_url = "http://localhost:11434" with pytest.raises(ValueError, match="Could not get model names from Ollama."): - await component.get_model(base_url_value=url) + await component.get_models(base_url) async def test_update_build_config_mirostat_disabled(component):