fix: improvement to ollama component to allow for dynamic filtering based on model capabilities (#7696)

* Updated model filtering to avoid hard coding of named based exclusions

* Stylistic adjustments

* Remove accidentally added package-lock.json from PR

* revert removal of package lock

* Modifications to the UT and changed component to be more async

* [autofix.ci] apply automated fixes

* Lint

---------

Co-authored-by: Edwin Jose <edwin.jose@datastax.com>
Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Pedro Pacheco 2025-05-06 11:10:32 -06:00 committed by GitHub
commit 12f35a0edc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 96 additions and 43 deletions

View file

@ -5,10 +5,11 @@ import httpx
from langchain_ollama import ChatOllama
from langflow.base.models.model import LCModelComponent
from langflow.base.models.ollama_constants import OLLAMA_EMBEDDING_MODELS, OLLAMA_TOOL_MODELS_BASE, URL_LIST
from langflow.base.models.ollama_constants import OLLAMA_TOOL_MODELS_BASE, URL_LIST
from langflow.field_typing import LanguageModel
from langflow.field_typing.range_spec import RangeSpec
from langflow.io import BoolInput, DictInput, DropdownInput, FloatInput, IntInput, MessageTextInput, SliderInput
from langflow.logging import logger
HTTP_STATUS_OK = 200
@ -19,6 +20,12 @@ class ChatOllamaComponent(LCModelComponent):
icon = "Ollama"
name = "OllamaModel"
# Define constants for JSON keys
JSON_MODELS_KEY = "models"
JSON_NAME_KEY = "name"
JSON_CAPABILITIES_KEY = "capabilities"
DESIRED_CAPABILITY = "completion"
inputs = [
MessageTextInput(
name="base_url",
@ -229,10 +236,10 @@ class ChatOllamaComponent(LCModelComponent):
if field_name in {"model_name", "base_url", "tool_model_enabled"}:
if await self.is_valid_ollama_url(self.base_url):
tool_model_enabled = build_config["tool_model_enabled"].get("value", False) or self.tool_model_enabled
build_config["model_name"]["options"] = await self.get_model(self.base_url, tool_model_enabled)
build_config["model_name"]["options"] = await self.get_models(self.base_url, tool_model_enabled)
elif await self.is_valid_ollama_url(build_config["base_url"].get("value", "")):
tool_model_enabled = build_config["tool_model_enabled"].get("value", False) or self.tool_model_enabled
build_config["model_name"]["options"] = await self.get_model(
build_config["model_name"]["options"] = await self.get_models(
build_config["base_url"].get("value", ""), tool_model_enabled
)
else:
@ -249,30 +256,59 @@ class ChatOllamaComponent(LCModelComponent):
return build_config
async def get_model(self, base_url_value: str, tool_model_enabled: bool | None = None) -> list[str]:
async def get_models(self, base_url_value: str, tool_model_enabled: bool | None = None) -> list[str]:
"""Fetches a list of models from the Ollama API that do not have the "embedding" capability.
Args:
base_url_value (str): The base URL of the Ollama API.
tool_model_enabled (bool | None, optional): If True, filters the models further to include
only those that support tool calling. Defaults to None.
Returns:
list[str]: A list of model names that do not have the "embedding" capability. If
`tool_model_enabled` is True, only models supporting tool calling are included.
Raises:
ValueError: If there is an issue with the API request or response, or if the model
names cannot be retrieved.
"""
try:
url = urljoin(base_url_value, "api/tags")
# Normalize the base URL to avoid the repeated "/" at the end
base_url = base_url_value.rstrip("/") + "/"
# Ollama REST API to return models
tags_url = urljoin(base_url, "api/tags")
# Ollama REST API to return model capabilities
show_url = urljoin(base_url, "api/show")
async with httpx.AsyncClient() as client:
response = await client.get(url)
response.raise_for_status()
data = response.json()
# Fetch available models
tags_response = await client.get(tags_url)
tags_response.raise_for_status()
models = await tags_response.json()
logger.debug(f"Available models: {models}")
model_ids = [model["name"] for model in data.get("models", [])]
# this to ensure that not embedding models are included.
# not even the base models since models can have 1b 2b etc
# handles cases when embeddings models have tags like :latest - etc.
model_ids = [
model
for model in model_ids
if not any(
model == embedding_model or model.startswith(embedding_model.split("-")[0])
for embedding_model in OLLAMA_EMBEDDING_MODELS
)
]
# Filter models that are NOT embedding models
model_ids = []
for model in models[self.JSON_MODELS_KEY]:
model_name = model[self.JSON_NAME_KEY]
logger.debug(f"Checking model: {model_name}")
except (ImportError, ValueError, httpx.RequestError, Exception) as e:
payload = {"model": model_name}
show_response = await client.post(show_url, json=payload)
show_response.raise_for_status()
json_data = await show_response.json()
capabilities = json_data.get(self.JSON_CAPABILITIES_KEY, [])
logger.debug(f"Model: {model_name}, Capabilities: {capabilities}")
if self.DESIRED_CAPABILITY in capabilities:
model_ids.append(model_name)
except (httpx.RequestError, ValueError) as e:
msg = "Could not get model names from Ollama."
raise ValueError(msg) from e
return (
model_ids if not tool_model_enabled else [model for model in model_ids if self.supports_tool_calling(model)]
)

View file

@ -1,5 +1,4 @@
from unittest.mock import MagicMock, patch
from urllib.parse import urljoin
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from langchain_ollama import ChatOllama
@ -11,34 +10,52 @@ def component():
return ChatOllamaComponent()
@patch("httpx.AsyncClient.get")
async def test_get_model_success(mock_get, component):
mock_response = MagicMock()
mock_response.json.return_value = {"models": [{"name": "model1"}, {"name": "model2"}]}
mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response
@pytest.mark.asyncio
@patch("langflow.components.models.ollama.httpx.AsyncClient.post")
@patch("langflow.components.models.ollama.httpx.AsyncClient.get")
async def test_get_models_success(mock_get, mock_post, component):
# The revised approach to get_models filters based on model capabilities.
# It requires one request to ollama to get the models and another to check
# the capabilities of each model.
mock_get_response = AsyncMock()
mock_get_response.raise_for_status.return_value = None
mock_get_response.json.return_value = {
component.JSON_MODELS_KEY: [{component.JSON_NAME_KEY: "model1"}, {component.JSON_NAME_KEY: "model2"}]
}
mock_get.return_value = mock_get_response
# Mock the response for the HTTP POST request to check capabilities.
# Note that this is not exactly what happens if the Ollama server is running,
# but it is a good approximation.
# The first call checks the capabilities of model1, and the second call checks the capabilities of model2.
mock_post_response = AsyncMock()
mock_post_response.raise_for_status.return_value = None
mock_post_response.json.side_effect = [
{component.JSON_CAPABILITIES_KEY: [component.DESIRED_CAPABILITY]},
{component.JSON_CAPABILITIES_KEY: []},
]
mock_post.return_value = mock_post_response
base_url = "http://localhost:11434"
result = await component.get_models(base_url)
model_names = await component.get_model(base_url)
expected_url = urljoin(base_url, "/api/tags")
mock_get.assert_called_once_with(expected_url)
assert model_names == ["model1", "model2"]
# Check that the correct URL was used for the GET request
assert result == ["model1"]
assert mock_get.call_count == 1
assert mock_post.call_count == 2
@patch("httpx.AsyncClient.get")
async def test_get_model_failure(mock_get, component):
# Mock the response for the HTTP GET request to raise an exception
mock_get.side_effect = Exception("HTTP request failed")
@pytest.mark.asyncio
@patch("langflow.components.models.ollama.httpx.AsyncClient.get")
async def test_get_models_failure(mock_get, component):
# Simulate a network error for /api/tags
import httpx
url = "http://localhost:11434/"
mock_get.side_effect = httpx.RequestError("Connection error", request=None)
# Assert that the ValueError is raised when an exception occurs
base_url = "http://localhost:11434"
with pytest.raises(ValueError, match="Could not get model names from Ollama."):
await component.get_model(base_url_value=url)
await component.get_models(base_url)
async def test_update_build_config_mirostat_disabled(component):