A better implementation of the Ollama component (#1701)

* Update OllamaModel.py

A draft to synchronize the model using the latest Langflow architecture and to improve it according to the latest Langchain specifications.

* Update OllamaModel.py

Checkout Models from api

* Update OllamaModel.py

* Update OllamaModel.py

---------

Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>
This commit is contained in:
yamonbt 2024-05-29 00:20:58 +09:00 committed by GitHub
commit 72f09d65b6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1,16 +1,21 @@
from typing import Dict, List, Optional
# from langchain_community.chat_models import ChatOllama
from langchain_community.chat_models import ChatOllama
from typing import Any, Dict, List, Optional, Union
from langchain_community.chat_models.ollama import ChatOllama
from langflow.base.constants import STREAM_INFO_TEXT
from langflow.base.models.model import LCModelComponent
from langchain_core.caches import BaseCache
# from langchain.chat_models import ChatOllama
from langflow.field_typing import Text
# whe When a callback component is added to Langflow, the comment must be uncommented.
# from langchain.callbacks.manager import CallbackManager
import asyncio
import json
import httpx
class ChatOllamaComponent(LCModelComponent):
@ -20,11 +25,19 @@ class ChatOllamaComponent(LCModelComponent):
field_order = [
"base_url",
"headers",
"keep_alive_flag",
"keep_alive",
"metadata",
"model",
"temperature",
"cache",
"callback_manager",
"callbacks",
"format",
"metadata",
"mirostat",
@ -54,12 +67,41 @@ class ChatOllamaComponent(LCModelComponent):
"base_url": {
"display_name": "Base URL",
"info": "Endpoint of the Ollama API. Defaults to 'http://localhost:11434' if not specified.",
},
"format": {
"display_name": "Format",
"info": "Specify the format of the output (e.g., json)",
"advanced": True,
},
"headers": {
"display_name": "Headers",
"advanced": True,
},
"keep_alive_flag": {
"display_name": "Unload interval",
"options": ["Keep", "Immediately","Minute", "Hour", "sec" ],
"real_time_refresh": True,
"refresh_button": True,
},
"keep_alive": {
"display_name": "interval",
"info": "How long the model will stay loaded into memory.",
},
"model": {
"display_name": "Model Name",
"value": "llama2",
"options": [],
"info": "Refer to https://ollama.ai/library for more models.",
"real_time_refresh": True,
"refresh_button": True,
},
"temperature": {
"display_name": "Temperature",
@ -67,25 +109,8 @@ class ChatOllamaComponent(LCModelComponent):
"value": 0.8,
"info": "Controls the creativity of model responses.",
},
"cache": {
"display_name": "Cache",
"field_type": "bool",
"info": "Enable or disable caching.",
"advanced": True,
"value": False,
},
### When a callback component is added to Langflow, the comment must be uncommented. ###
# "callback_manager": {
# "display_name": "Callback Manager",
# "info": "Optional callback manager for additional functionality.",
# "advanced": True,
# },
# "callbacks": {
# "display_name": "Callbacks",
# "info": "Callbacks to execute during model runtime.",
# "advanced": True,
# },
########################################################################################
"format": {
"display_name": "Format",
"field_type": "str",
@ -101,20 +126,24 @@ class ChatOllamaComponent(LCModelComponent):
"display_name": "Mirostat",
"options": ["Disabled", "Mirostat", "Mirostat 2.0"],
"info": "Enable/disable Mirostat sampling for controlling perplexity.",
"value": "Disabled",
"advanced": True,
"advanced": False,
"real_time_refresh": True,
"refresh_button": True,
},
"mirostat_eta": {
"display_name": "Mirostat Eta",
"field_type": "float",
"info": "Learning rate for Mirostat algorithm. (Default: 0.1)",
"advanced": True,
"real_time_refresh": True,
},
"mirostat_tau": {
"display_name": "Mirostat Tau",
"field_type": "float",
"info": "Controls the balance between coherence and diversity of the output. (Default: 5.0)",
"advanced": True,
"real_time_refresh": True,
},
"num_ctx": {
"display_name": "Context Window Size",
@ -211,21 +240,74 @@ class ChatOllamaComponent(LCModelComponent):
},
}
def update_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None):
if field_name == "mirostat":
if field_value == "Disabled":
build_config["mirostat_eta"]["advanced"] = True
build_config["mirostat_tau"]["advanced"] = True
build_config["mirostat_eta"]["value"] = None
build_config["mirostat_tau"]["value"] = None
else:
build_config["mirostat_eta"]["advanced"] = False
build_config["mirostat_tau"]["advanced"] = False
if field_value == "Mirostat 2.0":
build_config["mirostat_eta"]["value"] = 0.2
build_config["mirostat_tau"]["value"] = 10
else:
build_config["mirostat_eta"]["value"] = 0.1
build_config["mirostat_tau"]["value"] = 5
if field_name == "model":
base_url = build_config.get("base_url", {}).get(
"value", "http://localhost:11434")
build_config["model"]["options"] = self.get_model(
base_url + "/api/tags")
if field_name == "keep_alive_flag":
if field_value == "Keep":
build_config["keep_alive"]["value"] = "-1"
build_config["keep_alive"]["advanced"] = True
elif field_value == "Immediately":
build_config["keep_alive"]["value"] = "0"
build_config["keep_alive"]["advanced"] = True
else:
build_config["keep_alive"]["advanced"] = False
return build_config
def get_model(self, url: str) -> List[str]:
try:
with httpx.Client() as client:
response = client.get(url)
response.raise_for_status()
data = response.json()
model_names = [model['name']
for model in data.get("models", [])]
return model_names
except Exception as e:
raise ValueError("Could not retrieve models") from e
return [""]
def build(
self,
base_url: Optional[str],
model: str,
input_value: Text,
mirostat: Optional[str],
mirostat: Optional[str],
mirostat_eta: Optional[float] = None,
mirostat_tau: Optional[float] = None,
### When a callback component is added to Langflow, the comment must be uncommented.###
# callback_manager: Optional[CallbackManager] = None,
# callbacks: Optional[List[Callbacks]] = None,
#######################################################################################
repeat_last_n: Optional[int] = None,
verbose: Optional[bool] = None,
cache: Optional[bool] = None,
keep_alive: Optional[int] = None,
keep_alive_flag: Optional[str] = None,
num_ctx: Optional[int] = None,
num_gpu: Optional[int] = None,
format: Optional[str] = None,
@ -244,33 +326,39 @@ class ChatOllamaComponent(LCModelComponent):
stream: bool = False,
system_message: Optional[str] = None,
) -> Text:
if not base_url:
base_url = "http://localhost:11434"
# Mapping mirostat settings to their corresponding values
mirostat_options = {"Mirostat": 1, "Mirostat 2.0": 2}
# Default to 0 for 'Disabled'
mirostat_value = mirostat_options.get(mirostat, 0) # type: ignore
# Set mirostat_eta and mirostat_tau to None if mirostat is disabled
if mirostat_value == 0:
mirostat_eta = None
mirostat_tau = None
if keep_alive_flag == "Minute":
keep_alive_instance = f"{keep_alive}m"
elif keep_alive_flag == "Hour":
keep_alive_instance = f"{keep_alive}h"
elif keep_alive_flag == "sec":
keep_alive_instance = f"{keep_alive}s"
elif keep_alive_flag == "Keep":
keep_alive_instance = "-1"
elif keep_alive_flag == "Immediately":
keep_alive_instance = "0"
else:
keep_alive_instance = "Invalid option"
mirostat_instance = 0
if mirostat == "disable":
mirostat_instance = 0
# Mapping system settings to their corresponding values
llm_params = {
"base_url": base_url,
"cache": cache,
"model": model,
"mirostat": mirostat_value,
"mirostat": mirostat_instance,
"keep_alive": keep_alive_instance,
"format": format,
"metadata": metadata,
"tags": tags,
## When a callback component is added to Langflow, the comment must be uncommented.##
# "callback_manager": callback_manager,
# "callbacks": callbacks,
#####################################################################################
"mirostat_eta": mirostat_eta,
"mirostat_tau": mirostat_tau,
"num_ctx": num_ctx,