diff --git a/src/backend/base/langflow/components/models/OllamaModel.py b/src/backend/base/langflow/components/models/OllamaModel.py index 5ef47c8ef..1fd9abcc7 100644 --- a/src/backend/base/langflow/components/models/OllamaModel.py +++ b/src/backend/base/langflow/components/models/OllamaModel.py @@ -1,9 +1,13 @@ -from langchain_community.chat_models import ChatOllama -from langchain_core.language_models.chat_models import BaseChatModel +from typing import Any +import httpx +from langchain_community.chat_models import ChatOllama + +from langflow.base.constants import STREAM_INFO_TEXT from langflow.base.models.model import LCModelComponent -from langflow.field_typing import LanguageModel, Text -from langflow.io import BoolInput, DictInput, DropdownInput, FloatInput, IntInput, Output, StrInput +from langflow.field_typing import LanguageModel +from langflow.io import BoolInput, DictInput, DropdownInput, FloatInput, IntInput, MessageInput, Output, StrInput +from langflow.schema.message import Message class ChatOllamaComponent(LCModelComponent): @@ -11,24 +15,77 @@ class ChatOllamaComponent(LCModelComponent): description = "Generate text using Ollama Local LLMs." icon = "Ollama" + def update_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None): + if field_name == "mirostat": + if field_value == "Disabled": + build_config["mirostat_eta"]["advanced"] = True + build_config["mirostat_tau"]["advanced"] = True + build_config["mirostat_eta"]["value"] = None + build_config["mirostat_tau"]["value"] = None + + else: + build_config["mirostat_eta"]["advanced"] = False + build_config["mirostat_tau"]["advanced"] = False + + if field_value == "Mirostat 2.0": + build_config["mirostat_eta"]["value"] = 0.2 + build_config["mirostat_tau"]["value"] = 10 + else: + build_config["mirostat_eta"]["value"] = 0.1 + build_config["mirostat_tau"]["value"] = 5 + + if field_name == "model": + base_url_dict = build_config.get("base_url", {}) + base_url_load_from_db = base_url_dict.get("load_from_db", False) + base_url_value = base_url_dict.get("value") + if base_url_load_from_db: + base_url_value = self.variables(base_url_value) + elif not base_url_value: + base_url_value = "http://localhost:11434" + build_config["model"]["options"] = self.get_model(base_url_value + "/api/tags") + + if field_name == "keep_alive_flag": + if field_value == "Keep": + build_config["keep_alive"]["value"] = "-1" + build_config["keep_alive"]["advanced"] = True + elif field_value == "Immediately": + build_config["keep_alive"]["value"] = "0" + build_config["keep_alive"]["advanced"] = True + else: + build_config["keep_alive"]["advanced"] = False + + return build_config + + def get_model(self, url: str) -> list[str]: + try: + with httpx.Client() as client: + response = client.get(url) + response.raise_for_status() + data = response.json() + + model_names = [model["name"] for model in data.get("models", [])] + return model_names + except Exception as e: + raise ValueError("Could not retrieve models. Please, make sure Ollama is running.") from e + inputs = [ StrInput( name="base_url", display_name="Base URL", info="Endpoint of the Ollama API. Defaults to 'http://localhost:11434' if not specified.", - advanced=True, value="http://localhost:11434", ), - StrInput( + DropdownInput( name="model", display_name="Model Name", value="llama2", info="Refer to https://ollama.ai/library for more models.", + refresh_button=True, ), FloatInput( name="temperature", display_name="Temperature", - value=0.8, + value=0.2, info="Controls the creativity of model responses.", ), StrInput( @@ -146,10 +203,9 @@ class ChatOllamaComponent(LCModelComponent): info="Template to use for generating text.", advanced=True, ), - StrInput( + MessageInput( name="input_value", display_name="Input", - input_types=["Text", "Data", "Prompt"], ), BoolInput( name="stream", @@ -168,7 +224,7 @@ class ChatOllamaComponent(LCModelComponent): Output(display_name="Language Model", name="model_output", method="build_model"), ] - def text_response(self) -> Text: + def text_response(self) -> Message: input_value = self.input_value stream = self.stream system_message = self.system_message @@ -177,7 +233,7 @@ class ChatOllamaComponent(LCModelComponent): self.status = result return result - def build_model(self) -> LanguageModel | BaseChatModel: + def build_model(self) -> LanguageModel: # Mapping mirostat settings to their corresponding values mirostat_options = {"Mirostat": 1, "Mirostat 2.0": 2}