fix: improvement to ollama component to allow for dynamic filtering based on model capabilities (#7696)
* Updated model filtering to avoid hard coding of named based exclusions * Stylistic adjustments * Remove accidentally added package-lock.json from PR * revert removal of package lock * Modifications to the UT and changed component to be more async * [autofix.ci] apply automated fixes * Lint --------- Co-authored-by: Edwin Jose <edwin.jose@datastax.com> Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
744df31f18
commit
12f35a0edc
2 changed files with 96 additions and 43 deletions
|
|
@ -5,10 +5,11 @@ import httpx
|
|||
from langchain_ollama import ChatOllama
|
||||
|
||||
from langflow.base.models.model import LCModelComponent
|
||||
from langflow.base.models.ollama_constants import OLLAMA_EMBEDDING_MODELS, OLLAMA_TOOL_MODELS_BASE, URL_LIST
|
||||
from langflow.base.models.ollama_constants import OLLAMA_TOOL_MODELS_BASE, URL_LIST
|
||||
from langflow.field_typing import LanguageModel
|
||||
from langflow.field_typing.range_spec import RangeSpec
|
||||
from langflow.io import BoolInput, DictInput, DropdownInput, FloatInput, IntInput, MessageTextInput, SliderInput
|
||||
from langflow.logging import logger
|
||||
|
||||
HTTP_STATUS_OK = 200
|
||||
|
||||
|
|
@ -19,6 +20,12 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
icon = "Ollama"
|
||||
name = "OllamaModel"
|
||||
|
||||
# Define constants for JSON keys
|
||||
JSON_MODELS_KEY = "models"
|
||||
JSON_NAME_KEY = "name"
|
||||
JSON_CAPABILITIES_KEY = "capabilities"
|
||||
DESIRED_CAPABILITY = "completion"
|
||||
|
||||
inputs = [
|
||||
MessageTextInput(
|
||||
name="base_url",
|
||||
|
|
@ -229,10 +236,10 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
if field_name in {"model_name", "base_url", "tool_model_enabled"}:
|
||||
if await self.is_valid_ollama_url(self.base_url):
|
||||
tool_model_enabled = build_config["tool_model_enabled"].get("value", False) or self.tool_model_enabled
|
||||
build_config["model_name"]["options"] = await self.get_model(self.base_url, tool_model_enabled)
|
||||
build_config["model_name"]["options"] = await self.get_models(self.base_url, tool_model_enabled)
|
||||
elif await self.is_valid_ollama_url(build_config["base_url"].get("value", "")):
|
||||
tool_model_enabled = build_config["tool_model_enabled"].get("value", False) or self.tool_model_enabled
|
||||
build_config["model_name"]["options"] = await self.get_model(
|
||||
build_config["model_name"]["options"] = await self.get_models(
|
||||
build_config["base_url"].get("value", ""), tool_model_enabled
|
||||
)
|
||||
else:
|
||||
|
|
@ -249,30 +256,59 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
|
||||
return build_config
|
||||
|
||||
async def get_model(self, base_url_value: str, tool_model_enabled: bool | None = None) -> list[str]:
|
||||
async def get_models(self, base_url_value: str, tool_model_enabled: bool | None = None) -> list[str]:
|
||||
"""Fetches a list of models from the Ollama API that do not have the "embedding" capability.
|
||||
|
||||
Args:
|
||||
base_url_value (str): The base URL of the Ollama API.
|
||||
tool_model_enabled (bool | None, optional): If True, filters the models further to include
|
||||
only those that support tool calling. Defaults to None.
|
||||
|
||||
Returns:
|
||||
list[str]: A list of model names that do not have the "embedding" capability. If
|
||||
`tool_model_enabled` is True, only models supporting tool calling are included.
|
||||
|
||||
Raises:
|
||||
ValueError: If there is an issue with the API request or response, or if the model
|
||||
names cannot be retrieved.
|
||||
"""
|
||||
try:
|
||||
url = urljoin(base_url_value, "api/tags")
|
||||
# Normalize the base URL to avoid the repeated "/" at the end
|
||||
base_url = base_url_value.rstrip("/") + "/"
|
||||
|
||||
# Ollama REST API to return models
|
||||
tags_url = urljoin(base_url, "api/tags")
|
||||
|
||||
# Ollama REST API to return model capabilities
|
||||
show_url = urljoin(base_url, "api/show")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
# Fetch available models
|
||||
tags_response = await client.get(tags_url)
|
||||
tags_response.raise_for_status()
|
||||
models = await tags_response.json()
|
||||
logger.debug(f"Available models: {models}")
|
||||
|
||||
model_ids = [model["name"] for model in data.get("models", [])]
|
||||
# this to ensure that not embedding models are included.
|
||||
# not even the base models since models can have 1b 2b etc
|
||||
# handles cases when embeddings models have tags like :latest - etc.
|
||||
model_ids = [
|
||||
model
|
||||
for model in model_ids
|
||||
if not any(
|
||||
model == embedding_model or model.startswith(embedding_model.split("-")[0])
|
||||
for embedding_model in OLLAMA_EMBEDDING_MODELS
|
||||
)
|
||||
]
|
||||
# Filter models that are NOT embedding models
|
||||
model_ids = []
|
||||
for model in models[self.JSON_MODELS_KEY]:
|
||||
model_name = model[self.JSON_NAME_KEY]
|
||||
logger.debug(f"Checking model: {model_name}")
|
||||
|
||||
except (ImportError, ValueError, httpx.RequestError, Exception) as e:
|
||||
payload = {"model": model_name}
|
||||
show_response = await client.post(show_url, json=payload)
|
||||
show_response.raise_for_status()
|
||||
json_data = await show_response.json()
|
||||
capabilities = json_data.get(self.JSON_CAPABILITIES_KEY, [])
|
||||
logger.debug(f"Model: {model_name}, Capabilities: {capabilities}")
|
||||
|
||||
if self.DESIRED_CAPABILITY in capabilities:
|
||||
model_ids.append(model_name)
|
||||
|
||||
except (httpx.RequestError, ValueError) as e:
|
||||
msg = "Could not get model names from Ollama."
|
||||
raise ValueError(msg) from e
|
||||
|
||||
return (
|
||||
model_ids if not tool_model_enabled else [model for model in model_ids if self.supports_tool_calling(model)]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
from urllib.parse import urljoin
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langchain_ollama import ChatOllama
|
||||
|
|
@ -11,34 +10,52 @@ def component():
|
|||
return ChatOllamaComponent()
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.get")
|
||||
async def test_get_model_success(mock_get, component):
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"models": [{"name": "model1"}, {"name": "model2"}]}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_get.return_value = mock_response
|
||||
@pytest.mark.asyncio
|
||||
@patch("langflow.components.models.ollama.httpx.AsyncClient.post")
|
||||
@patch("langflow.components.models.ollama.httpx.AsyncClient.get")
|
||||
async def test_get_models_success(mock_get, mock_post, component):
|
||||
# The revised approach to get_models filters based on model capabilities.
|
||||
# It requires one request to ollama to get the models and another to check
|
||||
# the capabilities of each model.
|
||||
mock_get_response = AsyncMock()
|
||||
mock_get_response.raise_for_status.return_value = None
|
||||
mock_get_response.json.return_value = {
|
||||
component.JSON_MODELS_KEY: [{component.JSON_NAME_KEY: "model1"}, {component.JSON_NAME_KEY: "model2"}]
|
||||
}
|
||||
mock_get.return_value = mock_get_response
|
||||
|
||||
# Mock the response for the HTTP POST request to check capabilities.
|
||||
# Note that this is not exactly what happens if the Ollama server is running,
|
||||
# but it is a good approximation.
|
||||
# The first call checks the capabilities of model1, and the second call checks the capabilities of model2.
|
||||
mock_post_response = AsyncMock()
|
||||
mock_post_response.raise_for_status.return_value = None
|
||||
mock_post_response.json.side_effect = [
|
||||
{component.JSON_CAPABILITIES_KEY: [component.DESIRED_CAPABILITY]},
|
||||
{component.JSON_CAPABILITIES_KEY: []},
|
||||
]
|
||||
mock_post.return_value = mock_post_response
|
||||
|
||||
base_url = "http://localhost:11434"
|
||||
result = await component.get_models(base_url)
|
||||
|
||||
model_names = await component.get_model(base_url)
|
||||
|
||||
expected_url = urljoin(base_url, "/api/tags")
|
||||
|
||||
mock_get.assert_called_once_with(expected_url)
|
||||
|
||||
assert model_names == ["model1", "model2"]
|
||||
# Check that the correct URL was used for the GET request
|
||||
assert result == ["model1"]
|
||||
assert mock_get.call_count == 1
|
||||
assert mock_post.call_count == 2
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.get")
|
||||
async def test_get_model_failure(mock_get, component):
|
||||
# Mock the response for the HTTP GET request to raise an exception
|
||||
mock_get.side_effect = Exception("HTTP request failed")
|
||||
@pytest.mark.asyncio
|
||||
@patch("langflow.components.models.ollama.httpx.AsyncClient.get")
|
||||
async def test_get_models_failure(mock_get, component):
|
||||
# Simulate a network error for /api/tags
|
||||
import httpx
|
||||
|
||||
url = "http://localhost:11434/"
|
||||
mock_get.side_effect = httpx.RequestError("Connection error", request=None)
|
||||
|
||||
# Assert that the ValueError is raised when an exception occurs
|
||||
base_url = "http://localhost:11434"
|
||||
with pytest.raises(ValueError, match="Could not get model names from Ollama."):
|
||||
await component.get_model(base_url_value=url)
|
||||
await component.get_models(base_url)
|
||||
|
||||
|
||||
async def test_update_build_config_mirostat_disabled(component):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue