refactor: Remove unused code from initialize/llm.py and initialize/utils.py
This commit is contained in:
parent
4b643a8b67
commit
9fa0a0e312
3 changed files with 0 additions and 362 deletions
|
|
@ -1,7 +0,0 @@
|
|||
def initialize_vertexai(class_object, params):
|
||||
if credentials_path := params.get("credentials"):
|
||||
from google.oauth2 import service_account # type: ignore
|
||||
|
||||
credentials_object = service_account.Credentials.from_service_account_file(filename=credentials_path)
|
||||
params["credentials"] = credentials_object
|
||||
return class_object(**params)
|
||||
|
|
@ -1,118 +0,0 @@
|
|||
import contextlib
|
||||
import json
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import orjson
|
||||
from langchain.agents import ZeroShotAgent
|
||||
|
||||
from langflow.services.database.models.base import orjson_dumps
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.output_parsers import BaseOutputParser
|
||||
|
||||
|
||||
def handle_node_type(node_type, class_object, params: Dict):
|
||||
if node_type == "ZeroShotPrompt":
|
||||
params = check_tools_in_params(params)
|
||||
prompt = ZeroShotAgent.create_prompt(**params)
|
||||
elif "MessagePromptTemplate" in node_type:
|
||||
prompt = instantiate_from_template(class_object, params)
|
||||
elif node_type == "ChatPromptTemplate":
|
||||
prompt = class_object.from_messages(**params)
|
||||
elif hasattr(class_object, "from_template") and params.get("template"):
|
||||
prompt = class_object.from_template(template=params.pop("template"))
|
||||
else:
|
||||
prompt = class_object(**params)
|
||||
return params, prompt
|
||||
|
||||
|
||||
def check_tools_in_params(params: Dict):
|
||||
if "tools" not in params:
|
||||
params["tools"] = []
|
||||
return params
|
||||
|
||||
|
||||
def instantiate_from_template(class_object, params: Dict):
|
||||
from_template_params = {"template": params.pop("prompt", params.pop("template", ""))}
|
||||
|
||||
from_template_params.update(params)
|
||||
if not from_template_params.get("template"):
|
||||
raise ValueError("Prompt template is required")
|
||||
return class_object.from_template(**from_template_params)
|
||||
|
||||
|
||||
def handle_format_kwargs(prompt, params: Dict):
|
||||
format_kwargs: Dict[str, Any] = {}
|
||||
for input_variable in prompt.input_variables:
|
||||
if input_variable in params:
|
||||
format_kwargs = handle_variable(params, input_variable, format_kwargs)
|
||||
return format_kwargs
|
||||
|
||||
|
||||
def handle_partial_variables(prompt, format_kwargs: Dict):
|
||||
partial_variables = format_kwargs.copy()
|
||||
partial_variables = {key: value for key, value in partial_variables.items() if value}
|
||||
# Remove handle_keys otherwise LangChain raises an error
|
||||
partial_variables.pop("handle_keys", None)
|
||||
if partial_variables and hasattr(prompt, "partial"):
|
||||
return prompt.partial(**partial_variables)
|
||||
return prompt
|
||||
|
||||
|
||||
def handle_variable(params: Dict, input_variable: str, format_kwargs: Dict):
|
||||
variable = params[input_variable]
|
||||
if isinstance(variable, str):
|
||||
format_kwargs[input_variable] = variable
|
||||
elif isinstance(variable, BaseOutputParser) and hasattr(variable, "get_format_instructions"):
|
||||
format_kwargs[input_variable] = variable.get_format_instructions()
|
||||
elif is_instance_of_list_or_document(variable):
|
||||
format_kwargs = format_document(variable, input_variable, format_kwargs)
|
||||
if needs_handle_keys(variable):
|
||||
format_kwargs = add_handle_keys(input_variable, format_kwargs)
|
||||
return format_kwargs
|
||||
|
||||
|
||||
def is_instance_of_list_or_document(variable):
|
||||
return (
|
||||
isinstance(variable, List)
|
||||
and all(isinstance(item, Document) for item in variable)
|
||||
or isinstance(variable, Document)
|
||||
)
|
||||
|
||||
|
||||
def format_document(variable, input_variable: str, format_kwargs: Dict):
|
||||
variable = variable if isinstance(variable, List) else [variable]
|
||||
content = format_content(variable)
|
||||
format_kwargs[input_variable] = content
|
||||
return format_kwargs
|
||||
|
||||
|
||||
def format_content(variable):
|
||||
if len(variable) > 1:
|
||||
return "\n".join([item.page_content for item in variable if item.page_content])
|
||||
elif len(variable) == 1:
|
||||
content = variable[0].page_content
|
||||
return try_to_load_json(content)
|
||||
return ""
|
||||
|
||||
|
||||
def try_to_load_json(content):
|
||||
with contextlib.suppress(json.JSONDecodeError):
|
||||
content = orjson.loads(content)
|
||||
if isinstance(content, list):
|
||||
content = ",".join([str(item) for item in content])
|
||||
else:
|
||||
content = orjson_dumps(content)
|
||||
return content
|
||||
|
||||
|
||||
def needs_handle_keys(variable):
|
||||
return is_instance_of_list_or_document(variable) or (
|
||||
isinstance(variable, BaseOutputParser) and hasattr(variable, "get_format_instructions")
|
||||
)
|
||||
|
||||
|
||||
def add_handle_keys(input_variable: str, format_kwargs: Dict):
|
||||
if "handle_keys" not in format_kwargs:
|
||||
format_kwargs["handle_keys"] = []
|
||||
format_kwargs["handle_keys"].append(input_variable)
|
||||
return format_kwargs
|
||||
|
|
@ -1,237 +0,0 @@
|
|||
import os
|
||||
from typing import Any, Callable, Dict, Type
|
||||
|
||||
import orjson
|
||||
from langchain_community.vectorstores import (
|
||||
FAISS,
|
||||
Chroma,
|
||||
MongoDBAtlasVectorSearch,
|
||||
Qdrant,
|
||||
SupabaseVectorStore,
|
||||
Weaviate,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
from langchain_pinecone import Pinecone
|
||||
|
||||
|
||||
def docs_in_params(params: dict) -> bool:
|
||||
"""Check if params has documents OR texts and one of them is not an empty list,
|
||||
If any of them is not an empty list, return True, else return False"""
|
||||
return ("documents" in params and params["documents"]) or ("texts" in params and params["texts"])
|
||||
|
||||
|
||||
def initialize_mongodb(class_object: Type[MongoDBAtlasVectorSearch], params: dict):
|
||||
"""Initialize mongodb and return the class object"""
|
||||
|
||||
MONGODB_ATLAS_CLUSTER_URI = params.pop("mongodb_atlas_cluster_uri")
|
||||
if not MONGODB_ATLAS_CLUSTER_URI:
|
||||
raise ValueError("Mongodb atlas cluster uri must be provided in the params")
|
||||
import certifi
|
||||
from pymongo import MongoClient
|
||||
|
||||
client: MongoClient = MongoClient(MONGODB_ATLAS_CLUSTER_URI, tlsCAFile=certifi.where())
|
||||
db_name = params.pop("db_name", None)
|
||||
collection_name = params.pop("collection_name", None)
|
||||
if not db_name or not collection_name:
|
||||
raise ValueError("db_name and collection_name must be provided in the params")
|
||||
|
||||
index_name = params.pop("index_name", None)
|
||||
if not index_name:
|
||||
raise ValueError("index_name must be provided in the params")
|
||||
|
||||
collection = client[db_name][collection_name]
|
||||
if not docs_in_params(params):
|
||||
# __init__ requires collection, embedding and index_name
|
||||
init_args = {
|
||||
"collection": collection,
|
||||
"index_name": index_name,
|
||||
"embedding": params.get("embedding"),
|
||||
}
|
||||
|
||||
return class_object(**init_args)
|
||||
|
||||
if "texts" in params:
|
||||
params["documents"] = params.pop("texts")
|
||||
|
||||
params["collection"] = collection
|
||||
params["index_name"] = index_name
|
||||
|
||||
return class_object.from_documents(**params)
|
||||
|
||||
|
||||
def initialize_supabase(class_object: Type[SupabaseVectorStore], params: dict):
|
||||
"""Initialize supabase and return the class object"""
|
||||
from supabase.client import Client, create_client
|
||||
|
||||
if "supabase_url" not in params or "supabase_service_key" not in params:
|
||||
raise ValueError("Supabase url and service key must be provided in the params")
|
||||
if "texts" in params:
|
||||
params["documents"] = params.pop("texts")
|
||||
|
||||
client_kwargs = {
|
||||
"supabase_url": params.pop("supabase_url"),
|
||||
"supabase_key": params.pop("supabase_service_key"),
|
||||
}
|
||||
|
||||
supabase: Client = create_client(**client_kwargs)
|
||||
if not docs_in_params(params):
|
||||
params.pop("documents", None)
|
||||
params.pop("texts", None)
|
||||
return class_object(client=supabase, **params)
|
||||
# If there are docs in the params, create a new index
|
||||
|
||||
return class_object.from_documents(client=supabase, **params)
|
||||
|
||||
|
||||
def initialize_weaviate(class_object: Type[Weaviate], params: dict):
|
||||
"""Initialize weaviate and return the class object"""
|
||||
if not docs_in_params(params):
|
||||
import weaviate # type: ignore
|
||||
|
||||
client_kwargs_json = params.get("client_kwargs", "{}")
|
||||
client_kwargs = orjson.loads(client_kwargs_json)
|
||||
client_params = {
|
||||
"url": params.get("weaviate_url"),
|
||||
}
|
||||
client_params.update(client_kwargs)
|
||||
weaviate_client = weaviate.Client(**client_params)
|
||||
|
||||
new_params = {
|
||||
"client": weaviate_client,
|
||||
"index_name": params.get("index_name"),
|
||||
"text_key": params.get("text_key"),
|
||||
}
|
||||
return class_object(**new_params)
|
||||
# If there are docs in the params, create a new index
|
||||
if "texts" in params:
|
||||
params["documents"] = params.pop("texts")
|
||||
|
||||
return class_object.from_documents(**params)
|
||||
|
||||
|
||||
def initialize_faiss(class_object: Type[FAISS], params: dict):
|
||||
"""Initialize faiss and return the class object"""
|
||||
|
||||
if not docs_in_params(params):
|
||||
return class_object.load_local
|
||||
|
||||
save_local = params.get("save_local")
|
||||
faiss_index = class_object(**params)
|
||||
if save_local:
|
||||
faiss_index.save_local(folder_path=save_local)
|
||||
return faiss_index
|
||||
|
||||
|
||||
def initialize_pinecone(class_object: Type[Pinecone], params: dict):
|
||||
"""Initialize pinecone and return the class object"""
|
||||
|
||||
import pinecone # type: ignore
|
||||
|
||||
pinecone_api_key = params.pop("pinecone_api_key")
|
||||
pinecone_env = params.pop("pinecone_env")
|
||||
|
||||
if pinecone_api_key is None or pinecone_env is None:
|
||||
if os.getenv("PINECONE_API_KEY") is not None:
|
||||
pinecone_api_key = os.getenv("PINECONE_API_KEY")
|
||||
if os.getenv("PINECONE_ENV") is not None:
|
||||
pinecone_env = os.getenv("PINECONE_ENV")
|
||||
|
||||
if pinecone_api_key is None or pinecone_env is None:
|
||||
raise ValueError("Pinecone API key and environment must be provided in the params")
|
||||
|
||||
# initialize pinecone
|
||||
pinecone.init(
|
||||
api_key=pinecone_api_key, # find at app.pinecone.io
|
||||
environment=pinecone_env, # next to api key in console
|
||||
)
|
||||
|
||||
# If there are no docs in the params, return an existing index
|
||||
# but first remove any texts or docs keys from the params
|
||||
if not docs_in_params(params):
|
||||
existing_index_params = {
|
||||
"embedding": params.pop("embedding"),
|
||||
}
|
||||
if "index_name" in params:
|
||||
existing_index_params["index_name"] = params.pop("index_name")
|
||||
if "namespace" in params:
|
||||
existing_index_params["namespace"] = params.pop("namespace")
|
||||
|
||||
return class_object.from_existing_index(**existing_index_params)
|
||||
# If there are docs in the params, create a new index
|
||||
if "texts" in params:
|
||||
params["documents"] = params.pop("texts")
|
||||
return class_object.from_documents(**params)
|
||||
|
||||
|
||||
def initialize_chroma(class_object: Type[Chroma], params: dict):
|
||||
"""Initialize a ChromaDB object from the params"""
|
||||
if ( # type: ignore
|
||||
"chroma_server_host" in params or "chroma_server_http_port" in params
|
||||
):
|
||||
import chromadb # type: ignore
|
||||
|
||||
settings_params = {
|
||||
key: params[key] for key, value_ in params.items() if key.startswith("chroma_server_") and value_
|
||||
}
|
||||
chroma_settings = chromadb.config.Settings(**settings_params)
|
||||
params["client_settings"] = chroma_settings
|
||||
else:
|
||||
# remove all chroma_server_ keys from params
|
||||
params = {key: value for key, value in params.items() if not key.startswith("chroma_server_")}
|
||||
|
||||
persist = params.pop("persist", False)
|
||||
if not docs_in_params(params):
|
||||
params.pop("documents", None)
|
||||
params.pop("texts", None)
|
||||
params["embedding_function"] = params.pop("embedding")
|
||||
chromadb_instance = class_object(**params)
|
||||
else:
|
||||
if "texts" in params:
|
||||
params["documents"] = params.pop("texts")
|
||||
for doc in params["documents"]:
|
||||
if not isinstance(doc, Document):
|
||||
# remove any non-Document objects from the list
|
||||
params["documents"].remove(doc)
|
||||
continue
|
||||
if doc.metadata is None:
|
||||
doc.metadata = {}
|
||||
for key, value in doc.metadata.items():
|
||||
if value is None:
|
||||
doc.metadata[key] = ""
|
||||
|
||||
chromadb_instance = class_object.from_documents(**params)
|
||||
if persist:
|
||||
chromadb_instance.persist()
|
||||
return chromadb_instance
|
||||
|
||||
|
||||
def initialize_qdrant(class_object: Type[Qdrant], params: dict):
|
||||
if not docs_in_params(params):
|
||||
if "location" not in params and "api_key" not in params:
|
||||
raise ValueError("Location and API key must be provided in the params")
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
client_params = {
|
||||
"location": params.pop("location"),
|
||||
"api_key": params.pop("api_key"),
|
||||
}
|
||||
lc_params = {
|
||||
"collection_name": params.pop("collection_name"),
|
||||
"embeddings": params.pop("embedding"),
|
||||
}
|
||||
client = QdrantClient(**client_params)
|
||||
|
||||
return class_object(client=client, **lc_params)
|
||||
|
||||
return class_object.from_documents(**params)
|
||||
|
||||
|
||||
vecstore_initializer: Dict[str, Callable[[Type[Any], dict], Any]] = {
|
||||
"Pinecone": initialize_pinecone,
|
||||
"Chroma": initialize_chroma,
|
||||
"Qdrant": initialize_qdrant,
|
||||
"Weaviate": initialize_weaviate,
|
||||
"FAISS": initialize_faiss,
|
||||
"SupabaseVectorStore": initialize_supabase,
|
||||
"MongoDBAtlasVectorSearch": initialize_mongodb,
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue