refactor(loading.py): refactor instantiate_class function to improve readability and maintainability

This commit is contained in:
Gabriel Almeida 2023-05-10 11:41:34 -03:00
commit 7c1c513106

View file

@ -13,9 +13,9 @@ from langchain.agents.load_tools import (
)
from langchain.agents.loading import load_agent_from_config
from langchain.agents.tools import Tool
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.loading import load_chain_from_config
from langchain.base_language import BaseLanguageModel
from langchain.llms.loading import load_llm_from_config
from pydantic import ValidationError
@ -30,88 +30,19 @@ from langflow.utils import util, validate
def instantiate_class(node_type: str, base_type: str, params: Dict) -> Any:
"""Instantiate class from module type and key, and params"""
params = convert_params_to_sets(params)
if node_type in CUSTOM_AGENTS:
if custom_agent := CUSTOM_AGENTS.get(node_type):
return custom_agent.initialize(**params) # type: ignore
params = process_params(params)
custom_agent = CUSTOM_AGENTS.get(node_type)
if custom_agent:
return custom_agent.initialize(**params)
class_object = import_by_type(_type=base_type, name=node_type)
# check if it is a class before using issubclass
# if isinstance(class_object, type) and issubclass(class_object, BaseModel):
# # validate params
# fields = class_object.__fields__
# params = {key: value for key, value in params.items() if key in fields}
if base_type == "agents":
# We need to initialize it differently
return load_agent_executor(class_object, params)
elif base_type == "prompts":
if node_type == "ZeroShotPrompt":
if "tools" not in params:
params["tools"] = []
return ZeroShotAgent.create_prompt(**params)
elif base_type == "tools":
if node_type == "JsonSpec":
params["dict_"] = load_file_into_dict(params.pop("path"))
return class_object(**params)
elif node_type == "PythonFunction":
# If the node_type is "PythonFunction"
# we need to get the function from the params
# which will be a str containing a python function
# and then we need to compile it and return the function
# as the instance
function_string = params["code"]
if isinstance(function_string, str):
return validate.eval_function(function_string)
raise ValueError("Function should be a string")
elif node_type.lower() == "tool":
return class_object(**params)
elif base_type == "toolkits":
loaded_toolkit = class_object(**params)
# Check if node_type has a loader
if toolkits_creator.has_create_function(node_type):
return load_toolkits_executor(node_type, loaded_toolkit, params)
return loaded_toolkit
elif base_type == "embeddings":
# ? Why remove model from params?
try:
params.pop("model")
except KeyError:
pass
# remove all params that are not in class_object.__fields__
try:
return class_object(**params)
except ValidationError:
params = {
key: value
for key, value in params.items()
if key in class_object.__fields__
}
return class_object(**params)
elif base_type == "vectorstores":
if len(params.get("documents", [])) == 0:
# Error when the pdf or other source was not correctly
# loaded.
raise ValueError(
"The source you provided did not load correctly or was empty."
"This may cause an error in the vectorstore."
)
return class_object.from_documents(**params)
elif base_type == "documentloaders":
return class_object(**params).load()
elif base_type == "textsplitters":
documents = params.pop("documents")
text_splitter = class_object(**params)
return text_splitter.split_documents(documents)
elif base_type == "utilities":
if node_type == "SQLDatabase":
return class_object.from_uri(params.pop("uri"))
return class_object(**params)
return instantiate_based_on_type(class_object, base_type, node_type, params)
def process_params(params):
"""Process params"""
def convert_params_to_sets(params):
"""Convert certain params to sets"""
if "allowed_special" in params:
params["allowed_special"] = set(params["allowed_special"])
if "disallowed_special" in params:
@ -119,6 +50,100 @@ def process_params(params):
return params
def instantiate_based_on_type(class_object, base_type, node_type, params):
if base_type == "agents":
return instantiate_agent(class_object, params)
elif base_type == "prompts":
return instantiate_prompt(node_type, params)
elif base_type == "tools":
return instantiate_tool(node_type, class_object, params)
elif base_type == "toolkits":
return instantiate_toolkit(node_type, class_object, params)
elif base_type == "embeddings":
return instantiate_embedding(class_object, params)
elif base_type == "vectorstores":
return instantiate_vectorstore(class_object, params)
elif base_type == "documentloaders":
return instantiate_documentloader(class_object, params)
elif base_type == "textsplitters":
return instantiate_textsplitter(class_object, params)
elif base_type == "utilities":
return instantiate_utility(node_type, class_object, params)
else:
return class_object(**params)
def instantiate_agent(class_object, params):
return load_agent_executor(class_object, params)
def instantiate_prompt(node_type, params):
if node_type == "ZeroShotPrompt":
if "tools" not in params:
params["tools"] = []
return ZeroShotAgent.create_prompt(**params)
return None # Or some other default action
def instantiate_tool(node_type, class_object, params):
if node_type == "JsonSpec":
params["dict_"] = load_file_into_dict(params.pop("path"))
return class_object(**params)
elif node_type == "PythonFunction":
function_string = params["code"]
if isinstance(function_string, str):
return validate.eval_function(function_string)
raise ValueError("Function should be a string")
elif node_type.lower() == "tool":
return class_object(**params)
return None # Or some other default action
def instantiate_toolkit(node_type, class_object, params):
loaded_toolkit = class_object(**params)
if toolkits_creator.has_create_function(node_type):
return load_toolkits_executor(node_type, loaded_toolkit, params)
return loaded_toolkit
def instantiate_embedding(class_object, params):
params.pop("model", None)
try:
return class_object(**params)
except ValidationError:
params = {
key: value
for key, value in params.items()
if key in class_object.__fields__
}
return class_object(**params)
def instantiate_vectorstore(class_object, params):
if len(params.get("documents", [])) == 0:
raise ValueError(
"The source you provided did not load correctly or was empty."
"This may cause an error in the vectorstore."
)
return class_object.from_documents(**params)
def instantiate_documentloader(class_object, params):
return class_object(**params).load()
def instantiate_textsplitter(class_object, params):
documents = params.pop("documents")
text_splitter = class_object(**params)
return text_splitter.split_documents(documents)
def instantiate_utility(node_type, class_object, params):
if node_type == "SQLDatabase":
return class_object.from_uri(params.pop("uri"))
return class_object(**params)
def load_flow_from_json(path: str, build=True):
# This is done to avoid circular imports
from langflow.graph import Graph