♻️ (OllamaModel.py): Refactor OllamaModel.py to new Component standard
This commit is contained in:
parent
17065dd083
commit
9e6ede9fb0
1 changed files with 134 additions and 234 deletions
|
|
@ -1,11 +1,12 @@
|
|||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import httpx
|
||||
from langchain_community.chat_models.ollama import ChatOllama
|
||||
|
||||
from langflow.base.constants import STREAM_INFO_TEXT
|
||||
from langflow.base.models.model import LCModelComponent
|
||||
from langflow.field_typing import Text
|
||||
from langflow.field_typing import BaseLanguageModel, Text
|
||||
from langflow.template.field.base import Input, Output
|
||||
|
||||
|
||||
class ChatOllamaComponent(LCModelComponent):
|
||||
|
|
@ -13,199 +14,6 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
description = "Generate text using Ollama Local LLMs."
|
||||
icon = "Ollama"
|
||||
|
||||
field_order = [
|
||||
"base_url",
|
||||
"headers",
|
||||
"keep_alive_flag",
|
||||
"keep_alive",
|
||||
"metadata",
|
||||
"model",
|
||||
"temperature",
|
||||
"cache",
|
||||
"format",
|
||||
"metadata",
|
||||
"mirostat",
|
||||
"mirostat_eta",
|
||||
"mirostat_tau",
|
||||
"num_ctx",
|
||||
"num_gpu",
|
||||
"num_thread",
|
||||
"repeat_last_n",
|
||||
"repeat_penalty",
|
||||
"tfs_z",
|
||||
"timeout",
|
||||
"top_k",
|
||||
"top_p",
|
||||
"verbose",
|
||||
"tags",
|
||||
"stop",
|
||||
"system",
|
||||
"template",
|
||||
"input_value",
|
||||
"system_message",
|
||||
"stream",
|
||||
]
|
||||
|
||||
def build_config(self) -> dict:
|
||||
return {
|
||||
"base_url": {
|
||||
"display_name": "Base URL",
|
||||
"info": "Endpoint of the Ollama API. Defaults to 'http://localhost:11434' if not specified.",
|
||||
},
|
||||
"format": {
|
||||
"display_name": "Format",
|
||||
"info": "Specify the format of the output (e.g., json)",
|
||||
"advanced": True,
|
||||
},
|
||||
"headers": {
|
||||
"display_name": "Headers",
|
||||
"advanced": True,
|
||||
},
|
||||
"keep_alive_flag": {
|
||||
"display_name": "Unload interval",
|
||||
"options": ["Keep", "Immediately", "Minute", "Hour", "sec"],
|
||||
"real_time_refresh": True,
|
||||
"refresh_button": True,
|
||||
},
|
||||
"keep_alive": {
|
||||
"display_name": "interval",
|
||||
"info": "How long the model will stay loaded into memory.",
|
||||
},
|
||||
"model": {
|
||||
"display_name": "Model Name",
|
||||
"options": [],
|
||||
"info": "Refer to https://ollama.ai/library for more models.",
|
||||
"real_time_refresh": True,
|
||||
"refresh_button": True,
|
||||
},
|
||||
"temperature": {
|
||||
"display_name": "Temperature",
|
||||
"field_type": "float",
|
||||
"value": 0.8,
|
||||
"info": "Controls the creativity of model responses.",
|
||||
},
|
||||
"metadata": {
|
||||
"display_name": "Metadata",
|
||||
"info": "Metadata to add to the run trace.",
|
||||
"advanced": True,
|
||||
},
|
||||
"mirostat": {
|
||||
"display_name": "Mirostat",
|
||||
"options": ["Disabled", "Mirostat", "Mirostat 2.0"],
|
||||
"info": "Enable/disable Mirostat sampling for controlling perplexity.",
|
||||
"advanced": False,
|
||||
"real_time_refresh": True,
|
||||
"refresh_button": True,
|
||||
},
|
||||
"mirostat_eta": {
|
||||
"display_name": "Mirostat Eta",
|
||||
"field_type": "float",
|
||||
"info": "Learning rate for Mirostat algorithm. (Default: 0.1)",
|
||||
"advanced": True,
|
||||
"real_time_refresh": True,
|
||||
},
|
||||
"mirostat_tau": {
|
||||
"display_name": "Mirostat Tau",
|
||||
"field_type": "float",
|
||||
"info": "Controls the balance between coherence and diversity of the output. (Default: 5.0)",
|
||||
"advanced": True,
|
||||
"real_time_refresh": True,
|
||||
},
|
||||
"num_ctx": {
|
||||
"display_name": "Context Window Size",
|
||||
"field_type": "int",
|
||||
"info": "Size of the context window for generating tokens. (Default: 2048)",
|
||||
"advanced": True,
|
||||
},
|
||||
"num_gpu": {
|
||||
"display_name": "Number of GPUs",
|
||||
"field_type": "int",
|
||||
"info": "Number of GPUs to use for computation. (Default: 1 on macOS, 0 to disable)",
|
||||
"advanced": True,
|
||||
},
|
||||
"num_thread": {
|
||||
"display_name": "Number of Threads",
|
||||
"field_type": "int",
|
||||
"info": "Number of threads to use during computation. (Default: detected for optimal performance)",
|
||||
"advanced": True,
|
||||
},
|
||||
"repeat_last_n": {
|
||||
"display_name": "Repeat Last N",
|
||||
"field_type": "int",
|
||||
"info": "How far back the model looks to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx)",
|
||||
"advanced": True,
|
||||
},
|
||||
"repeat_penalty": {
|
||||
"display_name": "Repeat Penalty",
|
||||
"field_type": "float",
|
||||
"info": "Penalty for repetitions in generated text. (Default: 1.1)",
|
||||
"advanced": True,
|
||||
},
|
||||
"tfs_z": {
|
||||
"display_name": "TFS Z",
|
||||
"field_type": "float",
|
||||
"info": "Tail free sampling value. (Default: 1)",
|
||||
"advanced": True,
|
||||
},
|
||||
"timeout": {
|
||||
"display_name": "Timeout",
|
||||
"field_type": "int",
|
||||
"info": "Timeout for the request stream.",
|
||||
"advanced": True,
|
||||
},
|
||||
"top_k": {
|
||||
"display_name": "Top K",
|
||||
"field_type": "int",
|
||||
"info": "Limits token selection to top K. (Default: 40)",
|
||||
"advanced": True,
|
||||
},
|
||||
"top_p": {
|
||||
"display_name": "Top P",
|
||||
"field_type": "float",
|
||||
"info": "Works together with top-k. (Default: 0.9)",
|
||||
"advanced": True,
|
||||
},
|
||||
"verbose": {
|
||||
"display_name": "Verbose",
|
||||
"field_type": "bool",
|
||||
"info": "Whether to print out response text.",
|
||||
},
|
||||
"tags": {
|
||||
"display_name": "Tags",
|
||||
"field_type": "list",
|
||||
"info": "Tags to add to the run trace.",
|
||||
"advanced": True,
|
||||
},
|
||||
"stop": {
|
||||
"display_name": "Stop Tokens",
|
||||
"field_type": "list",
|
||||
"info": "List of tokens to signal the model to stop generating text.",
|
||||
"advanced": True,
|
||||
},
|
||||
"system": {
|
||||
"display_name": "System",
|
||||
"field_type": "str",
|
||||
"info": "System to use for generating text.",
|
||||
"advanced": True,
|
||||
},
|
||||
"template": {
|
||||
"display_name": "Template",
|
||||
"field_type": "str",
|
||||
"info": "Template to use for generating text.",
|
||||
"advanced": True,
|
||||
},
|
||||
"input_value": {"display_name": "Input", "input_types": ["Text", "Record", "Prompt"]},
|
||||
"stream": {
|
||||
"display_name": "Stream",
|
||||
"info": STREAM_INFO_TEXT,
|
||||
},
|
||||
"system_message": {
|
||||
"display_name": "System Message",
|
||||
"info": "System message to pass to the model.",
|
||||
"advanced": True,
|
||||
},
|
||||
}
|
||||
|
||||
def update_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None):
|
||||
if field_name == "mirostat":
|
||||
if field_value == "Disabled":
|
||||
|
|
@ -258,40 +66,134 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
return model_names
|
||||
except Exception as e:
|
||||
raise ValueError("Could not retrieve models") from e
|
||||
return [""]
|
||||
|
||||
def build(
|
||||
self,
|
||||
base_url: Optional[str],
|
||||
model: str,
|
||||
input_value: Text,
|
||||
mirostat: Optional[str] = "Disabled",
|
||||
mirostat_eta: Optional[float] = None,
|
||||
mirostat_tau: Optional[float] = None,
|
||||
repeat_last_n: Optional[int] = None,
|
||||
verbose: Optional[bool] = None,
|
||||
keep_alive: Optional[int] = None,
|
||||
keep_alive_flag: Optional[str] = "Keep",
|
||||
num_ctx: Optional[int] = None,
|
||||
num_gpu: Optional[int] = None,
|
||||
format: Optional[str] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
num_thread: Optional[int] = None,
|
||||
repeat_penalty: Optional[float] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
system: Optional[str] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
template: Optional[str] = None,
|
||||
tfs_z: Optional[float] = None,
|
||||
timeout: Optional[int] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
stream: bool = False,
|
||||
system_message: Optional[str] = None,
|
||||
) -> Text:
|
||||
if not base_url:
|
||||
base_url = "http://localhost:11434"
|
||||
inputs = [
|
||||
Input(
|
||||
name="base_url",
|
||||
type=Optional[str],
|
||||
display_name="Base URL",
|
||||
info="Endpoint of the Ollama API. Defaults to 'http://localhost:11434' if not specified.",
|
||||
value="http://localhost:11434",
|
||||
),
|
||||
Input(
|
||||
name="model",
|
||||
type=str,
|
||||
display_name="Model Name",
|
||||
options=[], # This should be dynamically loaded if possible
|
||||
info="Refer to https://ollama.ai/library for more models.",
|
||||
real_time_refresh=True,
|
||||
refresh_button=True,
|
||||
),
|
||||
Input(
|
||||
name="mirostat",
|
||||
type=str,
|
||||
display_name="Mirostat",
|
||||
options=["Disabled", "Mirostat", "Mirostat 2.0"],
|
||||
info="Enable/disable Mirostat sampling for controlling perplexity.",
|
||||
advanced=False,
|
||||
real_time_refresh=True,
|
||||
refresh_button=True,
|
||||
value="Disabled",
|
||||
),
|
||||
Input(
|
||||
name="mirostat_eta",
|
||||
type=Optional[float],
|
||||
display_name="Mirostat Eta",
|
||||
info="Learning rate for Mirostat algorithm.",
|
||||
advanced=True,
|
||||
real_time_refresh=True,
|
||||
value=None, # Default can vary based on mirostat status
|
||||
),
|
||||
Input(
|
||||
name="mirostat_tau",
|
||||
type=Optional[float],
|
||||
display_name="Mirostat Tau",
|
||||
info="Controls the balance between coherence and diversity of the output.",
|
||||
advanced=True,
|
||||
real_time_refresh=True,
|
||||
value=None, # Default can vary based on mirostat status
|
||||
),
|
||||
Input(
|
||||
name="temperature",
|
||||
type=float,
|
||||
display_name="Temperature",
|
||||
info="Controls the creativity of model responses.",
|
||||
value=0.8,
|
||||
),
|
||||
Input(name="input_value", type=str, display_name="Input", input_types=["Text", "Record", "Prompt"]),
|
||||
Input(name="stream", type=bool, display_name="Stream", info=STREAM_INFO_TEXT, value=False),
|
||||
Input(
|
||||
name="system_message",
|
||||
type=Optional[str],
|
||||
display_name="System Message",
|
||||
info="System message to pass to the model.",
|
||||
advanced=True,
|
||||
value=None,
|
||||
),
|
||||
Input(
|
||||
name="headers",
|
||||
type=dict,
|
||||
display_name="Headers",
|
||||
info="Additional headers to send with the request.",
|
||||
advanced=True,
|
||||
),
|
||||
Input(
|
||||
name="keep_alive_flag",
|
||||
type=str,
|
||||
display_params=["Keep", "Immediately", "Minute", "Hour", "sec"],
|
||||
display_name="Unload interval",
|
||||
info="Controls how the model unload interval is managed.",
|
||||
real_time_refresh=True,
|
||||
refresh_button=True,
|
||||
),
|
||||
Input(
|
||||
name="keep_alive",
|
||||
type=int,
|
||||
display_name="Interval",
|
||||
info="How long the model will stay loaded into memory.",
|
||||
value=None,
|
||||
),
|
||||
]
|
||||
outputs = [
|
||||
Output(display_name="Text", name="text_output", method="text_response"),
|
||||
Output(display_name="Language Model", name="model_output", method="model_response"),
|
||||
]
|
||||
|
||||
def text_response(self) -> Text:
|
||||
input_value = self.input_value
|
||||
stream = self.stream
|
||||
system_message = self.system_message
|
||||
output = self.model_response()
|
||||
result = self.get_chat_result(output, stream, input_value, system_message)
|
||||
self.status = result
|
||||
return result
|
||||
|
||||
def model_response(self) -> BaseLanguageModel:
|
||||
base_url = self.base_url or "http://localhost:11434"
|
||||
model = self.model
|
||||
mirostat = self.mirostat or "Disabled"
|
||||
mirostat_eta = self.mirostat_eta
|
||||
mirostat_tau = self.mirostat_tau
|
||||
repeat_last_n = self.repeat_last_n
|
||||
verbose = self.verbose
|
||||
keep_alive = self.keep_alive
|
||||
keep_alive_flag = self.keep_alive_flag or "Keep"
|
||||
num_ctx = self.num_ctx
|
||||
num_gpu = self.num_gpu
|
||||
_format = self.format
|
||||
metadata = self.metadata
|
||||
num_thread = self.num_thread
|
||||
repeat_penalty = self.repeat_penalty
|
||||
stop = self.stop
|
||||
system = self.system
|
||||
tags = self.tags
|
||||
temperature = self.temperature
|
||||
template = self.template
|
||||
tfs_z = self.tfs_z
|
||||
timeout = self.timeout
|
||||
top_k = self.top_k
|
||||
top_p = self.top_p
|
||||
headers = self.headers
|
||||
|
||||
if keep_alive_flag == "Minute":
|
||||
keep_alive_instance = f"{keep_alive}m"
|
||||
|
|
@ -307,17 +209,15 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
keep_alive_instance = "Invalid option"
|
||||
|
||||
mirostat_instance = 0
|
||||
|
||||
if mirostat == "disable":
|
||||
mirostat_instance = 0
|
||||
|
||||
# Mapping system settings to their corresponding values
|
||||
llm_params = {
|
||||
"base_url": base_url,
|
||||
"model": model,
|
||||
"mirostat": mirostat_instance,
|
||||
"keep_alive": keep_alive_instance,
|
||||
"format": format,
|
||||
"format": _format,
|
||||
"metadata": metadata,
|
||||
"tags": tags,
|
||||
"mirostat_eta": mirostat_eta,
|
||||
|
|
@ -336,14 +236,14 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
"top_k": top_k,
|
||||
"top_p": top_p,
|
||||
"verbose": verbose,
|
||||
"headers": headers,
|
||||
}
|
||||
|
||||
# None Value remove
|
||||
llm_params = {k: v for k, v in llm_params.items() if v is not None}
|
||||
|
||||
try:
|
||||
output = ChatOllama(**llm_params) # type: ignore
|
||||
output = ChatOllama(**llm_params)
|
||||
except Exception as e:
|
||||
raise ValueError("Could not initialize Ollama LLM.") from e
|
||||
|
||||
return self.get_chat_result(output, stream, input_value, system_message)
|
||||
return output
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue