Fixed Ollama base url handling and Qdrant component (#2007)
This pull request includes several code refactorings to improve the readability and maintainability of the codebase. The changes include reordering and organizing the initialization parameters in the QdrantComponent class, removing unused imports and cleaning up code formatting in the OllamaModel.py file, updating comments and docstrings for clarity and consistency in the OllamaModel.py file, and refactoring the logic for setting the base URL in the get_model method of the OllamaModel.py file to improve maintainability and readability. These changes aim to make the codebase more organized and easier to understand for future development and maintenance.
This commit is contained in:
commit
92a11b7ce4
2 changed files with 21 additions and 61 deletions
|
|
@ -1,21 +1,11 @@
|
|||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
|
||||
|
||||
from langchain_community.chat_models.ollama import ChatOllama
|
||||
from langflow.base.constants import STREAM_INFO_TEXT
|
||||
from langflow.base.models.model import LCModelComponent
|
||||
from langchain_core.caches import BaseCache
|
||||
|
||||
from langflow.field_typing import Text
|
||||
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
from langchain_community.chat_models.ollama import ChatOllama
|
||||
|
||||
from langflow.base.constants import STREAM_INFO_TEXT
|
||||
from langflow.base.models.model import LCModelComponent
|
||||
from langflow.field_typing import Text
|
||||
|
||||
|
||||
class ChatOllamaComponent(LCModelComponent):
|
||||
|
|
@ -26,18 +16,12 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
field_order = [
|
||||
"base_url",
|
||||
"headers",
|
||||
|
||||
"keep_alive_flag",
|
||||
"keep_alive",
|
||||
|
||||
"metadata",
|
||||
"model",
|
||||
|
||||
|
||||
"temperature",
|
||||
"cache",
|
||||
|
||||
|
||||
"format",
|
||||
"metadata",
|
||||
"mirostat",
|
||||
|
|
@ -67,10 +51,7 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
"base_url": {
|
||||
"display_name": "Base URL",
|
||||
"info": "Endpoint of the Ollama API. Defaults to 'http://localhost:11434' if not specified.",
|
||||
|
||||
},
|
||||
|
||||
|
||||
"format": {
|
||||
"display_name": "Format",
|
||||
"info": "Specify the format of the output (e.g., json)",
|
||||
|
|
@ -79,13 +60,10 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
"headers": {
|
||||
"display_name": "Headers",
|
||||
"advanced": True,
|
||||
|
||||
|
||||
},
|
||||
|
||||
"keep_alive_flag": {
|
||||
"display_name": "Unload interval",
|
||||
"options": ["Keep", "Immediately","Minute", "Hour", "sec" ],
|
||||
"options": ["Keep", "Immediately", "Minute", "Hour", "sec"],
|
||||
"real_time_refresh": True,
|
||||
"refresh_button": True,
|
||||
},
|
||||
|
|
@ -93,9 +71,6 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
"display_name": "interval",
|
||||
"info": "How long the model will stay loaded into memory.",
|
||||
},
|
||||
|
||||
|
||||
|
||||
"model": {
|
||||
"display_name": "Model Name",
|
||||
"options": [],
|
||||
|
|
@ -109,14 +84,6 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
"value": 0.8,
|
||||
"info": "Controls the creativity of model responses.",
|
||||
},
|
||||
|
||||
|
||||
"format": {
|
||||
"display_name": "Format",
|
||||
"field_type": "str",
|
||||
"info": "Specify the format of the output (e.g., json).",
|
||||
"advanced": True,
|
||||
},
|
||||
"metadata": {
|
||||
"display_name": "Metadata",
|
||||
"info": "Metadata to add to the run trace.",
|
||||
|
|
@ -129,7 +96,6 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
"advanced": False,
|
||||
"real_time_refresh": True,
|
||||
"refresh_button": True,
|
||||
|
||||
},
|
||||
"mirostat_eta": {
|
||||
"display_name": "Mirostat Eta",
|
||||
|
|
@ -260,10 +226,14 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
build_config["mirostat_tau"]["value"] = 5
|
||||
|
||||
if field_name == "model":
|
||||
base_url = build_config.get("base_url", {}).get(
|
||||
"value", "http://localhost:11434")
|
||||
build_config["model"]["options"] = self.get_model(
|
||||
base_url + "/api/tags")
|
||||
base_url_dict = build_config.get("base_url", {})
|
||||
base_url_load_from_db = base_url_dict.get("load_from_db", False)
|
||||
base_url_value = base_url_dict.get("value")
|
||||
if base_url_load_from_db:
|
||||
base_url_value = self.variables(base_url_value)
|
||||
elif not base_url_value:
|
||||
base_url_value = "http://localhost:11434"
|
||||
build_config["model"]["options"] = self.get_model(base_url_value + "/api/tags")
|
||||
|
||||
if field_name == "keep_alive_flag":
|
||||
if field_value == "Keep":
|
||||
|
|
@ -276,9 +246,6 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
build_config["keep_alive"]["advanced"] = False
|
||||
|
||||
return build_config
|
||||
|
||||
|
||||
|
||||
|
||||
def get_model(self, url: str) -> List[str]:
|
||||
try:
|
||||
|
|
@ -287,8 +254,7 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
model_names = [model['name']
|
||||
for model in data.get("models", [])]
|
||||
model_names = [model["name"] for model in data.get("models", [])]
|
||||
return model_names
|
||||
except Exception as e:
|
||||
raise ValueError("Could not retrieve models") from e
|
||||
|
|
@ -299,15 +265,13 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
base_url: Optional[str],
|
||||
model: str,
|
||||
input_value: Text,
|
||||
|
||||
mirostat: Optional[str],
|
||||
mirostat: Optional[str] = "Disabled",
|
||||
mirostat_eta: Optional[float] = None,
|
||||
mirostat_tau: Optional[float] = None,
|
||||
|
||||
repeat_last_n: Optional[int] = None,
|
||||
verbose: Optional[bool] = None,
|
||||
keep_alive: Optional[int] = None,
|
||||
keep_alive_flag: Optional[str] = None,
|
||||
keep_alive_flag: Optional[str] = "Keep",
|
||||
num_ctx: Optional[int] = None,
|
||||
num_gpu: Optional[int] = None,
|
||||
format: Optional[str] = None,
|
||||
|
|
@ -326,12 +290,9 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
stream: bool = False,
|
||||
system_message: Optional[str] = None,
|
||||
) -> Text:
|
||||
|
||||
if not base_url:
|
||||
base_url = "http://localhost:11434"
|
||||
|
||||
|
||||
|
||||
if keep_alive_flag == "Minute":
|
||||
keep_alive_instance = f"{keep_alive}m"
|
||||
elif keep_alive_flag == "Hour":
|
||||
|
|
|
|||
|
|
@ -67,22 +67,19 @@ class QdrantComponent(CustomComponent):
|
|||
documents.append(_input.to_lc_document())
|
||||
else:
|
||||
documents.append(_input)
|
||||
if documents is None:
|
||||
if not documents:
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
client = QdrantClient(
|
||||
location=location,
|
||||
url=host,
|
||||
url=url,
|
||||
port=port,
|
||||
grpc_port=grpc_port,
|
||||
https=https,
|
||||
prefix=prefix,
|
||||
timeout=timeout,
|
||||
prefer_grpc=prefer_grpc,
|
||||
metadata_payload_key=metadata_payload_key,
|
||||
content_payload_key=content_payload_key,
|
||||
api_key=api_key,
|
||||
collection_name=collection_name,
|
||||
host=host,
|
||||
path=path,
|
||||
)
|
||||
|
|
@ -90,6 +87,8 @@ class QdrantComponent(CustomComponent):
|
|||
client=client,
|
||||
collection_name=collection_name,
|
||||
embeddings=embedding,
|
||||
content_payload_key=content_payload_key,
|
||||
metadata_payload_key=metadata_payload_key,
|
||||
)
|
||||
return vs
|
||||
else:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue