refactor: Update OllamaModel.py to include new features and improve build configuration
This commit is contained in:
parent
a33a56bf26
commit
bbd3ab4960
1 changed files with 67 additions and 11 deletions
|
|
@ -1,9 +1,13 @@
|
|||
from langchain_community.chat_models import ChatOllama
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from langchain_community.chat_models import ChatOllama
|
||||
|
||||
from langflow.base.constants import STREAM_INFO_TEXT
|
||||
from langflow.base.models.model import LCModelComponent
|
||||
from langflow.field_typing import LanguageModel, Text
|
||||
from langflow.io import BoolInput, DictInput, DropdownInput, FloatInput, IntInput, Output, StrInput
|
||||
from langflow.field_typing import LanguageModel
|
||||
from langflow.io import BoolInput, DictInput, DropdownInput, FloatInput, IntInput, MessageInput, Output, StrInput
|
||||
from langflow.schema.message import Message
|
||||
|
||||
|
||||
class ChatOllamaComponent(LCModelComponent):
|
||||
|
|
@ -11,24 +15,77 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
description = "Generate text using Ollama Local LLMs."
|
||||
icon = "Ollama"
|
||||
|
||||
def update_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None):
|
||||
if field_name == "mirostat":
|
||||
if field_value == "Disabled":
|
||||
build_config["mirostat_eta"]["advanced"] = True
|
||||
build_config["mirostat_tau"]["advanced"] = True
|
||||
build_config["mirostat_eta"]["value"] = None
|
||||
build_config["mirostat_tau"]["value"] = None
|
||||
|
||||
else:
|
||||
build_config["mirostat_eta"]["advanced"] = False
|
||||
build_config["mirostat_tau"]["advanced"] = False
|
||||
|
||||
if field_value == "Mirostat 2.0":
|
||||
build_config["mirostat_eta"]["value"] = 0.2
|
||||
build_config["mirostat_tau"]["value"] = 10
|
||||
else:
|
||||
build_config["mirostat_eta"]["value"] = 0.1
|
||||
build_config["mirostat_tau"]["value"] = 5
|
||||
|
||||
if field_name == "model":
|
||||
base_url_dict = build_config.get("base_url", {})
|
||||
base_url_load_from_db = base_url_dict.get("load_from_db", False)
|
||||
base_url_value = base_url_dict.get("value")
|
||||
if base_url_load_from_db:
|
||||
base_url_value = self.variables(base_url_value)
|
||||
elif not base_url_value:
|
||||
base_url_value = "http://localhost:11434"
|
||||
build_config["model"]["options"] = self.get_model(base_url_value + "/api/tags")
|
||||
|
||||
if field_name == "keep_alive_flag":
|
||||
if field_value == "Keep":
|
||||
build_config["keep_alive"]["value"] = "-1"
|
||||
build_config["keep_alive"]["advanced"] = True
|
||||
elif field_value == "Immediately":
|
||||
build_config["keep_alive"]["value"] = "0"
|
||||
build_config["keep_alive"]["advanced"] = True
|
||||
else:
|
||||
build_config["keep_alive"]["advanced"] = False
|
||||
|
||||
return build_config
|
||||
|
||||
def get_model(self, url: str) -> list[str]:
|
||||
try:
|
||||
with httpx.Client() as client:
|
||||
response = client.get(url)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
model_names = [model["name"] for model in data.get("models", [])]
|
||||
return model_names
|
||||
except Exception as e:
|
||||
raise ValueError("Could not retrieve models. Please, make sure Ollama is running.") from e
|
||||
|
||||
inputs = [
|
||||
StrInput(
|
||||
name="base_url",
|
||||
display_name="Base URL",
|
||||
info="Endpoint of the Ollama API. Defaults to 'http://localhost:11434' if not specified.",
|
||||
advanced=True,
|
||||
value="http://localhost:11434",
|
||||
),
|
||||
StrInput(
|
||||
DropdownInput(
|
||||
name="model",
|
||||
display_name="Model Name",
|
||||
value="llama2",
|
||||
info="Refer to https://ollama.ai/library for more models.",
|
||||
refresh_button=True,
|
||||
),
|
||||
FloatInput(
|
||||
name="temperature",
|
||||
display_name="Temperature",
|
||||
value=0.8,
|
||||
value=0.2,
|
||||
info="Controls the creativity of model responses.",
|
||||
),
|
||||
StrInput(
|
||||
|
|
@ -146,10 +203,9 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
info="Template to use for generating text.",
|
||||
advanced=True,
|
||||
),
|
||||
StrInput(
|
||||
MessageInput(
|
||||
name="input_value",
|
||||
display_name="Input",
|
||||
input_types=["Text", "Data", "Prompt"],
|
||||
),
|
||||
BoolInput(
|
||||
name="stream",
|
||||
|
|
@ -168,7 +224,7 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
Output(display_name="Language Model", name="model_output", method="build_model"),
|
||||
]
|
||||
|
||||
def text_response(self) -> Text:
|
||||
def text_response(self) -> Message:
|
||||
input_value = self.input_value
|
||||
stream = self.stream
|
||||
system_message = self.system_message
|
||||
|
|
@ -177,7 +233,7 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
self.status = result
|
||||
return result
|
||||
|
||||
def build_model(self) -> LanguageModel | BaseChatModel:
|
||||
def build_model(self) -> LanguageModel:
|
||||
# Mapping mirostat settings to their corresponding values
|
||||
mirostat_options = {"Mirostat": 1, "Mirostat 2.0": 2}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue