Update ChatOllamaEndpoint.py

This commit is contained in:
yamonkjd 2023-12-23 03:57:38 +09:00 committed by GitHub
commit bea5065237
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1,9 +1,15 @@
from typing import Optional, List
from typing import Optional, List, Dict, Any
from langchain.chat_models.base import BaseChatModel
from langchain_community.chat_models import ChatOllama
# from langchain_community.chat_models import ChatOllama
from langchain.chat_models import ChatOllama
# from langchain.chat_models import ChatOllama
from langflow import CustomComponent
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
# whe When a callback component is added to Langflow, the comment must be uncommented.
# from langchain.callbacks.manager import CallbackManager
class ChatOllamaComponent(CustomComponent):
@ -14,174 +20,208 @@ class ChatOllamaComponent(CustomComponent):
return {
"base_url": {
"display_name": "Base URL",
"value": "http://localhost:11434",
"info": "Endpoint of the Ollama API."
"info": "Endpoint of the Ollama API. Defaults to 'http://localhost:11434' if not specified.",
},
"model": {
"display_name": "Model Name",
"value": "llama2",
"info": "Refer to https://ollama.ai/library for more models."
"info": "Refer to https://ollama.ai/library for more models.",
},
"temperature": {
"display_name": "Temperature",
"field_type": "float",
"value": 0.8,
"info": "Controls the creativity of model responses."
"info": "Controls the creativity of model responses.",
},
"cache": {
"display_name": "Cache",
"field_type": "bool",
"info": "Enable or disable caching.",
"advanced": True,
"value": False
},
"callback_manager": {
"display_name": "Callback Manager",
"info": "Optional callback manager for additional functionality.",
"advanced": True,
"value": None
},
"callbacks": {
"display_name": "Callbacks",
"info": "Callbacks to execute during model runtime.",
"advanced": True,
"value": None
"value": False,
},
### When a callback component is added to Langflow, the comment must be uncommented. ###
# "callback_manager": {
# "display_name": "Callback Manager",
# "info": "Optional callback manager for additional functionality.",
# "advanced": True,
# },
# "callbacks": {
# "display_name": "Callbacks",
# "info": "Callbacks to execute during model runtime.",
# "advanced": True,
# },
########################################################################################
"format": {
"display_name": "Format",
"field_type": "str",
"info": "Specify the format of the output (e.g., json).",
"advanced": True,
"value": None
},
"metadata": {
"display_name": "Metadata",
"info": "Metadata to add to the run trace.",
"advanced": True,
"value": None
},
"mirostat": {
"display_name": "Mirostat",
"field_type": "int",
"info": "Enable Mirostat sampling for controlling perplexity. (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)",
"options": ["Disabled", "Mirostat", "Mirostat 2.0"],
"info": "Enable/disable Mirostat sampling for controlling perplexity.",
"value": "Disabled",
"advanced": True,
"value": 0
},
"mirostat_eta": {
"display_name": "Mirostat Eta",
"field_type": "float",
"info": "Learning rate for Mirostat algorithm. (Default: 0.1)",
"advanced": True,
"value": 0.1
},
"mirostat_tau": {
"display_name": "Mirostat Tau",
"field_type": "float",
"info": "Controls the balance between coherence and diversity of the output. (Default: 5.0)",
"advanced": True,
"value": 5.0
},
"num_ctx": {
"display_name": "Context Window Size",
"field_type": "int",
"info": "Size of the context window for generating tokens. (Default: 2048)",
"advanced": True,
"value": 2048
},
"num_gpu": {
"display_name": "Number of GPUs",
"field_type": "int",
"info": "Number of GPUs to use for computation. (Default: 1 on macOS, 0 to disable)",
"advanced": True,
"value": 0
},
"num_thread": {
"display_name": "Number of Threads",
"field_type": "int",
"info": "Number of threads to use during computation. (Default: detected for optimal performance)",
"advanced": True,
"value": None
},
"repeat_last_n": {
"display_name": "Repeat Last N",
"field_type": "int",
"info": "How far back the model looks to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx)",
"advanced": True,
"value": 64
},
"repeat_penalty": {
"display_name": "Repeat Penalty",
"field_type": "float",
"info": "Penalty for repetitions in generated text. (Default: 1.1)",
"advanced": True,
"value": 1.1
},
"tfs_z": {
"display_name": "TFS Z",
"field_type": "float",
"info": "Tail free sampling value. (Default: 1)",
"advanced": True,
"value": 1.0
},
"timeout": {
"display_name": "Timeout",
"field_type": "int",
"info": "Timeout for the request stream.",
"advanced": True,
"value": None
},
"top_k": {
"display_name": "Top K",
"field_type": "int",
"info": "Limits token selection to top K. (Default: 40)",
"advanced": True,
"value": 40
},
"top_p": {
"display_name": "Top P",
"field_type": "float",
"info": "Works together with top-k. (Default: 0.9)",
"advanced": True,
"value": 0.9
},
"verbose": {
"display_name": "Verbose",
"field_type": "bool",
"info": "Whether to print out response text.",
"value": None
},
"tags": {
"display_name": "Tags",
"field_type": "list",
"info": "Tags to add to the run trace.",
"advanced": True,
"value": None
},
"stop": {
"display_name": "Stop Tokens",
"field_type": "list",
"info": "List of tokens to signal the model to stop generating text.",
"advanced": True,
},
"system": {
"display_name": "System",
"field_type": "str",
"info": "System to use for generating text.",
"advanced": True,
},
"template": {
"display_name": "Template",
"field_type": "str",
"info": "Template to use for generating text.",
},
}
def build(self, base_url: str, model: str, mirostat: Optional[int],
mirostat_eta: Optional[float], mirostat_tau: Optional[float],
num_ctx: Optional[int], num_gpu: Optional[int],
repeat_last_n: Optional[int],
repeat_penalty: Optional[float], temperature: Optional[float],
tfs_z: Optional[float],
num_thread: Optional[int] = None,
stop: Optional[List[str]] = None,
tags: Optional[List[str]] = None,
system: Optional[str] = None,
template: Optional[str] = None,
timeout: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[int] = None, verbose: Optional[bool] = None
) -> BaseChatModel:
def build(
self,
base_url: Optional[str],
model: str,
mirostat: Optional[str],
mirostat_eta: Optional[float] = None,
mirostat_tau: Optional[float] = None,
### When a callback component is added to Langflow, the comment must be uncommented.###
# callback_manager: Optional[CallbackManager] = None,
# callbacks: Optional[List[Callbacks]] = None,
#######################################################################################
repeat_last_n: Optional[int] = None,
verbose: Optional[bool] = None,
cache: Optional[bool] = None,
num_ctx: Optional[int] = None,
num_gpu: Optional[int] = None,
format: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
num_thread: Optional[int] = None,
repeat_penalty: Optional[float] = None,
stop: Optional[List[str]] = None,
system: Optional[str] = None,
tags: Optional[List[str]] = None,
temperature: Optional[float] = None,
template: Optional[str] = None,
tfs_z: Optional[float] = None,
timeout: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[int] = None,
) -> BaseChatModel:
if not base_url:
base_url = "http://localhost:11434"
callback_manager = CallbackManager(
[StreamingStdOutCallbackHandler()])
# Mapping mirostat settings to their corresponding values
mirostat_options = {"Mirostat": 1, "Mirostat 2.0": 2}
# Default to 0 for 'Disabled'
mirostat_value = mirostat_options.get(mirostat, 0) # type: ignore
# Set mirostat_eta and mirostat_tau to None if mirostat is disabled
if mirostat_value == 0:
mirostat_eta = None
mirostat_tau = None
llm_params = {
"base_url": base_url,
"cache": cache,
"model": model,
"mirostat": mirostat,
"mirostat": mirostat_value,
"format": format,
"metadata": metadata,
"tags": tags,
## When a callback component is added to Langflow, the comment must be uncommented.##
# "callback_manager": callback_manager,
# "callbacks": callbacks,
#####################################################################################
"mirostat_eta": mirostat_eta,
"mirostat_tau": mirostat_tau,
"num_ctx": num_ctx,
@ -198,10 +238,9 @@ class ChatOllamaComponent(CustomComponent):
"top_k": top_k,
"top_p": top_p,
"verbose": verbose,
"callback_manager": callback_manager
}
# None Value Remove
# None Value remove
llm_params = {k: v for k, v in llm_params.items() if v is not None}
try:
@ -209,4 +248,4 @@ class ChatOllamaComponent(CustomComponent):
except Exception as e:
raise ValueError("Could not initialize Ollama LLM.") from e
return output
return output # type: ignore