refactor: Remove unused code from initialize/llm.py and initialize/utils.py

This commit is contained in:
ogabrielluiz 2024-06-17 15:04:40 -03:00
commit 9fa0a0e312
3 changed files with 0 additions and 362 deletions

View file

@ -1,7 +0,0 @@
def initialize_vertexai(class_object, params):
if credentials_path := params.get("credentials"):
from google.oauth2 import service_account # type: ignore
credentials_object = service_account.Credentials.from_service_account_file(filename=credentials_path)
params["credentials"] = credentials_object
return class_object(**params)

View file

@ -1,118 +0,0 @@
import contextlib
import json
from typing import Any, Dict, List
import orjson
from langchain.agents import ZeroShotAgent
from langflow.services.database.models.base import orjson_dumps
from langchain_core.documents import Document
from langchain_core.output_parsers import BaseOutputParser
def handle_node_type(node_type, class_object, params: Dict):
if node_type == "ZeroShotPrompt":
params = check_tools_in_params(params)
prompt = ZeroShotAgent.create_prompt(**params)
elif "MessagePromptTemplate" in node_type:
prompt = instantiate_from_template(class_object, params)
elif node_type == "ChatPromptTemplate":
prompt = class_object.from_messages(**params)
elif hasattr(class_object, "from_template") and params.get("template"):
prompt = class_object.from_template(template=params.pop("template"))
else:
prompt = class_object(**params)
return params, prompt
def check_tools_in_params(params: Dict):
if "tools" not in params:
params["tools"] = []
return params
def instantiate_from_template(class_object, params: Dict):
from_template_params = {"template": params.pop("prompt", params.pop("template", ""))}
from_template_params.update(params)
if not from_template_params.get("template"):
raise ValueError("Prompt template is required")
return class_object.from_template(**from_template_params)
def handle_format_kwargs(prompt, params: Dict):
format_kwargs: Dict[str, Any] = {}
for input_variable in prompt.input_variables:
if input_variable in params:
format_kwargs = handle_variable(params, input_variable, format_kwargs)
return format_kwargs
def handle_partial_variables(prompt, format_kwargs: Dict):
partial_variables = format_kwargs.copy()
partial_variables = {key: value for key, value in partial_variables.items() if value}
# Remove handle_keys otherwise LangChain raises an error
partial_variables.pop("handle_keys", None)
if partial_variables and hasattr(prompt, "partial"):
return prompt.partial(**partial_variables)
return prompt
def handle_variable(params: Dict, input_variable: str, format_kwargs: Dict):
variable = params[input_variable]
if isinstance(variable, str):
format_kwargs[input_variable] = variable
elif isinstance(variable, BaseOutputParser) and hasattr(variable, "get_format_instructions"):
format_kwargs[input_variable] = variable.get_format_instructions()
elif is_instance_of_list_or_document(variable):
format_kwargs = format_document(variable, input_variable, format_kwargs)
if needs_handle_keys(variable):
format_kwargs = add_handle_keys(input_variable, format_kwargs)
return format_kwargs
def is_instance_of_list_or_document(variable):
return (
isinstance(variable, List)
and all(isinstance(item, Document) for item in variable)
or isinstance(variable, Document)
)
def format_document(variable, input_variable: str, format_kwargs: Dict):
variable = variable if isinstance(variable, List) else [variable]
content = format_content(variable)
format_kwargs[input_variable] = content
return format_kwargs
def format_content(variable):
if len(variable) > 1:
return "\n".join([item.page_content for item in variable if item.page_content])
elif len(variable) == 1:
content = variable[0].page_content
return try_to_load_json(content)
return ""
def try_to_load_json(content):
with contextlib.suppress(json.JSONDecodeError):
content = orjson.loads(content)
if isinstance(content, list):
content = ",".join([str(item) for item in content])
else:
content = orjson_dumps(content)
return content
def needs_handle_keys(variable):
return is_instance_of_list_or_document(variable) or (
isinstance(variable, BaseOutputParser) and hasattr(variable, "get_format_instructions")
)
def add_handle_keys(input_variable: str, format_kwargs: Dict):
if "handle_keys" not in format_kwargs:
format_kwargs["handle_keys"] = []
format_kwargs["handle_keys"].append(input_variable)
return format_kwargs

View file

