Update OllamaLLM.py
This commit is contained in:
parent
6c0b4fb416
commit
3b6672e06d
1 changed files with 64 additions and 58 deletions
|
|
@ -14,143 +14,149 @@ class OllamaLLM(CustomComponent):
|
|||
return {
|
||||
"base_url": {
|
||||
"display_name": "Base URL",
|
||||
"info": "Endpoint of the Ollama API. Defaults to 'http://localhost:11434' if not specified."
|
||||
"info": "Endpoint of the Ollama API. Defaults to 'http://localhost:11434' if not specified.",
|
||||
},
|
||||
"model": {
|
||||
"display_name": "Model Name",
|
||||
"value": "llama2",
|
||||
"info": "Refer to https://ollama.ai/library for more models."
|
||||
"info": "Refer to https://ollama.ai/library for more models.",
|
||||
},
|
||||
"temperature": {
|
||||
"display_name": "Temperature",
|
||||
"field_type": "float",
|
||||
"value": 0.8,
|
||||
"info": "Controls the creativity of model responses."
|
||||
"info": "Controls the creativity of model responses.",
|
||||
},
|
||||
|
||||
"mirostat": {
|
||||
"display_name": "Mirostat",
|
||||
"options": ["Disabled", "Mirostat", "Mirostat 2.0"],
|
||||
"info": "Enable/disable Mirostat sampling for controlling perplexity.",
|
||||
"value": "Disabled",
|
||||
"advanced": True
|
||||
"advanced": True,
|
||||
},
|
||||
"mirostat_eta": {
|
||||
"display_name": "Mirostat Eta",
|
||||
"field_type": "float",
|
||||
"info": "Learning rate influencing the algorithm's response to feedback.",
|
||||
"advanced": True
|
||||
"advanced": True,
|
||||
},
|
||||
|
||||
|
||||
"mirostat_tau": {
|
||||
"display_name": "Mirostat Tau",
|
||||
"field_type": "float",
|
||||
"value": 5.0,
|
||||
"info": "Controls balance between coherence and diversity.",
|
||||
"advanced": True
|
||||
"advanced": True,
|
||||
},
|
||||
"num_ctx": {
|
||||
"display_name": "Context Window Size",
|
||||
"field_type": "int",
|
||||
"value": 2048,
|
||||
"info": "Size of the context window for generating the next token.",
|
||||
"advanced": True
|
||||
"advanced": True,
|
||||
},
|
||||
"num_gpu": {
|
||||
"display_name": "Number of GPUs",
|
||||
"field_type": "int",
|
||||
"info": "Number of GPUs to use for computation.",
|
||||
"advanced": True
|
||||
"advanced": True,
|
||||
},
|
||||
"num_thread": {
|
||||
"display_name": "Number of Threads",
|
||||
"field_type": "int",
|
||||
"info": "Number of threads to use during computation.",
|
||||
"advanced": True
|
||||
"advanced": True,
|
||||
},
|
||||
"repeat_last_n": {
|
||||
"display_name": "Repeat Last N",
|
||||
"field_type": "int",
|
||||
"value": 64,
|
||||
"info": "Sets how far back the model looks to prevent repetition.",
|
||||
"advanced": True
|
||||
"advanced": True,
|
||||
},
|
||||
"repeat_penalty": {
|
||||
"display_name": "Repeat Penalty",
|
||||
"field_type": "float",
|
||||
"value": 1.1,
|
||||
"info": "Penalty for repetitions in generated text.",
|
||||
"advanced": True
|
||||
"advanced": True,
|
||||
},
|
||||
|
||||
"stop": {
|
||||
"display_name": "Stop Tokens",
|
||||
|
||||
"info": "List of tokens to signal the model to stop generating text.",
|
||||
"advanced": True
|
||||
"advanced": True,
|
||||
},
|
||||
"tfs_z": {
|
||||
"display_name": "TFS Z",
|
||||
"field_type": "float",
|
||||
"value": 1,
|
||||
"info": "Tail free sampling to reduce impact of less probable tokens.",
|
||||
"advanced": True
|
||||
"advanced": True,
|
||||
},
|
||||
"top_k": {
|
||||
"display_name": "Top K",
|
||||
"field_type": "int",
|
||||
"value": 40,
|
||||
"info": "Limits token selection to top K for reducing nonsense generation.",
|
||||
"advanced": True
|
||||
"advanced": True,
|
||||
},
|
||||
"top_p": {
|
||||
"display_name": "Top P",
|
||||
"field_type": "int",
|
||||
"value": 0.9,
|
||||
"info": "Works with top-k to control diversity of generated text.",
|
||||
"advanced": True
|
||||
"advanced": True,
|
||||
},
|
||||
}
|
||||
|
||||
def build(self, base_url: Optional[str], model: str, mirostat: str, mirostat_eta: Optional[float],
|
||||
mirostat_tau: Optional[float], num_ctx: Optional[int], num_gpu: Optional[int],
|
||||
num_thread: Optional[int], repeat_last_n: Optional[int], repeat_penalty: Optional[float],
|
||||
temperature: Optional[float], stop: Optional[List[str]], tfs_z: Optional[float],
|
||||
top_k: Optional[int], top_p: Optional[int]) -> BaseLLM:
|
||||
|
||||
def build(
|
||||
self,
|
||||
base_url: Optional[str],
|
||||
model: str,
|
||||
temperature: Optional[float],
|
||||
mirostat: Optional[str],
|
||||
mirostat_eta: Optional[float] = None,
|
||||
mirostat_tau: Optional[float] = None,
|
||||
num_ctx: Optional[int] = None,
|
||||
num_gpu: Optional[int] = None,
|
||||
num_thread: Optional[int] = None,
|
||||
repeat_last_n: Optional[int] = None,
|
||||
repeat_penalty: Optional[float] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
tfs_z: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
) -> BaseLLM:
|
||||
if not base_url:
|
||||
base_url = "http://localhost:11434"
|
||||
|
||||
mirostat_value = 0 # Default value for 'Disabled'
|
||||
# Mapping mirostat settings to their corresponding values
|
||||
mirostat_options = {"Mirostat": 1, "Mirostat 2.0": 2}
|
||||
|
||||
# Map the textual option to the corresponding integer
|
||||
if mirostat == "Mirostat":
|
||||
mirostat_value = 1
|
||||
elif mirostat == "Mirostat 2.0":
|
||||
mirostat_value = 2
|
||||
# Default to 0 for 'Disabled'
|
||||
mirostat_value = mirostat_options.get(mirostat, 0) # type: ignore
|
||||
|
||||
# Set mirostat_eta and mirostat_tau to None if mirostat is disabled
|
||||
if mirostat_value == 0:
|
||||
mirostat_eta = None
|
||||
mirostat_tau = None
|
||||
|
||||
llm_params = {
|
||||
"base_url": base_url,
|
||||
"model": model,
|
||||
"mirostat": mirostat_value,
|
||||
"mirostat_eta": mirostat_eta,
|
||||
"mirostat_tau": mirostat_tau,
|
||||
"num_ctx": num_ctx,
|
||||
"num_gpu": num_gpu,
|
||||
"num_thread": num_thread,
|
||||
"repeat_last_n": repeat_last_n,
|
||||
"repeat_penalty": repeat_penalty,
|
||||
"temperature": temperature,
|
||||
"stop": stop,
|
||||
"tfs_z": tfs_z,
|
||||
"top_k": top_k,
|
||||
"top_p": top_p,
|
||||
}
|
||||
|
||||
# None Value remove
|
||||
llm_params = {k: v for k, v in llm_params.items() if v is not None}
|
||||
|
||||
params = {k: v for k, v in {
|
||||
'base_url': base_url,
|
||||
'model': model,
|
||||
'mirostat': mirostat_value,
|
||||
'mirostat_eta': mirostat_eta,
|
||||
'mirostat_tau': mirostat_tau,
|
||||
'num_ctx': num_ctx,
|
||||
'num_gpu': num_gpu,
|
||||
'num_thread': num_thread,
|
||||
'repeat_last_n': repeat_last_n,
|
||||
'repeat_penalty': repeat_penalty,
|
||||
'temperature': temperature,
|
||||
'stop': stop,
|
||||
'tfs_z': tfs_z,
|
||||
'top_k': top_k,
|
||||
'top_p': top_p,
|
||||
'streaming' :"True"
|
||||
}.items() if v is not None}
|
||||
|
||||
try:
|
||||
llm = Ollama(**params)
|
||||
llm = Ollama(**llm_params)
|
||||
except Exception as e:
|
||||
raise ValueError("Could not connect to Ollama.") from e
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue