A better implementation of the Ollama component (#1701)
* Update OllamaModel.py A draft to synchronize the model using the latest Langflow architecture and to improve it according to the latest Langchain specifications. * Update OllamaModel.py Checkout Models from api * Update OllamaModel.py * Update OllamaModel.py --------- Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>
This commit is contained in:
parent
7d67f36000
commit
72f09d65b6
1 changed files with 138 additions and 50 deletions
|
|
@ -1,16 +1,21 @@
|
|||
from typing import Dict, List, Optional
|
||||
|
||||
# from langchain_community.chat_models import ChatOllama
|
||||
from langchain_community.chat_models import ChatOllama
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
|
||||
|
||||
from langchain_community.chat_models.ollama import ChatOllama
|
||||
from langflow.base.constants import STREAM_INFO_TEXT
|
||||
from langflow.base.models.model import LCModelComponent
|
||||
from langchain_core.caches import BaseCache
|
||||
|
||||
# from langchain.chat_models import ChatOllama
|
||||
from langflow.field_typing import Text
|
||||
|
||||
# whe When a callback component is added to Langflow, the comment must be uncommented.
|
||||
# from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
|
||||
class ChatOllamaComponent(LCModelComponent):
|
||||
|
|
@ -20,11 +25,19 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
|
||||
field_order = [
|
||||
"base_url",
|
||||
"headers",
|
||||
|
||||
"keep_alive_flag",
|
||||
"keep_alive",
|
||||
|
||||
"metadata",
|
||||
"model",
|
||||
|
||||
|
||||
"temperature",
|
||||
"cache",
|
||||
"callback_manager",
|
||||
"callbacks",
|
||||
|
||||
|
||||
"format",
|
||||
"metadata",
|
||||
"mirostat",
|
||||
|
|
@ -54,12 +67,41 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
"base_url": {
|
||||
"display_name": "Base URL",
|
||||
"info": "Endpoint of the Ollama API. Defaults to 'http://localhost:11434' if not specified.",
|
||||
|
||||
},
|
||||
|
||||
|
||||
"format": {
|
||||
"display_name": "Format",
|
||||
"info": "Specify the format of the output (e.g., json)",
|
||||
"advanced": True,
|
||||
},
|
||||
"headers": {
|
||||
"display_name": "Headers",
|
||||
"advanced": True,
|
||||
|
||||
|
||||
},
|
||||
|
||||
"keep_alive_flag": {
|
||||
"display_name": "Unload interval",
|
||||
"options": ["Keep", "Immediately","Minute", "Hour", "sec" ],
|
||||
"real_time_refresh": True,
|
||||
"refresh_button": True,
|
||||
},
|
||||
"keep_alive": {
|
||||
"display_name": "interval",
|
||||
"info": "How long the model will stay loaded into memory.",
|
||||
},
|
||||
|
||||
|
||||
|
||||
"model": {
|
||||
"display_name": "Model Name",
|
||||
"value": "llama2",
|
||||
"options": [],
|
||||
"info": "Refer to https://ollama.ai/library for more models.",
|
||||
"real_time_refresh": True,
|
||||
"refresh_button": True,
|
||||
},
|
||||
"temperature": {
|
||||
"display_name": "Temperature",
|
||||
|
|
@ -67,25 +109,8 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
"value": 0.8,
|
||||
"info": "Controls the creativity of model responses.",
|
||||
},
|
||||
"cache": {
|
||||
"display_name": "Cache",
|
||||
"field_type": "bool",
|
||||
"info": "Enable or disable caching.",
|
||||
"advanced": True,
|
||||
"value": False,
|
||||
},
|
||||
### When a callback component is added to Langflow, the comment must be uncommented. ###
|
||||
# "callback_manager": {
|
||||
# "display_name": "Callback Manager",
|
||||
# "info": "Optional callback manager for additional functionality.",
|
||||
# "advanced": True,
|
||||
# },
|
||||
# "callbacks": {
|
||||
# "display_name": "Callbacks",
|
||||
# "info": "Callbacks to execute during model runtime.",
|
||||
# "advanced": True,
|
||||
# },
|
||||
########################################################################################
|
||||
|
||||
|
||||
"format": {
|
||||
"display_name": "Format",
|
||||
"field_type": "str",
|
||||
|
|
@ -101,20 +126,24 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
"display_name": "Mirostat",
|
||||
"options": ["Disabled", "Mirostat", "Mirostat 2.0"],
|
||||
"info": "Enable/disable Mirostat sampling for controlling perplexity.",
|
||||
"value": "Disabled",
|
||||
"advanced": True,
|
||||
"advanced": False,
|
||||
"real_time_refresh": True,
|
||||
"refresh_button": True,
|
||||
|
||||
},
|
||||
"mirostat_eta": {
|
||||
"display_name": "Mirostat Eta",
|
||||
"field_type": "float",
|
||||
"info": "Learning rate for Mirostat algorithm. (Default: 0.1)",
|
||||
"advanced": True,
|
||||
"real_time_refresh": True,
|
||||
},
|
||||
"mirostat_tau": {
|
||||
"display_name": "Mirostat Tau",
|
||||
"field_type": "float",
|
||||
"info": "Controls the balance between coherence and diversity of the output. (Default: 5.0)",
|
||||
"advanced": True,
|
||||
"real_time_refresh": True,
|
||||
},
|
||||
"num_ctx": {
|
||||
"display_name": "Context Window Size",
|
||||
|
|
@ -211,21 +240,74 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
},
|
||||
}
|
||||
|
||||
def update_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None):
|
||||
if field_name == "mirostat":
|
||||
if field_value == "Disabled":
|
||||
build_config["mirostat_eta"]["advanced"] = True
|
||||
build_config["mirostat_tau"]["advanced"] = True
|
||||
build_config["mirostat_eta"]["value"] = None
|
||||
build_config["mirostat_tau"]["value"] = None
|
||||
|
||||
else:
|
||||
build_config["mirostat_eta"]["advanced"] = False
|
||||
build_config["mirostat_tau"]["advanced"] = False
|
||||
|
||||
if field_value == "Mirostat 2.0":
|
||||
build_config["mirostat_eta"]["value"] = 0.2
|
||||
build_config["mirostat_tau"]["value"] = 10
|
||||
else:
|
||||
build_config["mirostat_eta"]["value"] = 0.1
|
||||
build_config["mirostat_tau"]["value"] = 5
|
||||
|
||||
if field_name == "model":
|
||||
base_url = build_config.get("base_url", {}).get(
|
||||
"value", "http://localhost:11434")
|
||||
build_config["model"]["options"] = self.get_model(
|
||||
base_url + "/api/tags")
|
||||
|
||||
if field_name == "keep_alive_flag":
|
||||
if field_value == "Keep":
|
||||
build_config["keep_alive"]["value"] = "-1"
|
||||
build_config["keep_alive"]["advanced"] = True
|
||||
elif field_value == "Immediately":
|
||||
build_config["keep_alive"]["value"] = "0"
|
||||
build_config["keep_alive"]["advanced"] = True
|
||||
else:
|
||||
build_config["keep_alive"]["advanced"] = False
|
||||
|
||||
return build_config
|
||||
|
||||
|
||||
|
||||
|
||||
def get_model(self, url: str) -> List[str]:
|
||||
try:
|
||||
with httpx.Client() as client:
|
||||
response = client.get(url)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
model_names = [model['name']
|
||||
for model in data.get("models", [])]
|
||||
return model_names
|
||||
except Exception as e:
|
||||
raise ValueError("Could not retrieve models") from e
|
||||
return [""]
|
||||
|
||||
def build(
|
||||
self,
|
||||
base_url: Optional[str],
|
||||
model: str,
|
||||
input_value: Text,
|
||||
mirostat: Optional[str],
|
||||
|
||||
mirostat: Optional[str],
|
||||
mirostat_eta: Optional[float] = None,
|
||||
mirostat_tau: Optional[float] = None,
|
||||
### When a callback component is added to Langflow, the comment must be uncommented.###
|
||||
# callback_manager: Optional[CallbackManager] = None,
|
||||
# callbacks: Optional[List[Callbacks]] = None,
|
||||
#######################################################################################
|
||||
|
||||
repeat_last_n: Optional[int] = None,
|
||||
verbose: Optional[bool] = None,
|
||||
cache: Optional[bool] = None,
|
||||
keep_alive: Optional[int] = None,
|
||||
keep_alive_flag: Optional[str] = None,
|
||||
num_ctx: Optional[int] = None,
|
||||
num_gpu: Optional[int] = None,
|
||||
format: Optional[str] = None,
|
||||
|
|
@ -244,33 +326,39 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
stream: bool = False,
|
||||
system_message: Optional[str] = None,
|
||||
) -> Text:
|
||||
|
||||
if not base_url:
|
||||
base_url = "http://localhost:11434"
|
||||
|
||||
# Mapping mirostat settings to their corresponding values
|
||||
mirostat_options = {"Mirostat": 1, "Mirostat 2.0": 2}
|
||||
|
||||
# Default to 0 for 'Disabled'
|
||||
mirostat_value = mirostat_options.get(mirostat, 0) # type: ignore
|
||||
|
||||
# Set mirostat_eta and mirostat_tau to None if mirostat is disabled
|
||||
if mirostat_value == 0:
|
||||
mirostat_eta = None
|
||||
mirostat_tau = None
|
||||
if keep_alive_flag == "Minute":
|
||||
keep_alive_instance = f"{keep_alive}m"
|
||||
elif keep_alive_flag == "Hour":
|
||||
keep_alive_instance = f"{keep_alive}h"
|
||||
elif keep_alive_flag == "sec":
|
||||
keep_alive_instance = f"{keep_alive}s"
|
||||
elif keep_alive_flag == "Keep":
|
||||
keep_alive_instance = "-1"
|
||||
elif keep_alive_flag == "Immediately":
|
||||
keep_alive_instance = "0"
|
||||
else:
|
||||
keep_alive_instance = "Invalid option"
|
||||
|
||||
mirostat_instance = 0
|
||||
|
||||
if mirostat == "disable":
|
||||
mirostat_instance = 0
|
||||
|
||||
# Mapping system settings to their corresponding values
|
||||
llm_params = {
|
||||
"base_url": base_url,
|
||||
"cache": cache,
|
||||
"model": model,
|
||||
"mirostat": mirostat_value,
|
||||
"mirostat": mirostat_instance,
|
||||
"keep_alive": keep_alive_instance,
|
||||
"format": format,
|
||||
"metadata": metadata,
|
||||
"tags": tags,
|
||||
## When a callback component is added to Langflow, the comment must be uncommented.##
|
||||
# "callback_manager": callback_manager,
|
||||
# "callbacks": callbacks,
|
||||
#####################################################################################
|
||||
"mirostat_eta": mirostat_eta,
|
||||
"mirostat_tau": mirostat_tau,
|
||||
"num_ctx": num_ctx,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue