From 72f09d65b64f0428ef0863de0ee97b16109c28d3 Mon Sep 17 00:00:00 2001 From: yamonbt Date: Wed, 29 May 2024 00:20:58 +0900 Subject: [PATCH] A better implementation of the Ollama component (#1701) * Update OllamaModel.py A draft to synchronize the model using the latest Langflow architecture and to improve it according to the latest Langchain specifications. * Update OllamaModel.py Checkout Models from api * Update OllamaModel.py * Update OllamaModel.py --------- Co-authored-by: Gabriel Luiz Freitas Almeida --- .../langflow/components/models/OllamaModel.py | 188 +++++++++++++----- 1 file changed, 138 insertions(+), 50 deletions(-) diff --git a/src/backend/base/langflow/components/models/OllamaModel.py b/src/backend/base/langflow/components/models/OllamaModel.py index d9ba48501..5806ad770 100644 --- a/src/backend/base/langflow/components/models/OllamaModel.py +++ b/src/backend/base/langflow/components/models/OllamaModel.py @@ -1,16 +1,21 @@ -from typing import Dict, List, Optional -# from langchain_community.chat_models import ChatOllama -from langchain_community.chat_models import ChatOllama +from typing import Any, Dict, List, Optional, Union + + +from langchain_community.chat_models.ollama import ChatOllama from langflow.base.constants import STREAM_INFO_TEXT from langflow.base.models.model import LCModelComponent +from langchain_core.caches import BaseCache -# from langchain.chat_models import ChatOllama from langflow.field_typing import Text -# whe When a callback component is added to Langflow, the comment must be uncommented. -# from langchain.callbacks.manager import CallbackManager + +import asyncio +import json + +import httpx + class ChatOllamaComponent(LCModelComponent): @@ -20,11 +25,19 @@ class ChatOllamaComponent(LCModelComponent): field_order = [ "base_url", + "headers", + + "keep_alive_flag", + "keep_alive", + + "metadata", "model", + + "temperature", "cache", - "callback_manager", - "callbacks", + + "format", "metadata", "mirostat", @@ -54,12 +67,41 @@ class ChatOllamaComponent(LCModelComponent): "base_url": { "display_name": "Base URL", "info": "Endpoint of the Ollama API. Defaults to 'http://localhost:11434' if not specified.", + + }, + + + "format": { + "display_name": "Format", + "info": "Specify the format of the output (e.g., json)", "advanced": True, }, + "headers": { + "display_name": "Headers", + "advanced": True, + + + }, + + "keep_alive_flag": { + "display_name": "Unload interval", + "options": ["Keep", "Immediately","Minute", "Hour", "sec" ], + "real_time_refresh": True, + "refresh_button": True, + }, + "keep_alive": { + "display_name": "interval", + "info": "How long the model will stay loaded into memory.", + }, + + + "model": { "display_name": "Model Name", - "value": "llama2", + "options": [], "info": "Refer to https://ollama.ai/library for more models.", + "real_time_refresh": True, + "refresh_button": True, }, "temperature": { "display_name": "Temperature", @@ -67,25 +109,8 @@ class ChatOllamaComponent(LCModelComponent): "value": 0.8, "info": "Controls the creativity of model responses.", }, - "cache": { - "display_name": "Cache", - "field_type": "bool", - "info": "Enable or disable caching.", - "advanced": True, - "value": False, - }, - ### When a callback component is added to Langflow, the comment must be uncommented. ### - # "callback_manager": { - # "display_name": "Callback Manager", - # "info": "Optional callback manager for additional functionality.", - # "advanced": True, - # }, - # "callbacks": { - # "display_name": "Callbacks", - # "info": "Callbacks to execute during model runtime.", - # "advanced": True, - # }, - ######################################################################################## + + "format": { "display_name": "Format", "field_type": "str", @@ -101,20 +126,24 @@ class ChatOllamaComponent(LCModelComponent): "display_name": "Mirostat", "options": ["Disabled", "Mirostat", "Mirostat 2.0"], "info": "Enable/disable Mirostat sampling for controlling perplexity.", - "value": "Disabled", - "advanced": True, + "advanced": False, + "real_time_refresh": True, + "refresh_button": True, + }, "mirostat_eta": { "display_name": "Mirostat Eta", "field_type": "float", "info": "Learning rate for Mirostat algorithm. (Default: 0.1)", "advanced": True, + "real_time_refresh": True, }, "mirostat_tau": { "display_name": "Mirostat Tau", "field_type": "float", "info": "Controls the balance between coherence and diversity of the output. (Default: 5.0)", "advanced": True, + "real_time_refresh": True, }, "num_ctx": { "display_name": "Context Window Size", @@ -211,21 +240,74 @@ class ChatOllamaComponent(LCModelComponent): }, } + def update_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None): + if field_name == "mirostat": + if field_value == "Disabled": + build_config["mirostat_eta"]["advanced"] = True + build_config["mirostat_tau"]["advanced"] = True + build_config["mirostat_eta"]["value"] = None + build_config["mirostat_tau"]["value"] = None + + else: + build_config["mirostat_eta"]["advanced"] = False + build_config["mirostat_tau"]["advanced"] = False + + if field_value == "Mirostat 2.0": + build_config["mirostat_eta"]["value"] = 0.2 + build_config["mirostat_tau"]["value"] = 10 + else: + build_config["mirostat_eta"]["value"] = 0.1 + build_config["mirostat_tau"]["value"] = 5 + + if field_name == "model": + base_url = build_config.get("base_url", {}).get( + "value", "http://localhost:11434") + build_config["model"]["options"] = self.get_model( + base_url + "/api/tags") + + if field_name == "keep_alive_flag": + if field_value == "Keep": + build_config["keep_alive"]["value"] = "-1" + build_config["keep_alive"]["advanced"] = True + elif field_value == "Immediately": + build_config["keep_alive"]["value"] = "0" + build_config["keep_alive"]["advanced"] = True + else: + build_config["keep_alive"]["advanced"] = False + + return build_config + + + + + def get_model(self, url: str) -> List[str]: + try: + with httpx.Client() as client: + response = client.get(url) + response.raise_for_status() + data = response.json() + + model_names = [model['name'] + for model in data.get("models", [])] + return model_names + except Exception as e: + raise ValueError("Could not retrieve models") from e + return [""] + def build( self, base_url: Optional[str], model: str, input_value: Text, - mirostat: Optional[str], + + mirostat: Optional[str], mirostat_eta: Optional[float] = None, mirostat_tau: Optional[float] = None, - ### When a callback component is added to Langflow, the comment must be uncommented.### - # callback_manager: Optional[CallbackManager] = None, - # callbacks: Optional[List[Callbacks]] = None, - ####################################################################################### + repeat_last_n: Optional[int] = None, verbose: Optional[bool] = None, - cache: Optional[bool] = None, + keep_alive: Optional[int] = None, + keep_alive_flag: Optional[str] = None, num_ctx: Optional[int] = None, num_gpu: Optional[int] = None, format: Optional[str] = None, @@ -244,33 +326,39 @@ class ChatOllamaComponent(LCModelComponent): stream: bool = False, system_message: Optional[str] = None, ) -> Text: + if not base_url: base_url = "http://localhost:11434" - # Mapping mirostat settings to their corresponding values - mirostat_options = {"Mirostat": 1, "Mirostat 2.0": 2} - # Default to 0 for 'Disabled' - mirostat_value = mirostat_options.get(mirostat, 0) # type: ignore - # Set mirostat_eta and mirostat_tau to None if mirostat is disabled - if mirostat_value == 0: - mirostat_eta = None - mirostat_tau = None + if keep_alive_flag == "Minute": + keep_alive_instance = f"{keep_alive}m" + elif keep_alive_flag == "Hour": + keep_alive_instance = f"{keep_alive}h" + elif keep_alive_flag == "sec": + keep_alive_instance = f"{keep_alive}s" + elif keep_alive_flag == "Keep": + keep_alive_instance = "-1" + elif keep_alive_flag == "Immediately": + keep_alive_instance = "0" + else: + keep_alive_instance = "Invalid option" + + mirostat_instance = 0 + + if mirostat == "disable": + mirostat_instance = 0 # Mapping system settings to their corresponding values llm_params = { "base_url": base_url, - "cache": cache, "model": model, - "mirostat": mirostat_value, + "mirostat": mirostat_instance, + "keep_alive": keep_alive_instance, "format": format, "metadata": metadata, "tags": tags, - ## When a callback component is added to Langflow, the comment must be uncommented.## - # "callback_manager": callback_manager, - # "callbacks": callbacks, - ##################################################################################### "mirostat_eta": mirostat_eta, "mirostat_tau": mirostat_tau, "num_ctx": num_ctx,