♻️ (OllamaModel.py): Remove unused imports and clean up code formatting for better readability
📝 (OllamaModel.py): Update comments and docstrings for clarity and consistency 🔧 (OllamaModel.py): Refactor the logic for setting the base URL in the get_model method to improve maintainability and readability
This commit is contained in:
parent
642acf8172
commit
32e8da3bf4
1 changed files with 17 additions and 56 deletions
|
|
@ -1,21 +1,11 @@
|
|||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
|
||||
|
||||
from langchain_community.chat_models.ollama import ChatOllama
|
||||
from langflow.base.constants import STREAM_INFO_TEXT
|
||||
from langflow.base.models.model import LCModelComponent
|
||||
from langchain_core.caches import BaseCache
|
||||
|
||||
from langflow.field_typing import Text
|
||||
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
from langchain_community.chat_models.ollama import ChatOllama
|
||||
|
||||
from langflow.base.constants import STREAM_INFO_TEXT
|
||||
from langflow.base.models.model import LCModelComponent
|
||||
from langflow.field_typing import Text
|
||||
|
||||
|
||||
class ChatOllamaComponent(LCModelComponent):
|
||||
|
|
@ -26,18 +16,12 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
field_order = [
|
||||
"base_url",
|
||||
"headers",
|
||||
|
||||
"keep_alive_flag",
|
||||
"keep_alive",
|
||||
|
||||
"metadata",
|
||||
"model",
|
||||
|
||||
|
||||
"temperature",
|
||||
"cache",
|
||||
|
||||
|
||||
"format",
|
||||
"metadata",
|
||||
"mirostat",
|
||||
|
|
@ -67,10 +51,7 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
"base_url": {
|
||||
"display_name": "Base URL",
|
||||
"info": "Endpoint of the Ollama API. Defaults to 'http://localhost:11434' if not specified.",
|
||||
|
||||
},
|
||||
|
||||
|
||||
"format": {
|
||||
"display_name": "Format",
|
||||
"info": "Specify the format of the output (e.g., json)",
|
||||
|
|
@ -79,13 +60,10 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
"headers": {
|
||||
"display_name": "Headers",
|
||||
"advanced": True,
|
||||
|
||||
|
||||
},
|
||||
|
||||
"keep_alive_flag": {
|
||||
"display_name": "Unload interval",
|
||||
"options": ["Keep", "Immediately","Minute", "Hour", "sec" ],
|
||||
"options": ["Keep", "Immediately", "Minute", "Hour", "sec"],
|
||||
"real_time_refresh": True,
|
||||
"refresh_button": True,
|
||||
},
|
||||
|
|
@ -93,9 +71,6 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
"display_name": "interval",
|
||||
"info": "How long the model will stay loaded into memory.",
|
||||
},
|
||||
|
||||
|
||||
|
||||
"model": {
|
||||
"display_name": "Model Name",
|
||||
"options": [],
|
||||
|
|
@ -109,14 +84,6 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
"value": 0.8,
|
||||
"info": "Controls the creativity of model responses.",
|
||||
},
|
||||
|
||||
|
||||
"format": {
|
||||
"display_name": "Format",
|
||||
"field_type": "str",
|
||||
"info": "Specify the format of the output (e.g., json).",
|
||||
"advanced": True,
|
||||
},
|
||||
"metadata": {
|
||||
"display_name": "Metadata",
|
||||
"info": "Metadata to add to the run trace.",
|
||||
|
|
@ -129,7 +96,6 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
"advanced": False,
|
||||
"real_time_refresh": True,
|
||||
"refresh_button": True,
|
||||
|
||||
},
|
||||
"mirostat_eta": {
|
||||
"display_name": "Mirostat Eta",
|
||||
|
|
@ -260,10 +226,14 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
build_config["mirostat_tau"]["value"] = 5
|
||||
|
||||
if field_name == "model":
|
||||
base_url = build_config.get("base_url", {}).get(
|
||||
"value", "http://localhost:11434")
|
||||
build_config["model"]["options"] = self.get_model(
|
||||
base_url + "/api/tags")
|
||||
base_url_dict = build_config.get("base_url", {})
|
||||
base_url_load_from_db = base_url_dict.get("load_from_db", False)
|
||||
base_url_value = base_url_dict.get("value")
|
||||
if base_url_load_from_db:
|
||||
base_url_value = self.variables(base_url_value)
|
||||
elif not base_url_value:
|
||||
base_url_value = "http://localhost:11434"
|
||||
build_config["model"]["options"] = self.get_model(base_url_value + "/api/tags")
|
||||
|
||||
if field_name == "keep_alive_flag":
|
||||
if field_value == "Keep":
|
||||
|
|
@ -276,9 +246,6 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
build_config["keep_alive"]["advanced"] = False
|
||||
|
||||
return build_config
|
||||
|
||||
|
||||
|
||||
|
||||
def get_model(self, url: str) -> List[str]:
|
||||
try:
|
||||
|
|
@ -287,8 +254,7 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
model_names = [model['name']
|
||||
for model in data.get("models", [])]
|
||||
model_names = [model["name"] for model in data.get("models", [])]
|
||||
return model_names
|
||||
except Exception as e:
|
||||
raise ValueError("Could not retrieve models") from e
|
||||
|
|
@ -299,15 +265,13 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
base_url: Optional[str],
|
||||
model: str,
|
||||
input_value: Text,
|
||||
|
||||
mirostat: Optional[str],
|
||||
mirostat: Optional[str] = "Disabled",
|
||||
mirostat_eta: Optional[float] = None,
|
||||
mirostat_tau: Optional[float] = None,
|
||||
|
||||
repeat_last_n: Optional[int] = None,
|
||||
verbose: Optional[bool] = None,
|
||||
keep_alive: Optional[int] = None,
|
||||
keep_alive_flag: Optional[str] = None,
|
||||
keep_alive_flag: Optional[str] = "Keep",
|
||||
num_ctx: Optional[int] = None,
|
||||
num_gpu: Optional[int] = None,
|
||||
format: Optional[str] = None,
|
||||
|
|
@ -326,12 +290,9 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
stream: bool = False,
|
||||
system_message: Optional[str] = None,
|
||||
) -> Text:
|
||||
|
||||
if not base_url:
|
||||
base_url = "http://localhost:11434"
|
||||
|
||||
|
||||
|
||||
if keep_alive_flag == "Minute":
|
||||
keep_alive_instance = f"{keep_alive}m"
|
||||
elif keep_alive_flag == "Hour":
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue