♻️ (OllamaModel.py): Remove unused imports and clean up code formatting for better readability

📝 (OllamaModel.py): Update comments and docstrings for clarity and consistency
🔧 (OllamaModel.py): Refactor the logic for setting the base URL in the get_model method to improve maintainability and readability
This commit is contained in:
ogabrielluiz 2024-05-29 10:33:12 -03:00
commit 32e8da3bf4

View file

@ -1,21 +1,11 @@
from typing import Any, Dict, List, Optional, Union
from langchain_community.chat_models.ollama import ChatOllama
from langflow.base.constants import STREAM_INFO_TEXT
from langflow.base.models.model import LCModelComponent
from langchain_core.caches import BaseCache
from langflow.field_typing import Text
import asyncio
import json
from typing import Any, Dict, List, Optional
import httpx
from langchain_community.chat_models.ollama import ChatOllama
from langflow.base.constants import STREAM_INFO_TEXT
from langflow.base.models.model import LCModelComponent
from langflow.field_typing import Text
class ChatOllamaComponent(LCModelComponent):
@ -26,18 +16,12 @@ class ChatOllamaComponent(LCModelComponent):
field_order = [
"base_url",
"headers",
"keep_alive_flag",
"keep_alive",
"metadata",
"model",
"temperature",
"cache",
"format",
"metadata",
"mirostat",
@ -67,10 +51,7 @@ class ChatOllamaComponent(LCModelComponent):
"base_url": {
"display_name": "Base URL",
"info": "Endpoint of the Ollama API. Defaults to 'http://localhost:11434' if not specified.",
},
"format": {
"display_name": "Format",
"info": "Specify the format of the output (e.g., json)",
@ -79,13 +60,10 @@ class ChatOllamaComponent(LCModelComponent):
"headers": {
"display_name": "Headers",
"advanced": True,
},
"keep_alive_flag": {
"display_name": "Unload interval",
"options": ["Keep", "Immediately","Minute", "Hour", "sec" ],
"options": ["Keep", "Immediately", "Minute", "Hour", "sec"],
"real_time_refresh": True,
"refresh_button": True,
},
@ -93,9 +71,6 @@ class ChatOllamaComponent(LCModelComponent):
"display_name": "interval",
"info": "How long the model will stay loaded into memory.",
},
"model": {
"display_name": "Model Name",
"options": [],
@ -109,14 +84,6 @@ class ChatOllamaComponent(LCModelComponent):
"value": 0.8,
"info": "Controls the creativity of model responses.",
},
"format": {
"display_name": "Format",
"field_type": "str",
"info": "Specify the format of the output (e.g., json).",
"advanced": True,
},
"metadata": {
"display_name": "Metadata",
"info": "Metadata to add to the run trace.",
@ -129,7 +96,6 @@ class ChatOllamaComponent(LCModelComponent):
"advanced": False,
"real_time_refresh": True,
"refresh_button": True,
},
"mirostat_eta": {
"display_name": "Mirostat Eta",
@ -260,10 +226,14 @@ class ChatOllamaComponent(LCModelComponent):
build_config["mirostat_tau"]["value"] = 5
if field_name == "model":
base_url = build_config.get("base_url", {}).get(
"value", "http://localhost:11434")
build_config["model"]["options"] = self.get_model(
base_url + "/api/tags")
base_url_dict = build_config.get("base_url", {})
base_url_load_from_db = base_url_dict.get("load_from_db", False)
base_url_value = base_url_dict.get("value")
if base_url_load_from_db:
base_url_value = self.variables(base_url_value)
elif not base_url_value:
base_url_value = "http://localhost:11434"
build_config["model"]["options"] = self.get_model(base_url_value + "/api/tags")
if field_name == "keep_alive_flag":
if field_value == "Keep":
@ -276,9 +246,6 @@ class ChatOllamaComponent(LCModelComponent):
build_config["keep_alive"]["advanced"] = False
return build_config
def get_model(self, url: str) -> List[str]:
try:
@ -287,8 +254,7 @@ class ChatOllamaComponent(LCModelComponent):
response.raise_for_status()
data = response.json()
model_names = [model['name']
for model in data.get("models", [])]
model_names = [model["name"] for model in data.get("models", [])]
return model_names
except Exception as e:
raise ValueError("Could not retrieve models") from e
@ -299,15 +265,13 @@ class ChatOllamaComponent(LCModelComponent):
base_url: Optional[str],
model: str,
input_value: Text,
mirostat: Optional[str],
mirostat: Optional[str] = "Disabled",
mirostat_eta: Optional[float] = None,
mirostat_tau: Optional[float] = None,
repeat_last_n: Optional[int] = None,
verbose: Optional[bool] = None,
keep_alive: Optional[int] = None,
keep_alive_flag: Optional[str] = None,
keep_alive_flag: Optional[str] = "Keep",
num_ctx: Optional[int] = None,
num_gpu: Optional[int] = None,
format: Optional[str] = None,
@ -326,12 +290,9 @@ class ChatOllamaComponent(LCModelComponent):
stream: bool = False,
system_message: Optional[str] = None,
) -> Text:
if not base_url:
base_url = "http://localhost:11434"
if keep_alive_flag == "Minute":
keep_alive_instance = f"{keep_alive}m"
elif keep_alive_flag == "Hour":