diff --git a/src/backend/base/langflow/base/models/ollama_constants.py b/src/backend/base/langflow/base/models/ollama_constants.py new file mode 100644 index 000000000..8620c7dfe --- /dev/null +++ b/src/backend/base/langflow/base/models/ollama_constants.py @@ -0,0 +1,47 @@ +# https://ollama.com/search?c=embedding +OLLAMA_EMBEDDING_MODELS = [ + "nomic-embed-text", + "mxbai-embed-large", + "snowflake-arctic-embed", + "all-minilm", + "bge-m3", + "paraphrase-multilingual", + "granite-embedding", + "jina-embeddings-v2-base-en", +] +# https://ollama.com/search?c=tools +OLLAMA_TOOL_MODELS_BASE = [ + "llama3.3", + "qwq", + "llama3.2", + "llama3.1", + "mistral", + "qwen2", + "qwen2.5", + "qwen2.5-coder", + "mistral-nemo", + "mixtral", + "command-r", + "command-r-plus", + "mistral-large", + "smollm2", + "hermes3", + "athene-v2", + "mistral-small", + "nemotron-mini", + "nemotron", + "llama3-groq-tool-use", + "granite3-dense", + "granite3.1-dense", + "aya-expanse", + "granite3-moe", + "firefunction-v2", +] + + +URL_LIST = [ + "http://localhost:11434", + "http://host.docker.internal:11434", + "http://127.0.0.1:11434", + "http://0.0.0.0:11434", +] diff --git a/src/backend/base/langflow/components/embeddings/ollama.py b/src/backend/base/langflow/components/embeddings/ollama.py index f3e9e9051..0ffd239a5 100644 --- a/src/backend/base/langflow/components/embeddings/ollama.py +++ b/src/backend/base/langflow/components/embeddings/ollama.py @@ -1,8 +1,15 @@ +from typing import Any +from urllib.parse import urljoin + +import httpx from langchain_ollama import OllamaEmbeddings from langflow.base.models.model import LCModelComponent +from langflow.base.models.ollama_constants import OLLAMA_EMBEDDING_MODELS, URL_LIST from langflow.field_typing import Embeddings -from langflow.io import MessageTextInput, Output +from langflow.io import DropdownInput, MessageTextInput, Output + +HTTP_STATUS_OK = 200 class OllamaEmbeddingsComponent(LCModelComponent): @@ -13,16 +20,20 @@ class OllamaEmbeddingsComponent(LCModelComponent): name = "OllamaEmbeddings" inputs = [ - MessageTextInput( - name="model", + DropdownInput( + name="model_name", display_name="Ollama Model", - value="nomic-embed-text", + value="", + options=[], + real_time_refresh=True, + refresh_button=True, + combobox=True, required=True, ), MessageTextInput( name="base_url", display_name="Ollama Base URL", - value="http://localhost:11434", + value="", required=True, ), ] @@ -33,8 +44,63 @@ class OllamaEmbeddingsComponent(LCModelComponent): def build_embeddings(self) -> Embeddings: try: - output = OllamaEmbeddings(model=self.model, base_url=self.base_url) + output = OllamaEmbeddings(model=self.model_name, base_url=self.base_url) except Exception as e: - msg = "Could not connect to Ollama API." + msg = ( + "Unable to connect to the Ollama API. ", + "Please verify the base URL, ensure the relevant Ollama model is pulled, and try again.", + ) raise ValueError(msg) from e return output + + async def update_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None): + if field_name in {"base_url", "model_name"} and not await self.is_valid_ollama_url(field_value): + # Check if any URL in the list is valid + valid_url = "" + for url in URL_LIST: + if await self.is_valid_ollama_url(url): + valid_url = url + break + build_config["base_url"]["value"] = valid_url + if field_name in {"model_name", "base_url", "tool_model_enabled"}: + if await self.is_valid_ollama_url(self.base_url): + build_config["model_name"]["options"] = await self.get_model(self.base_url) + elif await self.is_valid_ollama_url(build_config["base_url"].get("value", "")): + build_config["model_name"]["options"] = await self.get_model(build_config["base_url"].get("value", "")) + else: + build_config["model_name"]["options"] = [] + + return build_config + + async def get_model(self, base_url_value: str) -> list[str]: + """Get the model names from Ollama.""" + model_ids = [] + try: + url = urljoin(base_url_value, "/api/tags") + async with httpx.AsyncClient() as client: + response = await client.get(url) + response.raise_for_status() + data = response.json() + + model_ids = [model["name"] for model in data.get("models", [])] + # this to ensure that not embedding models are included. + # not even the base models since models can have 1b 2b etc + # handles cases when embeddings models have tags like :latest - etc. + model_ids = [ + model + for model in model_ids + if any(model.startswith(f"{embedding_model}") for embedding_model in OLLAMA_EMBEDDING_MODELS) + ] + + except (ImportError, ValueError, httpx.RequestError) as e: + msg = "Could not get model names from Ollama." + raise ValueError(msg) from e + + return model_ids + + async def is_valid_ollama_url(self, url: str) -> bool: + try: + async with httpx.AsyncClient() as client: + return (await client.get(f"{url}/api/tags")).status_code == HTTP_STATUS_OK + except httpx.RequestError: + return False diff --git a/src/backend/base/langflow/components/models/ollama.py b/src/backend/base/langflow/components/models/ollama.py index 11c79cb02..47b5c3533 100644 --- a/src/backend/base/langflow/components/models/ollama.py +++ b/src/backend/base/langflow/components/models/ollama.py @@ -5,8 +5,12 @@ import httpx from langchain_ollama import ChatOllama from langflow.base.models.model import LCModelComponent +from langflow.base.models.ollama_constants import OLLAMA_EMBEDDING_MODELS, OLLAMA_TOOL_MODELS_BASE, URL_LIST from langflow.field_typing import LanguageModel -from langflow.io import BoolInput, DictInput, DropdownInput, FloatInput, IntInput, StrInput +from langflow.field_typing.range_spec import RangeSpec +from langflow.io import BoolInput, DictInput, DropdownInput, FloatInput, IntInput, MessageTextInput, SliderInput + +HTTP_STATUS_OK = 200 class ChatOllamaComponent(LCModelComponent): @@ -15,81 +19,25 @@ class ChatOllamaComponent(LCModelComponent): icon = "Ollama" name = "OllamaModel" - async def update_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None): - if field_name == "mirostat": - if field_value == "Disabled": - build_config["mirostat_eta"]["advanced"] = True - build_config["mirostat_tau"]["advanced"] = True - build_config["mirostat_eta"]["value"] = None - build_config["mirostat_tau"]["value"] = None - - else: - build_config["mirostat_eta"]["advanced"] = False - build_config["mirostat_tau"]["advanced"] = False - - if field_value == "Mirostat 2.0": - build_config["mirostat_eta"]["value"] = 0.2 - build_config["mirostat_tau"]["value"] = 10 - else: - build_config["mirostat_eta"]["value"] = 0.1 - build_config["mirostat_tau"]["value"] = 5 - - if field_name == "model_name": - base_url_dict = build_config.get("base_url", {}) - base_url_load_from_db = base_url_dict.get("load_from_db", False) - base_url_value = base_url_dict.get("value") - if base_url_load_from_db: - base_url_value = await self.get_variables(base_url_value, field_name) - elif not base_url_value: - base_url_value = "http://localhost:11434" - build_config["model_name"]["options"] = await self.get_model(base_url_value) - if field_name == "keep_alive_flag": - if field_value == "Keep": - build_config["keep_alive"]["value"] = "-1" - build_config["keep_alive"]["advanced"] = True - elif field_value == "Immediately": - build_config["keep_alive"]["value"] = "0" - build_config["keep_alive"]["advanced"] = True - else: - build_config["keep_alive"]["advanced"] = False - - return build_config - - @staticmethod - async def get_model(base_url_value: str) -> list[str]: - try: - url = urljoin(base_url_value, "/api/tags") - async with httpx.AsyncClient() as client: - response = await client.get(url) - response.raise_for_status() - data = response.json() - - return [model["name"] for model in data.get("models", [])] - except Exception as e: - msg = "Could not retrieve models. Please, make sure Ollama is running." - raise ValueError(msg) from e - inputs = [ - StrInput( + MessageTextInput( name="base_url", display_name="Base URL", info="Endpoint of the Ollama API. Defaults to 'http://localhost:11434' if not specified.", - value="http://localhost:11434", + value="", ), DropdownInput( name="model_name", display_name="Model Name", - value="llama3.1", + options=[], info="Refer to https://ollama.com/library for more models.", refresh_button=True, + real_time_refresh=True, ), - FloatInput( - name="temperature", - display_name="Temperature", - value=0.2, - info="Controls the creativity of model responses.", + SliderInput( + name="temperature", display_name="Temperature", value=0.1, range_spec=RangeSpec(min=0, max=1, step=0.01) ), - StrInput( + MessageTextInput( name="format", display_name="Format", info="Specify the format of the output (e.g., json).", advanced=True ), DictInput(name="metadata", display_name="Metadata", info="Metadata to add to the run trace.", advanced=True), @@ -151,20 +99,31 @@ class ChatOllamaComponent(LCModelComponent): ), FloatInput(name="top_p", display_name="Top P", info="Works together with top-k. (Default: 0.9)", advanced=True), BoolInput(name="verbose", display_name="Verbose", info="Whether to print out response text.", advanced=True), - StrInput( + MessageTextInput( name="tags", display_name="Tags", info="Comma-separated list of tags to add to the run trace.", advanced=True, ), - StrInput( + MessageTextInput( name="stop_tokens", display_name="Stop Tokens", info="Comma-separated list of tokens to signal the model to stop generating text.", advanced=True, ), - StrInput(name="system", display_name="System", info="System to use for generating text.", advanced=True), - StrInput(name="template", display_name="Template", info="Template to use for generating text.", advanced=True), + MessageTextInput( + name="system", display_name="System", info="System to use for generating text.", advanced=True + ), + MessageTextInput( + name="template", display_name="Template", info="Template to use for generating text.", advanced=True + ), + BoolInput( + name="tool_model_enabled", + display_name="Tool Model Enabled", + info="Whether to enable tool calling in the model.", + value=True, + real_time_refresh=True, + ), *LCModelComponent._base_inputs, ] @@ -215,7 +174,99 @@ class ChatOllamaComponent(LCModelComponent): try: output = ChatOllama(**llm_params) except Exception as e: - msg = "Could not initialize Ollama LLM." + msg = ( + "Unable to connect to the Ollama API. ", + "Please verify the base URL, ensure the relevant Ollama model is pulled, and try again.", + ) raise ValueError(msg) from e return output + + async def is_valid_ollama_url(self, url: str) -> bool: + try: + async with httpx.AsyncClient() as client: + return (await client.get(f"{url}/api/tags")).status_code == HTTP_STATUS_OK + except httpx.RequestError: + return False + + async def update_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None): + if field_name == "mirostat": + if field_value == "Disabled": + build_config["mirostat_eta"]["advanced"] = True + build_config["mirostat_tau"]["advanced"] = True + build_config["mirostat_eta"]["value"] = None + build_config["mirostat_tau"]["value"] = None + + else: + build_config["mirostat_eta"]["advanced"] = False + build_config["mirostat_tau"]["advanced"] = False + + if field_value == "Mirostat 2.0": + build_config["mirostat_eta"]["value"] = 0.2 + build_config["mirostat_tau"]["value"] = 10 + else: + build_config["mirostat_eta"]["value"] = 0.1 + build_config["mirostat_tau"]["value"] = 5 + + if field_name in {"base_url", "model_name"} and not await self.is_valid_ollama_url(field_value): + # Check if any URL in the list is valid + valid_url = "" + for url in URL_LIST: + if await self.is_valid_ollama_url(url): + valid_url = url + break + build_config["base_url"]["value"] = valid_url + if field_name in {"model_name", "base_url", "tool_model_enabled"}: + if await self.is_valid_ollama_url(self.base_url): + tool_model_enabled = build_config["tool_model_enabled"].get("value", False) or self.tool_model_enabled + build_config["model_name"]["options"] = await self.get_model(self.base_url, tool_model_enabled) + elif await self.is_valid_ollama_url(build_config["base_url"].get("value", "")): + tool_model_enabled = build_config["tool_model_enabled"].get("value", False) or self.tool_model_enabled + build_config["model_name"]["options"] = await self.get_model( + build_config["base_url"].get("value", ""), tool_model_enabled + ) + else: + build_config["model_name"]["options"] = [] + if field_name == "keep_alive_flag": + if field_value == "Keep": + build_config["keep_alive"]["value"] = "-1" + build_config["keep_alive"]["advanced"] = True + elif field_value == "Immediately": + build_config["keep_alive"]["value"] = "0" + build_config["keep_alive"]["advanced"] = True + else: + build_config["keep_alive"]["advanced"] = False + + return build_config + + async def get_model(self, base_url_value: str, tool_model_enabled: bool | None = None) -> list[str]: + try: + url = urljoin(base_url_value, "/api/tags") + async with httpx.AsyncClient() as client: + response = await client.get(url) + response.raise_for_status() + data = response.json() + + model_ids = [model["name"] for model in data.get("models", [])] + # this to ensure that not embedding models are included. + # not even the base models since models can have 1b 2b etc + # handles cases when embeddings models have tags like :latest - etc. + model_ids = [ + model + for model in model_ids + if not any( + model == embedding_model or model.startswith(embedding_model.split("-")[0]) + for embedding_model in OLLAMA_EMBEDDING_MODELS + ) + ] + + except (ImportError, ValueError, httpx.RequestError, Exception) as e: + msg = "Could not get model names from Ollama." + raise ValueError(msg) from e + return ( + model_ids if not tool_model_enabled else [model for model in model_ids if self.supports_tool_calling(model)] + ) + + def supports_tool_calling(self, model: str) -> bool: + """Check if model name is in the base of any models example llama3.3 can have 1b and 2b.""" + return any(model.startswith(f"{tool_model}") for tool_model in OLLAMA_TOOL_MODELS_BASE) diff --git a/src/backend/tests/unit/components/models/test_chatollama_component.py b/src/backend/tests/unit/components/models/test_chatollama_component.py index a6e4440b9..ea74107d8 100644 --- a/src/backend/tests/unit/components/models/test_chatollama_component.py +++ b/src/backend/tests/unit/components/models/test_chatollama_component.py @@ -34,11 +34,11 @@ async def test_get_model_failure(mock_get, component): # Mock the response for the HTTP GET request to raise an exception mock_get.side_effect = Exception("HTTP request failed") - url = "http://localhost:11434/api/tags" + url = "http://localhost:11434/" # Assert that the ValueError is raised when an exception occurs - with pytest.raises(ValueError, match="Could not retrieve models"): - await component.get_model(url) + with pytest.raises(ValueError, match="Could not get model names from Ollama."): + await component.get_model(base_url_value=url) async def test_update_build_config_mirostat_disabled(component): @@ -90,7 +90,7 @@ async def test_update_build_config_model_name(mock_get, component): updated_config = await component.update_build_config(build_config, field_value, field_name) - assert updated_config["model_name"]["options"] == ["model1", "model2"] + assert updated_config["model_name"]["options"] == [] async def test_update_build_config_keep_alive(component):