feat: Add model filtering support for Ollama Component, improving stability (#5748)

* Update ollama.py

* ollama models refactor

* ollama embeddings support for model filters

* formatting

* Update src/backend/base/langflow/components/embeddings/ollama.py

Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>

* [autofix.ci] apply automated fixes

* fix test

* reverting test

* refactor: Update Ollama components to use async URL validation

- Changed `is_valid_ollama_url` method to be asynchronous in both `ollama.py` files.
- Updated calls to `is_valid_ollama_url` to use `await` for proper async handling.
- Modified URL validation logic to ensure compatibility with async operations.
- Improved overall responsiveness of the Ollama components by leveraging async HTTP requests.

* reverting to empty list!

---------

Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Edwin Jose 2025-01-20 12:00:40 -05:00 committed by GitHub
commit bbb3987cda
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 244 additions and 80 deletions

View file

@ -0,0 +1,47 @@
# https://ollama.com/search?c=embedding
OLLAMA_EMBEDDING_MODELS = [
"nomic-embed-text",
"mxbai-embed-large",
"snowflake-arctic-embed",
"all-minilm",
"bge-m3",
"paraphrase-multilingual",
"granite-embedding",
"jina-embeddings-v2-base-en",
]
# https://ollama.com/search?c=tools
OLLAMA_TOOL_MODELS_BASE = [
"llama3.3",
"qwq",
"llama3.2",
"llama3.1",
"mistral",
"qwen2",
"qwen2.5",
"qwen2.5-coder",
"mistral-nemo",
"mixtral",
"command-r",
"command-r-plus",
"mistral-large",
"smollm2",
"hermes3",
"athene-v2",
"mistral-small",
"nemotron-mini",
"nemotron",
"llama3-groq-tool-use",
"granite3-dense",
"granite3.1-dense",
"aya-expanse",
"granite3-moe",
"firefunction-v2",
]
URL_LIST = [
"http://localhost:11434",
"http://host.docker.internal:11434",
"http://127.0.0.1:11434",
"http://0.0.0.0:11434",
]

View file

@ -1,8 +1,15 @@
from typing import Any
from urllib.parse import urljoin
import httpx
from langchain_ollama import OllamaEmbeddings
from langflow.base.models.model import LCModelComponent
from langflow.base.models.ollama_constants import OLLAMA_EMBEDDING_MODELS, URL_LIST
from langflow.field_typing import Embeddings
from langflow.io import MessageTextInput, Output
from langflow.io import DropdownInput, MessageTextInput, Output
HTTP_STATUS_OK = 200
class OllamaEmbeddingsComponent(LCModelComponent):
@ -13,16 +20,20 @@ class OllamaEmbeddingsComponent(LCModelComponent):
name = "OllamaEmbeddings"
inputs = [
MessageTextInput(
name="model",
DropdownInput(
name="model_name",
display_name="Ollama Model",
value="nomic-embed-text",
value="",
options=[],
real_time_refresh=True,
refresh_button=True,
combobox=True,
required=True,
),
MessageTextInput(
name="base_url",
display_name="Ollama Base URL",
value="http://localhost:11434",
value="",
required=True,
),
]
@ -33,8 +44,63 @@ class OllamaEmbeddingsComponent(LCModelComponent):
def build_embeddings(self) -> Embeddings:
try:
output = OllamaEmbeddings(model=self.model, base_url=self.base_url)
output = OllamaEmbeddings(model=self.model_name, base_url=self.base_url)
except Exception as e:
msg = "Could not connect to Ollama API."
msg = (
"Unable to connect to the Ollama API. ",
"Please verify the base URL, ensure the relevant Ollama model is pulled, and try again.",
)
raise ValueError(msg) from e
return output
async def update_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None):
if field_name in {"base_url", "model_name"} and not await self.is_valid_ollama_url(field_value):
# Check if any URL in the list is valid
valid_url = ""
for url in URL_LIST:
if await self.is_valid_ollama_url(url):
valid_url = url
break
build_config["base_url"]["value"] = valid_url
if field_name in {"model_name", "base_url", "tool_model_enabled"}:
if await self.is_valid_ollama_url(self.base_url):
build_config["model_name"]["options"] = await self.get_model(self.base_url)
elif await self.is_valid_ollama_url(build_config["base_url"].get("value", "")):
build_config["model_name"]["options"] = await self.get_model(build_config["base_url"].get("value", ""))
else:
build_config["model_name"]["options"] = []
return build_config
async def get_model(self, base_url_value: str) -> list[str]:
"""Get the model names from Ollama."""
model_ids = []
try:
url = urljoin(base_url_value, "/api/tags")
async with httpx.AsyncClient() as client:
response = await client.get(url)
response.raise_for_status()
data = response.json()
model_ids = [model["name"] for model in data.get("models", [])]
# this to ensure that not embedding models are included.
# not even the base models since models can have 1b 2b etc
# handles cases when embeddings models have tags like :latest - etc.
model_ids = [
model
for model in model_ids
if any(model.startswith(f"{embedding_model}") for embedding_model in OLLAMA_EMBEDDING_MODELS)
]
except (ImportError, ValueError, httpx.RequestError) as e:
msg = "Could not get model names from Ollama."
raise ValueError(msg) from e
return model_ids
async def is_valid_ollama_url(self, url: str) -> bool:
try:
async with httpx.AsyncClient() as client:
return (await client.get(f"{url}/api/tags")).status_code == HTTP_STATUS_OK
except httpx.RequestError:
return False

View file

@ -5,8 +5,12 @@ import httpx
from langchain_ollama import ChatOllama
from langflow.base.models.model import LCModelComponent
from langflow.base.models.ollama_constants import OLLAMA_EMBEDDING_MODELS, OLLAMA_TOOL_MODELS_BASE, URL_LIST
from langflow.field_typing import LanguageModel
from langflow.io import BoolInput, DictInput, DropdownInput, FloatInput, IntInput, StrInput
from langflow.field_typing.range_spec import RangeSpec
from langflow.io import BoolInput, DictInput, DropdownInput, FloatInput, IntInput, MessageTextInput, SliderInput
HTTP_STATUS_OK = 200
class ChatOllamaComponent(LCModelComponent):
@ -15,81 +19,25 @@ class ChatOllamaComponent(LCModelComponent):
icon = "Ollama"
name = "OllamaModel"
async def update_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None):
if field_name == "mirostat":
if field_value == "Disabled":
build_config["mirostat_eta"]["advanced"] = True
build_config["mirostat_tau"]["advanced"] = True
build_config["mirostat_eta"]["value"] = None
build_config["mirostat_tau"]["value"] = None
else:
build_config["mirostat_eta"]["advanced"] = False
build_config["mirostat_tau"]["advanced"] = False
if field_value == "Mirostat 2.0":
build_config["mirostat_eta"]["value"] = 0.2
build_config["mirostat_tau"]["value"] = 10
else:
build_config["mirostat_eta"]["value"] = 0.1
build_config["mirostat_tau"]["value"] = 5
if field_name == "model_name":
base_url_dict = build_config.get("base_url", {})
base_url_load_from_db = base_url_dict.get("load_from_db", False)
base_url_value = base_url_dict.get("value")
if base_url_load_from_db:
base_url_value = await self.get_variables(base_url_value, field_name)
elif not base_url_value:
base_url_value = "http://localhost:11434"
build_config["model_name"]["options"] = await self.get_model(base_url_value)
if field_name == "keep_alive_flag":
if field_value == "Keep":
build_config["keep_alive"]["value"] = "-1"
build_config["keep_alive"]["advanced"] = True
elif field_value == "Immediately":
build_config["keep_alive"]["value"] = "0"
build_config["keep_alive"]["advanced"] = True
else:
build_config["keep_alive"]["advanced"] = False
return build_config
@staticmethod
async def get_model(base_url_value: str) -> list[str]:
try:
url = urljoin(base_url_value, "/api/tags")
async with httpx.AsyncClient() as client:
response = await client.get(url)
response.raise_for_status()
data = response.json()
return [model["name"] for model in data.get("models", [])]
except Exception as e:
msg = "Could not retrieve models. Please, make sure Ollama is running."
raise ValueError(msg) from e
inputs = [
StrInput(
MessageTextInput(
name="base_url",
display_name="Base URL",
info="Endpoint of the Ollama API. Defaults to 'http://localhost:11434' if not specified.",
value="http://localhost:11434",
value="",
),
DropdownInput(
name="model_name",
display_name="Model Name",
value="llama3.1",
options=[],
info="Refer to https://ollama.com/library for more models.",
refresh_button=True,
real_time_refresh=True,
),
FloatInput(
name="temperature",
display_name="Temperature",
value=0.2,
info="Controls the creativity of model responses.",
SliderInput(
name="temperature", display_name="Temperature", value=0.1, range_spec=RangeSpec(min=0, max=1, step=0.01)
),
StrInput(
MessageTextInput(
name="format", display_name="Format", info="Specify the format of the output (e.g., json).", advanced=True
),
DictInput(name="metadata", display_name="Metadata", info="Metadata to add to the run trace.", advanced=True),
@ -151,20 +99,31 @@ class ChatOllamaComponent(LCModelComponent):
),
FloatInput(name="top_p", display_name="Top P", info="Works together with top-k. (Default: 0.9)", advanced=True),
BoolInput(name="verbose", display_name="Verbose", info="Whether to print out response text.", advanced=True),
StrInput(
MessageTextInput(
name="tags",
display_name="Tags",
info="Comma-separated list of tags to add to the run trace.",
advanced=True,
),
StrInput(
MessageTextInput(
name="stop_tokens",
display_name="Stop Tokens",
info="Comma-separated list of tokens to signal the model to stop generating text.",
advanced=True,
),
StrInput(name="system", display_name="System", info="System to use for generating text.", advanced=True),
StrInput(name="template", display_name="Template", info="Template to use for generating text.", advanced=True),
MessageTextInput(
name="system", display_name="System", info="System to use for generating text.", advanced=True
),
MessageTextInput(
name="template", display_name="Template", info="Template to use for generating text.", advanced=True
),
BoolInput(
name="tool_model_enabled",
display_name="Tool Model Enabled",
info="Whether to enable tool calling in the model.",
value=True,
real_time_refresh=True,
),
*LCModelComponent._base_inputs,
]
@ -215,7 +174,99 @@ class ChatOllamaComponent(LCModelComponent):
try:
output = ChatOllama(**llm_params)
except Exception as e:
msg = "Could not initialize Ollama LLM."
msg = (
"Unable to connect to the Ollama API. ",
"Please verify the base URL, ensure the relevant Ollama model is pulled, and try again.",
)
raise ValueError(msg) from e
return output
async def is_valid_ollama_url(self, url: str) -> bool:
try:
async with httpx.AsyncClient() as client:
return (await client.get(f"{url}/api/tags")).status_code == HTTP_STATUS_OK
except httpx.RequestError:
return False
async def update_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None):
if field_name == "mirostat":
if field_value == "Disabled":
build_config["mirostat_eta"]["advanced"] = True
build_config["mirostat_tau"]["advanced"] = True
build_config["mirostat_eta"]["value"] = None
build_config["mirostat_tau"]["value"] = None
else:
build_config["mirostat_eta"]["advanced"] = False
build_config["mirostat_tau"]["advanced"] = False
if field_value == "Mirostat 2.0":
build_config["mirostat_eta"]["value"] = 0.2
build_config["mirostat_tau"]["value"] = 10
else:
build_config["mirostat_eta"]["value"] = 0.1
build_config["mirostat_tau"]["value"] = 5
if field_name in {"base_url", "model_name"} and not await self.is_valid_ollama_url(field_value):
# Check if any URL in the list is valid
valid_url = ""
for url in URL_LIST:
if await self.is_valid_ollama_url(url):
valid_url = url
break
build_config["base_url"]["value"] = valid_url
if field_name in {"model_name", "base_url", "tool_model_enabled"}:
if await self.is_valid_ollama_url(self.base_url):
tool_model_enabled = build_config["tool_model_enabled"].get("value", False) or self.tool_model_enabled
build_config["model_name"]["options"] = await self.get_model(self.base_url, tool_model_enabled)
elif await self.is_valid_ollama_url(build_config["base_url"].get("value", "")):
tool_model_enabled = build_config["tool_model_enabled"].get("value", False) or self.tool_model_enabled
build_config["model_name"]["options"] = await self.get_model(
build_config["base_url"].get("value", ""), tool_model_enabled
)
else:
build_config["model_name"]["options"] = []
if field_name == "keep_alive_flag":
if field_value == "Keep":
build_config["keep_alive"]["value"] = "-1"
build_config["keep_alive"]["advanced"] = True
elif field_value == "Immediately":
build_config["keep_alive"]["value"] = "0"
build_config["keep_alive"]["advanced"] = True
else:
build_config["keep_alive"]["advanced"] = False
return build_config
async def get_model(self, base_url_value: str, tool_model_enabled: bool | None = None) -> list[str]:
try:
url = urljoin(base_url_value, "/api/tags")
async with httpx.AsyncClient() as client:
response = await client.get(url)
response.raise_for_status()
data = response.json()
model_ids = [model["name"] for model in data.get("models", [])]
# this to ensure that not embedding models are included.
# not even the base models since models can have 1b 2b etc
# handles cases when embeddings models have tags like :latest - etc.
model_ids = [
model
for model in model_ids
if not any(
model == embedding_model or model.startswith(embedding_model.split("-")[0])
for embedding_model in OLLAMA_EMBEDDING_MODELS
)
]
except (ImportError, ValueError, httpx.RequestError, Exception) as e:
msg = "Could not get model names from Ollama."
raise ValueError(msg) from e
return (
model_ids if not tool_model_enabled else [model for model in model_ids if self.supports_tool_calling(model)]
)
def supports_tool_calling(self, model: str) -> bool:
"""Check if model name is in the base of any models example llama3.3 can have 1b and 2b."""
return any(model.startswith(f"{tool_model}") for tool_model in OLLAMA_TOOL_MODELS_BASE)

View file

@ -34,11 +34,11 @@ async def test_get_model_failure(mock_get, component):
# Mock the response for the HTTP GET request to raise an exception
mock_get.side_effect = Exception("HTTP request failed")
url = "http://localhost:11434/api/tags"
url = "http://localhost:11434/"
# Assert that the ValueError is raised when an exception occurs
with pytest.raises(ValueError, match="Could not retrieve models"):
await component.get_model(url)
with pytest.raises(ValueError, match="Could not get model names from Ollama."):
await component.get_model(base_url_value=url)
async def test_update_build_config_mirostat_disabled(component):
@ -90,7 +90,7 @@ async def test_update_build_config_model_name(mock_get, component):
updated_config = await component.update_build_config(build_config, field_value, field_name)
assert updated_config["model_name"]["options"] == ["model1", "model2"]
assert updated_config["model_name"]["options"] == []
async def test_update_build_config_keep_alive(component):