@ -1,237 +0,0 @@
import os
from typing import Any, Callable, Dict, Type
import orjson
from langchain_community.vectorstores import (
FAISS,
Chroma,
MongoDBAtlasVectorSearch,
Qdrant,
SupabaseVectorStore,
Weaviate,
)
from langchain_core.documents import Document
from langchain_pinecone import Pinecone
def docs_in_params(params: dict) -> bool:
"""Check if params has documents OR texts and one of them is not an empty list,
If any of them is not an empty list, return True, else return False"""
return ("documents" in params and params["documents"]) or ("texts" in params and params["texts"])
def initialize_mongodb(class_object: Type[MongoDBAtlasVectorSearch], params: dict):
"""Initialize mongodb and return the class object"""
MONGODB_ATLAS_CLUSTER_URI = params.pop("mongodb_atlas_cluster_uri")
if not MONGODB_ATLAS_CLUSTER_URI:
raise ValueError("Mongodb atlas cluster uri must be provided in the params")
import certifi
from pymongo import MongoClient
client: MongoClient = MongoClient(MONGODB_ATLAS_CLUSTER_URI, tlsCAFile=certifi.where())
db_name = params.pop("db_name", None)
collection_name = params.pop("collection_name", None)
if not db_name or not collection_name:
raise ValueError("db_name and collection_name must be provided in the params")
index_name = params.pop("index_name", None)
if not index_name:
raise ValueError("index_name must be provided in the params")
collection = client[db_name][collection_name]
if not docs_in_params(params):
# __init__ requires collection, embedding and index_name
init_args = {
"collection": collection,
"index_name": index_name,
"embedding": params.get("embedding"),
}
return class_object(**init_args)
if "texts" in params:
params["documents"] = params.pop("texts")
params["collection"] = collection
params["index_name"] = index_name
return class_object.from_documents(**params)
def initialize_supabase(class_object: Type[SupabaseVectorStore], params: dict):
"""Initialize supabase and return the class object"""
from supabase.client import Client, create_client
if "supabase_url" not in params or "supabase_service_key" not in params:
raise ValueError("Supabase url and service key must be provided in the params")
if "texts" in params:
params["documents"] = params.pop("texts")
client_kwargs = {
"supabase_url": params.pop("supabase_url"),
"supabase_key": params.pop("supabase_service_key"),
}
supabase: Client = create_client(**client_kwargs)
if not docs_in_params(params):
params.pop("documents", None)
params.pop("texts", None)
return class_object(client=supabase, **params)
# If there are docs in the params, create a new index
return class_object.from_documents(client=supabase, **params)
def initialize_weaviate(class_object: Type[Weaviate], params: dict):
"""Initialize weaviate and return the class object"""
if not docs_in_params(params):
import weaviate # type: ignore
client_kwargs_json = params.get("client_kwargs", "{}")
client_kwargs = orjson.loads(client_kwargs_json)
client_params = {
"url": params.get("weaviate_url"),
}
client_params.update(client_kwargs)
weaviate_client = weaviate.Client(**client_params)
new_params = {
"client": weaviate_client,
"index_name": params.get("index_name"),
"text_key": params.get("text_key"),
}
return class_object(**new_params)
# If there are docs in the params, create a new index
if "texts" in params:
params["documents"] = params.pop("texts")
return class_object.from_documents(**params)
def initialize_faiss(class_object: Type[FAISS], params: dict):
"""Initialize faiss and return the class object"""
if not docs_in_params(params):
return class_object.load_local
save_local = params.get("save_local")
faiss_index = class_object(**params)
if save_local:
faiss_index.save_local(folder_path=save_local)
return faiss_index
def initialize_pinecone(class_object: Type[Pinecone], params: dict):
"""Initialize pinecone and return the class object"""
import pinecone # type: ignore
pinecone_api_key = params.pop("pinecone_api_key")
pinecone_env = params.pop("pinecone_env")
if pinecone_api_key is None or pinecone_env is None:
if os.getenv("PINECONE_API_KEY") is not None:
pinecone_api_key = os.getenv("PINECONE_API_KEY")
if os.getenv("PINECONE_ENV") is not None:
pinecone_env = os.getenv("PINECONE_ENV")
if pinecone_api_key is None or pinecone_env is None:
raise ValueError("Pinecone API key and environment must be provided in the params")
# initialize pinecone
pinecone.init(
api_key=pinecone_api_key, # find at app.pinecone.io
environment=pinecone_env, # next to api key in console
)
# If there are no docs in the params, return an existing index
# but first remove any texts or docs keys from the params
if not docs_in_params(params):
existing_index_params = {
"embedding": params.pop("embedding"),
}
if "index_name" in params:
existing_index_params["index_name"] = params.pop("index_name")
if "namespace" in params:
existing_index_params["namespace"] = params.pop("namespace")
return class_object.from_existing_index(**existing_index_params)
# If there are docs in the params, create a new index
if "texts" in params:
params["documents"] = params.pop("texts")
return class_object.from_documents(**params)
def initialize_chroma(class_object: Type[Chroma], params: dict):
"""Initialize a ChromaDB object from the params"""
if ( # type: ignore
"chroma_server_host" in params or "chroma_server_http_port" in params
):
import chromadb # type: ignore
settings_params = {
key: params[key] for key, value_ in params.items() if key.startswith("chroma_server_") and value_
}
chroma_settings = chromadb.config.Settings(**settings_params)
params["client_settings"] = chroma_settings
else:
# remove all chroma_server_ keys from params
params = {key: value for key, value in params.items() if not key.startswith("chroma_server_")}
persist = params.pop("persist", False)
if not docs_in_params(params):
params.pop("documents", None)
params.pop("texts", None)
params["embedding_function"] = params.pop("embedding")
chromadb_instance = class_object(**params)
else:
if "texts" in params:
params["documents"] = params.pop("texts")
for doc in params["documents"]:
if not isinstance(doc, Document):
# remove any non-Document objects from the list
params["documents"].remove(doc)
continue
if doc.metadata is None:
doc.metadata = {}
for key, value in doc.metadata.items():
if value is None:
doc.metadata[key] = ""
chromadb_instance = class_object.from_documents(**params)
if persist:
chromadb_instance.persist()
return chromadb_instance
def initialize_qdrant(class_object: Type[Qdrant], params: dict):
if not docs_in_params(params):
if "location" not in params and "api_key" not in params:
raise ValueError("Location and API key must be provided in the params")
from qdrant_client import QdrantClient
client_params = {
"location": params.pop("location"),
"api_key": params.pop("api_key"),
}
lc_params = {
"collection_name": params.pop("collection_name"),
"embeddings": params.pop("embedding"),
}
client = QdrantClient(**client_params)
return class_object(client=client, **lc_params)
return class_object.from_documents(**params)
vecstore_initializer: Dict[str, Callable[[Type[Any], dict], Any]] = {
"Pinecone": initialize_pinecone,
"Chroma": initialize_chroma,
"Qdrant": initialize_qdrant,
"Weaviate": initialize_weaviate,
"FAISS": initialize_faiss,
"SupabaseVectorStore": initialize_supabase,
"MongoDBAtlasVectorSearch": initialize_mongodb,
}