diff --git a/src/backend/base/langflow/components/models/OllamaModel.py b/src/backend/base/langflow/components/models/OllamaModel.py index cca2a0f48..12db058c2 100644 --- a/src/backend/base/langflow/components/models/OllamaModel.py +++ b/src/backend/base/langflow/components/models/OllamaModel.py @@ -1,11 +1,12 @@ -from typing import Any, Dict, List, Optional +from typing import Any, List, Optional import httpx from langchain_community.chat_models.ollama import ChatOllama from langflow.base.constants import STREAM_INFO_TEXT from langflow.base.models.model import LCModelComponent -from langflow.field_typing import Text +from langflow.field_typing import BaseLanguageModel, Text +from langflow.template.field.base import Input, Output class ChatOllamaComponent(LCModelComponent): @@ -13,199 +14,6 @@ class ChatOllamaComponent(LCModelComponent): description = "Generate text using Ollama Local LLMs." icon = "Ollama" - field_order = [ - "base_url", - "headers", - "keep_alive_flag", - "keep_alive", - "metadata", - "model", - "temperature", - "cache", - "format", - "metadata", - "mirostat", - "mirostat_eta", - "mirostat_tau", - "num_ctx", - "num_gpu", - "num_thread", - "repeat_last_n", - "repeat_penalty", - "tfs_z", - "timeout", - "top_k", - "top_p", - "verbose", - "tags", - "stop", - "system", - "template", - "input_value", - "system_message", - "stream", - ] - - def build_config(self) -> dict: - return { - "base_url": { - "display_name": "Base URL", - "info": "Endpoint of the Ollama API. Defaults to 'http://localhost:11434' if not specified.", - }, - "format": { - "display_name": "Format", - "info": "Specify the format of the output (e.g., json)", - "advanced": True, - }, - "headers": { - "display_name": "Headers", - "advanced": True, - }, - "keep_alive_flag": { - "display_name": "Unload interval", - "options": ["Keep", "Immediately", "Minute", "Hour", "sec"], - "real_time_refresh": True, - "refresh_button": True, - }, - "keep_alive": { - "display_name": "interval", - "info": "How long the model will stay loaded into memory.", - }, - "model": { - "display_name": "Model Name", - "options": [], - "info": "Refer to https://ollama.ai/library for more models.", - "real_time_refresh": True, - "refresh_button": True, - }, - "temperature": { - "display_name": "Temperature", - "field_type": "float", - "value": 0.8, - "info": "Controls the creativity of model responses.", - }, - "metadata": { - "display_name": "Metadata", - "info": "Metadata to add to the run trace.", - "advanced": True, - }, - "mirostat": { - "display_name": "Mirostat", - "options": ["Disabled", "Mirostat", "Mirostat 2.0"], - "info": "Enable/disable Mirostat sampling for controlling perplexity.", - "advanced": False, - "real_time_refresh": True, - "refresh_button": True, - }, - "mirostat_eta": { - "display_name": "Mirostat Eta", - "field_type": "float", - "info": "Learning rate for Mirostat algorithm. (Default: 0.1)", - "advanced": True, - "real_time_refresh": True, - }, - "mirostat_tau": { - "display_name": "Mirostat Tau", - "field_type": "float", - "info": "Controls the balance between coherence and diversity of the output. (Default: 5.0)", - "advanced": True, - "real_time_refresh": True, - }, - "num_ctx": { - "display_name": "Context Window Size", - "field_type": "int", - "info": "Size of the context window for generating tokens. (Default: 2048)", - "advanced": True, - }, - "num_gpu": { - "display_name": "Number of GPUs", - "field_type": "int", - "info": "Number of GPUs to use for computation. (Default: 1 on macOS, 0 to disable)", - "advanced": True, - }, - "num_thread": { - "display_name": "Number of Threads", - "field_type": "int", - "info": "Number of threads to use during computation. (Default: detected for optimal performance)", - "advanced": True, - }, - "repeat_last_n": { - "display_name": "Repeat Last N", - "field_type": "int", - "info": "How far back the model looks to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx)", - "advanced": True, - }, - "repeat_penalty": { - "display_name": "Repeat Penalty", - "field_type": "float", - "info": "Penalty for repetitions in generated text. (Default: 1.1)", - "advanced": True, - }, - "tfs_z": { - "display_name": "TFS Z", - "field_type": "float", - "info": "Tail free sampling value. (Default: 1)", - "advanced": True, - }, - "timeout": { - "display_name": "Timeout", - "field_type": "int", - "info": "Timeout for the request stream.", - "advanced": True, - }, - "top_k": { - "display_name": "Top K", - "field_type": "int", - "info": "Limits token selection to top K. (Default: 40)", - "advanced": True, - }, - "top_p": { - "display_name": "Top P", - "field_type": "float", - "info": "Works together with top-k. (Default: 0.9)", - "advanced": True, - }, - "verbose": { - "display_name": "Verbose", - "field_type": "bool", - "info": "Whether to print out response text.", - }, - "tags": { - "display_name": "Tags", - "field_type": "list", - "info": "Tags to add to the run trace.", - "advanced": True, - }, - "stop": { - "display_name": "Stop Tokens", - "field_type": "list", - "info": "List of tokens to signal the model to stop generating text.", - "advanced": True, - }, - "system": { - "display_name": "System", - "field_type": "str", - "info": "System to use for generating text.", - "advanced": True, - }, - "template": { - "display_name": "Template", - "field_type": "str", - "info": "Template to use for generating text.", - "advanced": True, - }, - "input_value": {"display_name": "Input", "input_types": ["Text", "Record", "Prompt"]}, - "stream": { - "display_name": "Stream", - "info": STREAM_INFO_TEXT, - }, - "system_message": { - "display_name": "System Message", - "info": "System message to pass to the model.", - "advanced": True, - }, - } - def update_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None): if field_name == "mirostat": if field_value == "Disabled": @@ -258,40 +66,134 @@ class ChatOllamaComponent(LCModelComponent): return model_names except Exception as e: raise ValueError("Could not retrieve models") from e - return [""] - def build( - self, - base_url: Optional[str], - model: str, - input_value: Text, - mirostat: Optional[str] = "Disabled", - mirostat_eta: Optional[float] = None, - mirostat_tau: Optional[float] = None, - repeat_last_n: Optional[int] = None, - verbose: Optional[bool] = None, - keep_alive: Optional[int] = None, - keep_alive_flag: Optional[str] = "Keep", - num_ctx: Optional[int] = None, - num_gpu: Optional[int] = None, - format: Optional[str] = None, - metadata: Optional[Dict] = None, - num_thread: Optional[int] = None, - repeat_penalty: Optional[float] = None, - stop: Optional[List[str]] = None, - system: Optional[str] = None, - tags: Optional[List[str]] = None, - temperature: Optional[float] = None, - template: Optional[str] = None, - tfs_z: Optional[float] = None, - timeout: Optional[int] = None, - top_k: Optional[int] = None, - top_p: Optional[int] = None, - stream: bool = False, - system_message: Optional[str] = None, - ) -> Text: - if not base_url: - base_url = "http://localhost:11434" + inputs = [ + Input( + name="base_url", + type=Optional[str], + display_name="Base URL", + info="Endpoint of the Ollama API. Defaults to 'http://localhost:11434' if not specified.", + value="http://localhost:11434", + ), + Input( + name="model", + type=str, + display_name="Model Name", + options=[], # This should be dynamically loaded if possible + info="Refer to https://ollama.ai/library for more models.", + real_time_refresh=True, + refresh_button=True, + ), + Input( + name="mirostat", + type=str, + display_name="Mirostat", + options=["Disabled", "Mirostat", "Mirostat 2.0"], + info="Enable/disable Mirostat sampling for controlling perplexity.", + advanced=False, + real_time_refresh=True, + refresh_button=True, + value="Disabled", + ), + Input( + name="mirostat_eta", + type=Optional[float], + display_name="Mirostat Eta", + info="Learning rate for Mirostat algorithm.", + advanced=True, + real_time_refresh=True, + value=None, # Default can vary based on mirostat status + ), + Input( + name="mirostat_tau", + type=Optional[float], + display_name="Mirostat Tau", + info="Controls the balance between coherence and diversity of the output.", + advanced=True, + real_time_refresh=True, + value=None, # Default can vary based on mirostat status + ), + Input( + name="temperature", + type=float, + display_name="Temperature", + info="Controls the creativity of model responses.", + value=0.8, + ), + Input(name="input_value", type=str, display_name="Input", input_types=["Text", "Record", "Prompt"]), + Input(name="stream", type=bool, display_name="Stream", info=STREAM_INFO_TEXT, value=False), + Input( + name="system_message", + type=Optional[str], + display_name="System Message", + info="System message to pass to the model.", + advanced=True, + value=None, + ), + Input( + name="headers", + type=dict, + display_name="Headers", + info="Additional headers to send with the request.", + advanced=True, + ), + Input( + name="keep_alive_flag", + type=str, + display_params=["Keep", "Immediately", "Minute", "Hour", "sec"], + display_name="Unload interval", + info="Controls how the model unload interval is managed.", + real_time_refresh=True, + refresh_button=True, + ), + Input( + name="keep_alive", + type=int, + display_name="Interval", + info="How long the model will stay loaded into memory.", + value=None, + ), + ] + outputs = [ + Output(display_name="Text", name="text_output", method="text_response"), + Output(display_name="Language Model", name="model_output", method="model_response"), + ] + + def text_response(self) -> Text: + input_value = self.input_value + stream = self.stream + system_message = self.system_message + output = self.model_response() + result = self.get_chat_result(output, stream, input_value, system_message) + self.status = result + return result + + def model_response(self) -> BaseLanguageModel: + base_url = self.base_url or "http://localhost:11434" + model = self.model + mirostat = self.mirostat or "Disabled" + mirostat_eta = self.mirostat_eta + mirostat_tau = self.mirostat_tau + repeat_last_n = self.repeat_last_n + verbose = self.verbose + keep_alive = self.keep_alive + keep_alive_flag = self.keep_alive_flag or "Keep" + num_ctx = self.num_ctx + num_gpu = self.num_gpu + _format = self.format + metadata = self.metadata + num_thread = self.num_thread + repeat_penalty = self.repeat_penalty + stop = self.stop + system = self.system + tags = self.tags + temperature = self.temperature + template = self.template + tfs_z = self.tfs_z + timeout = self.timeout + top_k = self.top_k + top_p = self.top_p + headers = self.headers if keep_alive_flag == "Minute": keep_alive_instance = f"{keep_alive}m" @@ -307,17 +209,15 @@ class ChatOllamaComponent(LCModelComponent): keep_alive_instance = "Invalid option" mirostat_instance = 0 - if mirostat == "disable": mirostat_instance = 0 - # Mapping system settings to their corresponding values llm_params = { "base_url": base_url, "model": model, "mirostat": mirostat_instance, "keep_alive": keep_alive_instance, - "format": format, + "format": _format, "metadata": metadata, "tags": tags, "mirostat_eta": mirostat_eta, @@ -336,14 +236,14 @@ class ChatOllamaComponent(LCModelComponent): "top_k": top_k, "top_p": top_p, "verbose": verbose, + "headers": headers, } - # None Value remove llm_params = {k: v for k, v in llm_params.items() if v is not None} try: - output = ChatOllama(**llm_params) # type: ignore + output = ChatOllama(**llm_params) except Exception as e: raise ValueError("Could not initialize Ollama LLM.") from e - return self.get_chat_result(output, stream, input_value, system_message) + return output