feat: Add model filtering support for Ollama Component, improving stability (#5748)
* Update ollama.py * ollama models refactor * ollama embeddings support for model filters * formatting * Update src/backend/base/langflow/components/embeddings/ollama.py Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org> * [autofix.ci] apply automated fixes * fix test * reverting test * refactor: Update Ollama components to use async URL validation - Changed `is_valid_ollama_url` method to be asynchronous in both `ollama.py` files. - Updated calls to `is_valid_ollama_url` to use `await` for proper async handling. - Modified URL validation logic to ensure compatibility with async operations. - Improved overall responsiveness of the Ollama components by leveraging async HTTP requests. * reverting to empty list! --------- Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
3b8578ca8c
commit
bbb3987cda
4 changed files with 244 additions and 80 deletions
47
src/backend/base/langflow/base/models/ollama_constants.py
Normal file
47
src/backend/base/langflow/base/models/ollama_constants.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
# https://ollama.com/search?c=embedding
|
||||
OLLAMA_EMBEDDING_MODELS = [
|
||||
"nomic-embed-text",
|
||||
"mxbai-embed-large",
|
||||
"snowflake-arctic-embed",
|
||||
"all-minilm",
|
||||
"bge-m3",
|
||||
"paraphrase-multilingual",
|
||||
"granite-embedding",
|
||||
"jina-embeddings-v2-base-en",
|
||||
]
|
||||
# https://ollama.com/search?c=tools
|
||||
OLLAMA_TOOL_MODELS_BASE = [
|
||||
"llama3.3",
|
||||
"qwq",
|
||||
"llama3.2",
|
||||
"llama3.1",
|
||||
"mistral",
|
||||
"qwen2",
|
||||
"qwen2.5",
|
||||
"qwen2.5-coder",
|
||||
"mistral-nemo",
|
||||
"mixtral",
|
||||
"command-r",
|
||||
"command-r-plus",
|
||||
"mistral-large",
|
||||
"smollm2",
|
||||
"hermes3",
|
||||
"athene-v2",
|
||||
"mistral-small",
|
||||
"nemotron-mini",
|
||||
"nemotron",
|
||||
"llama3-groq-tool-use",
|
||||
"granite3-dense",
|
||||
"granite3.1-dense",
|
||||
"aya-expanse",
|
||||
"granite3-moe",
|
||||
"firefunction-v2",
|
||||
]
|
||||
|
||||
|
||||
URL_LIST = [
|
||||
"http://localhost:11434",
|
||||
"http://host.docker.internal:11434",
|
||||
"http://127.0.0.1:11434",
|
||||
"http://0.0.0.0:11434",
|
||||
]
|
||||
|
|
@ -1,8 +1,15 @@
|
|||
from typing import Any
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
from langchain_ollama import OllamaEmbeddings
|
||||
|
||||
from langflow.base.models.model import LCModelComponent
|
||||
from langflow.base.models.ollama_constants import OLLAMA_EMBEDDING_MODELS, URL_LIST
|
||||
from langflow.field_typing import Embeddings
|
||||
from langflow.io import MessageTextInput, Output
|
||||
from langflow.io import DropdownInput, MessageTextInput, Output
|
||||
|
||||
HTTP_STATUS_OK = 200
|
||||
|
||||
|
||||
class OllamaEmbeddingsComponent(LCModelComponent):
|
||||
|
|
@ -13,16 +20,20 @@ class OllamaEmbeddingsComponent(LCModelComponent):
|
|||
name = "OllamaEmbeddings"
|
||||
|
||||
inputs = [
|
||||
MessageTextInput(
|
||||
name="model",
|
||||
DropdownInput(
|
||||
name="model_name",
|
||||
display_name="Ollama Model",
|
||||
value="nomic-embed-text",
|
||||
value="",
|
||||
options=[],
|
||||
real_time_refresh=True,
|
||||
refresh_button=True,
|
||||
combobox=True,
|
||||
required=True,
|
||||
),
|
||||
MessageTextInput(
|
||||
name="base_url",
|
||||
display_name="Ollama Base URL",
|
||||
value="http://localhost:11434",
|
||||
value="",
|
||||
required=True,
|
||||
),
|
||||
]
|
||||
|
|
@ -33,8 +44,63 @@ class OllamaEmbeddingsComponent(LCModelComponent):
|
|||
|
||||
def build_embeddings(self) -> Embeddings:
|
||||
try:
|
||||
output = OllamaEmbeddings(model=self.model, base_url=self.base_url)
|
||||
output = OllamaEmbeddings(model=self.model_name, base_url=self.base_url)
|
||||
except Exception as e:
|
||||
msg = "Could not connect to Ollama API."
|
||||
msg = (
|
||||
"Unable to connect to the Ollama API. ",
|
||||
"Please verify the base URL, ensure the relevant Ollama model is pulled, and try again.",
|
||||
)
|
||||
raise ValueError(msg) from e
|
||||
return output
|
||||
|
||||
async def update_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None):
|
||||
if field_name in {"base_url", "model_name"} and not await self.is_valid_ollama_url(field_value):
|
||||
# Check if any URL in the list is valid
|
||||
valid_url = ""
|
||||
for url in URL_LIST:
|
||||
if await self.is_valid_ollama_url(url):
|
||||
valid_url = url
|
||||
break
|
||||
build_config["base_url"]["value"] = valid_url
|
||||
if field_name in {"model_name", "base_url", "tool_model_enabled"}:
|
||||
if await self.is_valid_ollama_url(self.base_url):
|
||||
build_config["model_name"]["options"] = await self.get_model(self.base_url)
|
||||
elif await self.is_valid_ollama_url(build_config["base_url"].get("value", "")):
|
||||
build_config["model_name"]["options"] = await self.get_model(build_config["base_url"].get("value", ""))
|
||||
else:
|
||||
build_config["model_name"]["options"] = []
|
||||
|
||||
return build_config
|
||||
|
||||
async def get_model(self, base_url_value: str) -> list[str]:
|
||||
"""Get the model names from Ollama."""
|
||||
model_ids = []
|
||||
try:
|
||||
url = urljoin(base_url_value, "/api/tags")
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
model_ids = [model["name"] for model in data.get("models", [])]
|
||||
# this to ensure that not embedding models are included.
|
||||
# not even the base models since models can have 1b 2b etc
|
||||
# handles cases when embeddings models have tags like :latest - etc.
|
||||
model_ids = [
|
||||
model
|
||||
for model in model_ids
|
||||
if any(model.startswith(f"{embedding_model}") for embedding_model in OLLAMA_EMBEDDING_MODELS)
|
||||
]
|
||||
|
||||
except (ImportError, ValueError, httpx.RequestError) as e:
|
||||
msg = "Could not get model names from Ollama."
|
||||
raise ValueError(msg) from e
|
||||
|
||||
return model_ids
|
||||
|
||||
async def is_valid_ollama_url(self, url: str) -> bool:
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
return (await client.get(f"{url}/api/tags")).status_code == HTTP_STATUS_OK
|
||||
except httpx.RequestError:
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -5,8 +5,12 @@ import httpx
|
|||
from langchain_ollama import ChatOllama
|
||||
|
||||
from langflow.base.models.model import LCModelComponent
|
||||
from langflow.base.models.ollama_constants import OLLAMA_EMBEDDING_MODELS, OLLAMA_TOOL_MODELS_BASE, URL_LIST
|
||||
from langflow.field_typing import LanguageModel
|
||||
from langflow.io import BoolInput, DictInput, DropdownInput, FloatInput, IntInput, StrInput
|
||||
from langflow.field_typing.range_spec import RangeSpec
|
||||
from langflow.io import BoolInput, DictInput, DropdownInput, FloatInput, IntInput, MessageTextInput, SliderInput
|
||||
|
||||
HTTP_STATUS_OK = 200
|
||||
|
||||
|
||||
class ChatOllamaComponent(LCModelComponent):
|
||||
|
|
@ -15,81 +19,25 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
icon = "Ollama"
|
||||
name = "OllamaModel"
|
||||
|
||||
async def update_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None):
|
||||
if field_name == "mirostat":
|
||||
if field_value == "Disabled":
|
||||
build_config["mirostat_eta"]["advanced"] = True
|
||||
build_config["mirostat_tau"]["advanced"] = True
|
||||
build_config["mirostat_eta"]["value"] = None
|
||||
build_config["mirostat_tau"]["value"] = None
|
||||
|
||||
else:
|
||||
build_config["mirostat_eta"]["advanced"] = False
|
||||
build_config["mirostat_tau"]["advanced"] = False
|
||||
|
||||
if field_value == "Mirostat 2.0":
|
||||
build_config["mirostat_eta"]["value"] = 0.2
|
||||
build_config["mirostat_tau"]["value"] = 10
|
||||
else:
|
||||
build_config["mirostat_eta"]["value"] = 0.1
|
||||
build_config["mirostat_tau"]["value"] = 5
|
||||
|
||||
if field_name == "model_name":
|
||||
base_url_dict = build_config.get("base_url", {})
|
||||
base_url_load_from_db = base_url_dict.get("load_from_db", False)
|
||||
base_url_value = base_url_dict.get("value")
|
||||
if base_url_load_from_db:
|
||||
base_url_value = await self.get_variables(base_url_value, field_name)
|
||||
elif not base_url_value:
|
||||
base_url_value = "http://localhost:11434"
|
||||
build_config["model_name"]["options"] = await self.get_model(base_url_value)
|
||||
if field_name == "keep_alive_flag":
|
||||
if field_value == "Keep":
|
||||
build_config["keep_alive"]["value"] = "-1"
|
||||
build_config["keep_alive"]["advanced"] = True
|
||||
elif field_value == "Immediately":
|
||||
build_config["keep_alive"]["value"] = "0"
|
||||
build_config["keep_alive"]["advanced"] = True
|
||||
else:
|
||||
build_config["keep_alive"]["advanced"] = False
|
||||
|
||||
return build_config
|
||||
|
||||
@staticmethod
|
||||
async def get_model(base_url_value: str) -> list[str]:
|
||||
try:
|
||||
url = urljoin(base_url_value, "/api/tags")
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
return [model["name"] for model in data.get("models", [])]
|
||||
except Exception as e:
|
||||
msg = "Could not retrieve models. Please, make sure Ollama is running."
|
||||
raise ValueError(msg) from e
|
||||
|
||||
inputs = [
|
||||
StrInput(
|
||||
MessageTextInput(
|
||||
name="base_url",
|
||||
display_name="Base URL",
|
||||
info="Endpoint of the Ollama API. Defaults to 'http://localhost:11434' if not specified.",
|
||||
value="http://localhost:11434",
|
||||
value="",
|
||||
),
|
||||
DropdownInput(
|
||||
name="model_name",
|
||||
display_name="Model Name",
|
||||
value="llama3.1",
|
||||
options=[],
|
||||
info="Refer to https://ollama.com/library for more models.",
|
||||
refresh_button=True,
|
||||
real_time_refresh=True,
|
||||
),
|
||||
FloatInput(
|
||||
name="temperature",
|
||||
display_name="Temperature",
|
||||
value=0.2,
|
||||
info="Controls the creativity of model responses.",
|
||||
SliderInput(
|
||||
name="temperature", display_name="Temperature", value=0.1, range_spec=RangeSpec(min=0, max=1, step=0.01)
|
||||
),
|
||||
StrInput(
|
||||
MessageTextInput(
|
||||
name="format", display_name="Format", info="Specify the format of the output (e.g., json).", advanced=True
|
||||
),
|
||||
DictInput(name="metadata", display_name="Metadata", info="Metadata to add to the run trace.", advanced=True),
|
||||
|
|
@ -151,20 +99,31 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
),
|
||||
FloatInput(name="top_p", display_name="Top P", info="Works together with top-k. (Default: 0.9)", advanced=True),
|
||||
BoolInput(name="verbose", display_name="Verbose", info="Whether to print out response text.", advanced=True),
|
||||
StrInput(
|
||||
MessageTextInput(
|
||||
name="tags",
|
||||
display_name="Tags",
|
||||
info="Comma-separated list of tags to add to the run trace.",
|
||||
advanced=True,
|
||||
),
|
||||
StrInput(
|
||||
MessageTextInput(
|
||||
name="stop_tokens",
|
||||
display_name="Stop Tokens",
|
||||
info="Comma-separated list of tokens to signal the model to stop generating text.",
|
||||
advanced=True,
|
||||
),
|
||||
StrInput(name="system", display_name="System", info="System to use for generating text.", advanced=True),
|
||||
StrInput(name="template", display_name="Template", info="Template to use for generating text.", advanced=True),
|
||||
MessageTextInput(
|
||||
name="system", display_name="System", info="System to use for generating text.", advanced=True
|
||||
),
|
||||
MessageTextInput(
|
||||
name="template", display_name="Template", info="Template to use for generating text.", advanced=True
|
||||
),
|
||||
BoolInput(
|
||||
name="tool_model_enabled",
|
||||
display_name="Tool Model Enabled",
|
||||
info="Whether to enable tool calling in the model.",
|
||||
value=True,
|
||||
real_time_refresh=True,
|
||||
),
|
||||
*LCModelComponent._base_inputs,
|
||||
]
|
||||
|
||||
|
|
@ -215,7 +174,99 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
try:
|
||||
output = ChatOllama(**llm_params)
|
||||
except Exception as e:
|
||||
msg = "Could not initialize Ollama LLM."
|
||||
msg = (
|
||||
"Unable to connect to the Ollama API. ",
|
||||
"Please verify the base URL, ensure the relevant Ollama model is pulled, and try again.",
|
||||
)
|
||||
raise ValueError(msg) from e
|
||||
|
||||
return output
|
||||
|
||||
async def is_valid_ollama_url(self, url: str) -> bool:
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
return (await client.get(f"{url}/api/tags")).status_code == HTTP_STATUS_OK
|
||||
except httpx.RequestError:
|
||||
return False
|
||||
|
||||
async def update_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None):
|
||||
if field_name == "mirostat":
|
||||
if field_value == "Disabled":
|
||||
build_config["mirostat_eta"]["advanced"] = True
|
||||
build_config["mirostat_tau"]["advanced"] = True
|
||||
build_config["mirostat_eta"]["value"] = None
|
||||
build_config["mirostat_tau"]["value"] = None
|
||||
|
||||
else:
|
||||
build_config["mirostat_eta"]["advanced"] = False
|
||||
build_config["mirostat_tau"]["advanced"] = False
|
||||
|
||||
if field_value == "Mirostat 2.0":
|
||||
build_config["mirostat_eta"]["value"] = 0.2
|
||||
build_config["mirostat_tau"]["value"] = 10
|
||||
else:
|
||||
build_config["mirostat_eta"]["value"] = 0.1
|
||||
build_config["mirostat_tau"]["value"] = 5
|
||||
|
||||
if field_name in {"base_url", "model_name"} and not await self.is_valid_ollama_url(field_value):
|
||||
# Check if any URL in the list is valid
|
||||
valid_url = ""
|
||||
for url in URL_LIST:
|
||||
if await self.is_valid_ollama_url(url):
|
||||
valid_url = url
|
||||
break
|
||||
build_config["base_url"]["value"] = valid_url
|
||||
if field_name in {"model_name", "base_url", "tool_model_enabled"}:
|
||||
if await self.is_valid_ollama_url(self.base_url):
|
||||
tool_model_enabled = build_config["tool_model_enabled"].get("value", False) or self.tool_model_enabled
|
||||
build_config["model_name"]["options"] = await self.get_model(self.base_url, tool_model_enabled)
|
||||
elif await self.is_valid_ollama_url(build_config["base_url"].get("value", "")):
|
||||
tool_model_enabled = build_config["tool_model_enabled"].get("value", False) or self.tool_model_enabled
|
||||
build_config["model_name"]["options"] = await self.get_model(
|
||||
build_config["base_url"].get("value", ""), tool_model_enabled
|
||||
)
|
||||
else:
|
||||
build_config["model_name"]["options"] = []
|
||||
if field_name == "keep_alive_flag":
|
||||
if field_value == "Keep":
|
||||
build_config["keep_alive"]["value"] = "-1"
|
||||
build_config["keep_alive"]["advanced"] = True
|
||||
elif field_value == "Immediately":
|
||||
build_config["keep_alive"]["value"] = "0"
|
||||
build_config["keep_alive"]["advanced"] = True
|
||||
else:
|
||||
build_config["keep_alive"]["advanced"] = False
|
||||
|
||||
return build_config
|
||||
|
||||
async def get_model(self, base_url_value: str, tool_model_enabled: bool | None = None) -> list[str]:
|
||||
try:
|
||||
url = urljoin(base_url_value, "/api/tags")
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
model_ids = [model["name"] for model in data.get("models", [])]
|
||||
# this to ensure that not embedding models are included.
|
||||
# not even the base models since models can have 1b 2b etc
|
||||
# handles cases when embeddings models have tags like :latest - etc.
|
||||
model_ids = [
|
||||
model
|
||||
for model in model_ids
|
||||
if not any(
|
||||
model == embedding_model or model.startswith(embedding_model.split("-")[0])
|
||||
for embedding_model in OLLAMA_EMBEDDING_MODELS
|
||||
)
|
||||
]
|
||||
|
||||
except (ImportError, ValueError, httpx.RequestError, Exception) as e:
|
||||
msg = "Could not get model names from Ollama."
|
||||
raise ValueError(msg) from e
|
||||
return (
|
||||
model_ids if not tool_model_enabled else [model for model in model_ids if self.supports_tool_calling(model)]
|
||||
)
|
||||
|
||||
def supports_tool_calling(self, model: str) -> bool:
|
||||
"""Check if model name is in the base of any models example llama3.3 can have 1b and 2b."""
|
||||
return any(model.startswith(f"{tool_model}") for tool_model in OLLAMA_TOOL_MODELS_BASE)
|
||||
|
|
|
|||
|
|
@ -34,11 +34,11 @@ async def test_get_model_failure(mock_get, component):
|
|||
# Mock the response for the HTTP GET request to raise an exception
|
||||
mock_get.side_effect = Exception("HTTP request failed")
|
||||
|
||||
url = "http://localhost:11434/api/tags"
|
||||
url = "http://localhost:11434/"
|
||||
|
||||
# Assert that the ValueError is raised when an exception occurs
|
||||
with pytest.raises(ValueError, match="Could not retrieve models"):
|
||||
await component.get_model(url)
|
||||
with pytest.raises(ValueError, match="Could not get model names from Ollama."):
|
||||
await component.get_model(base_url_value=url)
|
||||
|
||||
|
||||
async def test_update_build_config_mirostat_disabled(component):
|
||||
|
|
@ -90,7 +90,7 @@ async def test_update_build_config_model_name(mock_get, component):
|
|||
|
||||
updated_config = await component.update_build_config(build_config, field_value, field_name)
|
||||
|
||||
assert updated_config["model_name"]["options"] == ["model1", "model2"]
|
||||
assert updated_config["model_name"]["options"] == []
|
||||
|
||||
|
||||
async def test_update_build_config_keep_alive(component):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue