Fixed Ollama base url handling and Qdrant component (#2007)

This pull request includes several code refactorings to improve the
readability and maintainability of the codebase. The changes include
reordering and organizing the initialization parameters in the
QdrantComponent class, removing unused imports and cleaning up code
formatting in the OllamaModel.py file, updating comments and docstrings
for clarity and consistency in the OllamaModel.py file, and refactoring
the logic for setting the base URL in the get_model method of the
OllamaModel.py file to improve maintainability and readability. These
changes aim to make the codebase more organized and easier to understand
for future development and maintenance.
This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-05-29 06:40:17 -07:00 committed by GitHub
commit 92a11b7ce4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 21 additions and 61 deletions

View file

@ -1,21 +1,11 @@
from typing import Any, Dict, List, Optional, Union
from langchain_community.chat_models.ollama import ChatOllama
from langflow.base.constants import STREAM_INFO_TEXT
from langflow.base.models.model import LCModelComponent
from langchain_core.caches import BaseCache
from langflow.field_typing import Text
import asyncio
import json
from typing import Any, Dict, List, Optional
import httpx
from langchain_community.chat_models.ollama import ChatOllama
from langflow.base.constants import STREAM_INFO_TEXT
from langflow.base.models.model import LCModelComponent
from langflow.field_typing import Text
class ChatOllamaComponent(LCModelComponent):
@ -26,18 +16,12 @@ class ChatOllamaComponent(LCModelComponent):
field_order = [
"base_url",
"headers",
"keep_alive_flag",
"keep_alive",
"metadata",
"model",
"temperature",
"cache",
"format",
"metadata",
"mirostat",
@ -67,10 +51,7 @@ class ChatOllamaComponent(LCModelComponent):
"base_url": {
"display_name": "Base URL",
"info": "Endpoint of the Ollama API. Defaults to 'http://localhost:11434' if not specified.",
},
"format": {
"display_name": "Format",
"info": "Specify the format of the output (e.g., json)",
@ -79,13 +60,10 @@ class ChatOllamaComponent(LCModelComponent):
"headers": {
"display_name": "Headers",
"advanced": True,
},
"keep_alive_flag": {
"display_name": "Unload interval",
"options": ["Keep", "Immediately","Minute", "Hour", "sec" ],
"options": ["Keep", "Immediately", "Minute", "Hour", "sec"],
"real_time_refresh": True,
"refresh_button": True,
},
@ -93,9 +71,6 @@ class ChatOllamaComponent(LCModelComponent):
"display_name": "interval",
"info": "How long the model will stay loaded into memory.",
},
"model": {
"display_name": "Model Name",
"options": [],
@ -109,14 +84,6 @@ class ChatOllamaComponent(LCModelComponent):
"value": 0.8,
"info": "Controls the creativity of model responses.",
},
"format": {
"display_name": "Format",
"field_type": "str",
"info": "Specify the format of the output (e.g., json).",
"advanced": True,
},
"metadata": {
"display_name": "Metadata",
"info": "Metadata to add to the run trace.",
@ -129,7 +96,6 @@ class ChatOllamaComponent(LCModelComponent):
"advanced": False,
"real_time_refresh": True,
"refresh_button": True,
},
"mirostat_eta": {
"display_name": "Mirostat Eta",
@ -260,10 +226,14 @@ class ChatOllamaComponent(LCModelComponent):
build_config["mirostat_tau"]["value"] = 5
if field_name == "model":
base_url = build_config.get("base_url", {}).get(
"value", "http://localhost:11434")
build_config["model"]["options"] = self.get_model(
base_url + "/api/tags")
base_url_dict = build_config.get("base_url", {})
base_url_load_from_db = base_url_dict.get("load_from_db", False)
base_url_value = base_url_dict.get("value")
if base_url_load_from_db:
base_url_value = self.variables(base_url_value)
elif not base_url_value:
base_url_value = "http://localhost:11434"
build_config["model"]["options"] = self.get_model(base_url_value + "/api/tags")
if field_name == "keep_alive_flag":
if field_value == "Keep":
@ -276,9 +246,6 @@ class ChatOllamaComponent(LCModelComponent):
build_config["keep_alive"]["advanced"] = False
return build_config
def get_model(self, url: str) -> List[str]:
try:
@ -287,8 +254,7 @@ class ChatOllamaComponent(LCModelComponent):
response.raise_for_status()
data = response.json()
model_names = [model['name']
for model in data.get("models", [])]
model_names = [model["name"] for model in data.get("models", [])]
return model_names
except Exception as e:
raise ValueError("Could not retrieve models") from e
@ -299,15 +265,13 @@ class ChatOllamaComponent(LCModelComponent):
base_url: Optional[str],
model: str,
input_value: Text,
mirostat: Optional[str],
mirostat: Optional[str] = "Disabled",
mirostat_eta: Optional[float] = None,
mirostat_tau: Optional[float] = None,
repeat_last_n: Optional[int] = None,
verbose: Optional[bool] = None,
keep_alive: Optional[int] = None,
keep_alive_flag: Optional[str] = None,
keep_alive_flag: Optional[str] = "Keep",
num_ctx: Optional[int] = None,
num_gpu: Optional[int] = None,
format: Optional[str] = None,
@ -326,12 +290,9 @@ class ChatOllamaComponent(LCModelComponent):
stream: bool = False,
system_message: Optional[str] = None,
) -> Text:
if not base_url:
base_url = "http://localhost:11434"
if keep_alive_flag == "Minute":
keep_alive_instance = f"{keep_alive}m"
elif keep_alive_flag == "Hour":

View file

@ -67,22 +67,19 @@ class QdrantComponent(CustomComponent):
documents.append(_input.to_lc_document())
else:
documents.append(_input)
if documents is None:
if not documents:
from qdrant_client import QdrantClient
client = QdrantClient(
location=location,
url=host,
url=url,
port=port,
grpc_port=grpc_port,
https=https,
prefix=prefix,
timeout=timeout,
prefer_grpc=prefer_grpc,
metadata_payload_key=metadata_payload_key,
content_payload_key=content_payload_key,
api_key=api_key,
collection_name=collection_name,
host=host,
path=path,
)
@ -90,6 +87,8 @@ class QdrantComponent(CustomComponent):
client=client,
collection_name=collection_name,
embeddings=embedding,
content_payload_key=content_payload_key,
metadata_payload_key=metadata_payload_key,
)
return vs
else: