♻️ (OllamaModel.py): Refactor OllamaModel.py to new Component standard

This commit is contained in:
ogabrielluiz 2024-06-11 23:05:19 -03:00
commit 9e6ede9fb0

View file

@ -1,11 +1,12 @@
from typing import Any, Dict, List, Optional
from typing import Any, List, Optional
import httpx
from langchain_community.chat_models.ollama import ChatOllama
from langflow.base.constants import STREAM_INFO_TEXT
from langflow.base.models.model import LCModelComponent
from langflow.field_typing import Text
from langflow.field_typing import BaseLanguageModel, Text
from langflow.template.field.base import Input, Output
class ChatOllamaComponent(LCModelComponent):
@ -13,199 +14,6 @@ class ChatOllamaComponent(LCModelComponent):
description = "Generate text using Ollama Local LLMs."
icon = "Ollama"
field_order = [
"base_url",
"headers",
"keep_alive_flag",
"keep_alive",
"metadata",
"model",
"temperature",
"cache",
"format",
"metadata",
"mirostat",
"mirostat_eta",
"mirostat_tau",
"num_ctx",
"num_gpu",
"num_thread",
"repeat_last_n",
"repeat_penalty",
"tfs_z",
"timeout",
"top_k",
"top_p",
"verbose",
"tags",
"stop",
"system",
"template",
"input_value",
"system_message",
"stream",
]
def build_config(self) -> dict:
return {
"base_url": {
"display_name": "Base URL",
"info": "Endpoint of the Ollama API. Defaults to 'http://localhost:11434' if not specified.",
},
"format": {
"display_name": "Format",
"info": "Specify the format of the output (e.g., json)",
"advanced": True,
},
"headers": {
"display_name": "Headers",
"advanced": True,
},
"keep_alive_flag": {
"display_name": "Unload interval",
"options": ["Keep", "Immediately", "Minute", "Hour", "sec"],
"real_time_refresh": True,
"refresh_button": True,
},
"keep_alive": {
"display_name": "interval",
"info": "How long the model will stay loaded into memory.",
},
"model": {
"display_name": "Model Name",
"options": [],
"info": "Refer to https://ollama.ai/library for more models.",
"real_time_refresh": True,
"refresh_button": True,
},
"temperature": {
"display_name": "Temperature",
"field_type": "float",
"value": 0.8,
"info": "Controls the creativity of model responses.",
},
"metadata": {
"display_name": "Metadata",
"info": "Metadata to add to the run trace.",
"advanced": True,
},
"mirostat": {
"display_name": "Mirostat",
"options": ["Disabled", "Mirostat", "Mirostat 2.0"],
"info": "Enable/disable Mirostat sampling for controlling perplexity.",
"advanced": False,
"real_time_refresh": True,
"refresh_button": True,
},
"mirostat_eta": {
"display_name": "Mirostat Eta",
"field_type": "float",
"info": "Learning rate for Mirostat algorithm. (Default: 0.1)",
"advanced": True,
"real_time_refresh": True,
},
"mirostat_tau": {
"display_name": "Mirostat Tau",
"field_type": "float",
"info": "Controls the balance between coherence and diversity of the output. (Default: 5.0)",
"advanced": True,
"real_time_refresh": True,
},
"num_ctx": {
"display_name": "Context Window Size",
"field_type": "int",
"info": "Size of the context window for generating tokens. (Default: 2048)",
"advanced": True,
},
"num_gpu": {
"display_name": "Number of GPUs",
"field_type": "int",
"info": "Number of GPUs to use for computation. (Default: 1 on macOS, 0 to disable)",
"advanced": True,
},
"num_thread": {
"display_name": "Number of Threads",
"field_type": "int",
"info": "Number of threads to use during computation. (Default: detected for optimal performance)",
"advanced": True,
},
"repeat_last_n": {
"display_name": "Repeat Last N",
"field_type": "int",
"info": "How far back the model looks to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx)",
"advanced": True,
},
"repeat_penalty": {
"display_name": "Repeat Penalty",
"field_type": "float",
"info": "Penalty for repetitions in generated text. (Default: 1.1)",
"advanced": True,
},
"tfs_z": {
"display_name": "TFS Z",
"field_type": "float",
"info": "Tail free sampling value. (Default: 1)",
"advanced": True,
},
"timeout": {
"display_name": "Timeout",
"field_type": "int",
"info": "Timeout for the request stream.",
"advanced": True,
},
"top_k": {
"display_name": "Top K",
"field_type": "int",
"info": "Limits token selection to top K. (Default: 40)",
"advanced": True,
},
"top_p": {
"display_name": "Top P",
"field_type": "float",
"info": "Works together with top-k. (Default: 0.9)",
"advanced": True,
},
"verbose": {
"display_name": "Verbose",
"field_type": "bool",
"info": "Whether to print out response text.",
},
"tags": {
"display_name": "Tags",
"field_type": "list",
"info": "Tags to add to the run trace.",
"advanced": True,
},
"stop": {
"display_name": "Stop Tokens",
"field_type": "list",
"info": "List of tokens to signal the model to stop generating text.",
"advanced": True,
},
"system": {
"display_name": "System",
"field_type": "str",
"info": "System to use for generating text.",
"advanced": True,
},
"template": {
"display_name": "Template",
"field_type": "str",
"info": "Template to use for generating text.",
"advanced": True,
},
"input_value": {"display_name": "Input", "input_types": ["Text", "Record", "Prompt"]},
"stream": {
"display_name": "Stream",
"info": STREAM_INFO_TEXT,
},
"system_message": {
"display_name": "System Message",
"info": "System message to pass to the model.",
"advanced": True,
},
}
def update_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None):
if field_name == "mirostat":
if field_value == "Disabled":
@ -258,40 +66,134 @@ class ChatOllamaComponent(LCModelComponent):
return model_names
except Exception as e:
raise ValueError("Could not retrieve models") from e
return [""]
def build(
self,
base_url: Optional[str],
model: str,
input_value: Text,
mirostat: Optional[str] = "Disabled",
mirostat_eta: Optional[float] = None,
mirostat_tau: Optional[float] = None,
repeat_last_n: Optional[int] = None,
verbose: Optional[bool] = None,
keep_alive: Optional[int] = None,
keep_alive_flag: Optional[str] = "Keep",
num_ctx: Optional[int] = None,
num_gpu: Optional[int] = None,
format: Optional[str] = None,
metadata: Optional[Dict] = None,
num_thread: Optional[int] = None,
repeat_penalty: Optional[float] = None,
stop: Optional[List[str]] = None,
system: Optional[str] = None,
tags: Optional[List[str]] = None,
temperature: Optional[float] = None,
template: Optional[str] = None,
tfs_z: Optional[float] = None,
timeout: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[int] = None,
stream: bool = False,
system_message: Optional[str] = None,
) -> Text:
if not base_url:
base_url = "http://localhost:11434"
inputs = [
Input(
name="base_url",
type=Optional[str],
display_name="Base URL",
info="Endpoint of the Ollama API. Defaults to 'http://localhost:11434' if not specified.",
value="http://localhost:11434",
),
Input(
name="model",
type=str,
display_name="Model Name",
options=[], # This should be dynamically loaded if possible
info="Refer to https://ollama.ai/library for more models.",
real_time_refresh=True,
refresh_button=True,
),
Input(
name="mirostat",
type=str,
display_name="Mirostat",
options=["Disabled", "Mirostat", "Mirostat 2.0"],
info="Enable/disable Mirostat sampling for controlling perplexity.",
advanced=False,
real_time_refresh=True,
refresh_button=True,
value="Disabled",
),
Input(
name="mirostat_eta",
type=Optional[float],
display_name="Mirostat Eta",
info="Learning rate for Mirostat algorithm.",
advanced=True,
real_time_refresh=True,
value=None, # Default can vary based on mirostat status
),
Input(
name="mirostat_tau",
type=Optional[float],
display_name="Mirostat Tau",
info="Controls the balance between coherence and diversity of the output.",
advanced=True,
real_time_refresh=True,
value=None, # Default can vary based on mirostat status
),
Input(
name="temperature",
type=float,
display_name="Temperature",
info="Controls the creativity of model responses.",
value=0.8,
),
Input(name="input_value", type=str, display_name="Input", input_types=["Text", "Record", "Prompt"]),
Input(name="stream", type=bool, display_name="Stream", info=STREAM_INFO_TEXT, value=False),
Input(
name="system_message",
type=Optional[str],
display_name="System Message",
info="System message to pass to the model.",
advanced=True,
value=None,
),
Input(
name="headers",
type=dict,
display_name="Headers",
info="Additional headers to send with the request.",
advanced=True,
),
Input(
name="keep_alive_flag",
type=str,
display_params=["Keep", "Immediately", "Minute", "Hour", "sec"],
display_name="Unload interval",
info="Controls how the model unload interval is managed.",
real_time_refresh=True,
refresh_button=True,
),
Input(
name="keep_alive",
type=int,
display_name="Interval",
info="How long the model will stay loaded into memory.",
value=None,
),
]
outputs = [
Output(display_name="Text", name="text_output", method="text_response"),
Output(display_name="Language Model", name="model_output", method="model_response"),
]
def text_response(self) -> Text:
input_value = self.input_value
stream = self.stream
system_message = self.system_message
output = self.model_response()
result = self.get_chat_result(output, stream, input_value, system_message)
self.status = result
return result
def model_response(self) -> BaseLanguageModel:
base_url = self.base_url or "http://localhost:11434"
model = self.model
mirostat = self.mirostat or "Disabled"
mirostat_eta = self.mirostat_eta
mirostat_tau = self.mirostat_tau
repeat_last_n = self.repeat_last_n
verbose = self.verbose
keep_alive = self.keep_alive
keep_alive_flag = self.keep_alive_flag or "Keep"
num_ctx = self.num_ctx
num_gpu = self.num_gpu
_format = self.format
metadata = self.metadata
num_thread = self.num_thread
repeat_penalty = self.repeat_penalty
stop = self.stop
system = self.system
tags = self.tags
temperature = self.temperature
template = self.template
tfs_z = self.tfs_z
timeout = self.timeout
top_k = self.top_k
top_p = self.top_p
headers = self.headers
if keep_alive_flag == "Minute":
keep_alive_instance = f"{keep_alive}m"
@ -307,17 +209,15 @@ class ChatOllamaComponent(LCModelComponent):
keep_alive_instance = "Invalid option"
mirostat_instance = 0
if mirostat == "disable":
mirostat_instance = 0
# Mapping system settings to their corresponding values
llm_params = {
"base_url": base_url,
"model": model,
"mirostat": mirostat_instance,
"keep_alive": keep_alive_instance,
"format": format,
"format": _format,
"metadata": metadata,
"tags": tags,
"mirostat_eta": mirostat_eta,
@ -336,14 +236,14 @@ class ChatOllamaComponent(LCModelComponent):
"top_k": top_k,
"top_p": top_p,
"verbose": verbose,
"headers": headers,
}
# None Value remove
llm_params = {k: v for k, v in llm_params.items() if v is not None}
try:
output = ChatOllama(**llm_params) # type: ignore
output = ChatOllama(**llm_params)
except Exception as e:
raise ValueError("Could not initialize Ollama LLM.") from e
return self.get_chat_result(output, stream, input_value, system_message)
return output