refactor(loading.py): refactor instantiate_class function to improve readability and maintainability
This commit is contained in:
parent
31062f2a8c
commit
7c1c513106
1 changed files with 104 additions and 79 deletions
|
|
@ -13,9 +13,9 @@ from langchain.agents.load_tools import (
|
|||
)
|
||||
from langchain.agents.loading import load_agent_from_config
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.loading import load_chain_from_config
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.llms.loading import load_llm_from_config
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
|
@ -30,88 +30,19 @@ from langflow.utils import util, validate
|
|||
|
||||
def instantiate_class(node_type: str, base_type: str, params: Dict) -> Any:
|
||||
"""Instantiate class from module type and key, and params"""
|
||||
params = convert_params_to_sets(params)
|
||||
|
||||
if node_type in CUSTOM_AGENTS:
|
||||
if custom_agent := CUSTOM_AGENTS.get(node_type):
|
||||
return custom_agent.initialize(**params) # type: ignore
|
||||
params = process_params(params)
|
||||
custom_agent = CUSTOM_AGENTS.get(node_type)
|
||||
if custom_agent:
|
||||
return custom_agent.initialize(**params)
|
||||
|
||||
class_object = import_by_type(_type=base_type, name=node_type)
|
||||
# check if it is a class before using issubclass
|
||||
|
||||
# if isinstance(class_object, type) and issubclass(class_object, BaseModel):
|
||||
# # validate params
|
||||
# fields = class_object.__fields__
|
||||
# params = {key: value for key, value in params.items() if key in fields}
|
||||
|
||||
if base_type == "agents":
|
||||
# We need to initialize it differently
|
||||
return load_agent_executor(class_object, params)
|
||||
elif base_type == "prompts":
|
||||
if node_type == "ZeroShotPrompt":
|
||||
if "tools" not in params:
|
||||
params["tools"] = []
|
||||
return ZeroShotAgent.create_prompt(**params)
|
||||
elif base_type == "tools":
|
||||
if node_type == "JsonSpec":
|
||||
params["dict_"] = load_file_into_dict(params.pop("path"))
|
||||
return class_object(**params)
|
||||
elif node_type == "PythonFunction":
|
||||
# If the node_type is "PythonFunction"
|
||||
# we need to get the function from the params
|
||||
# which will be a str containing a python function
|
||||
# and then we need to compile it and return the function
|
||||
# as the instance
|
||||
function_string = params["code"]
|
||||
if isinstance(function_string, str):
|
||||
return validate.eval_function(function_string)
|
||||
raise ValueError("Function should be a string")
|
||||
elif node_type.lower() == "tool":
|
||||
return class_object(**params)
|
||||
elif base_type == "toolkits":
|
||||
loaded_toolkit = class_object(**params)
|
||||
# Check if node_type has a loader
|
||||
if toolkits_creator.has_create_function(node_type):
|
||||
return load_toolkits_executor(node_type, loaded_toolkit, params)
|
||||
return loaded_toolkit
|
||||
elif base_type == "embeddings":
|
||||
# ? Why remove model from params?
|
||||
try:
|
||||
params.pop("model")
|
||||
except KeyError:
|
||||
pass
|
||||
# remove all params that are not in class_object.__fields__
|
||||
try:
|
||||
return class_object(**params)
|
||||
except ValidationError:
|
||||
params = {
|
||||
key: value
|
||||
for key, value in params.items()
|
||||
if key in class_object.__fields__
|
||||
}
|
||||
return class_object(**params)
|
||||
elif base_type == "vectorstores":
|
||||
if len(params.get("documents", [])) == 0:
|
||||
# Error when the pdf or other source was not correctly
|
||||
# loaded.
|
||||
raise ValueError(
|
||||
"The source you provided did not load correctly or was empty."
|
||||
"This may cause an error in the vectorstore."
|
||||
)
|
||||
return class_object.from_documents(**params)
|
||||
elif base_type == "documentloaders":
|
||||
return class_object(**params).load()
|
||||
elif base_type == "textsplitters":
|
||||
documents = params.pop("documents")
|
||||
text_splitter = class_object(**params)
|
||||
return text_splitter.split_documents(documents)
|
||||
elif base_type == "utilities":
|
||||
if node_type == "SQLDatabase":
|
||||
return class_object.from_uri(params.pop("uri"))
|
||||
|
||||
return class_object(**params)
|
||||
return instantiate_based_on_type(class_object, base_type, node_type, params)
|
||||
|
||||
|
||||
def process_params(params):
|
||||
"""Process params"""
|
||||
def convert_params_to_sets(params):
|
||||
"""Convert certain params to sets"""
|
||||
if "allowed_special" in params:
|
||||
params["allowed_special"] = set(params["allowed_special"])
|
||||
if "disallowed_special" in params:
|
||||
|
|
@ -119,6 +50,100 @@ def process_params(params):
|
|||
return params
|
||||
|
||||
|
||||
def instantiate_based_on_type(class_object, base_type, node_type, params):
|
||||
if base_type == "agents":
|
||||
return instantiate_agent(class_object, params)
|
||||
elif base_type == "prompts":
|
||||
return instantiate_prompt(node_type, params)
|
||||
elif base_type == "tools":
|
||||
return instantiate_tool(node_type, class_object, params)
|
||||
elif base_type == "toolkits":
|
||||
return instantiate_toolkit(node_type, class_object, params)
|
||||
elif base_type == "embeddings":
|
||||
return instantiate_embedding(class_object, params)
|
||||
elif base_type == "vectorstores":
|
||||
return instantiate_vectorstore(class_object, params)
|
||||
elif base_type == "documentloaders":
|
||||
return instantiate_documentloader(class_object, params)
|
||||
elif base_type == "textsplitters":
|
||||
return instantiate_textsplitter(class_object, params)
|
||||
elif base_type == "utilities":
|
||||
return instantiate_utility(node_type, class_object, params)
|
||||
else:
|
||||
return class_object(**params)
|
||||
|
||||
|
||||
def instantiate_agent(class_object, params):
|
||||
return load_agent_executor(class_object, params)
|
||||
|
||||
|
||||
def instantiate_prompt(node_type, params):
|
||||
if node_type == "ZeroShotPrompt":
|
||||
if "tools" not in params:
|
||||
params["tools"] = []
|
||||
return ZeroShotAgent.create_prompt(**params)
|
||||
return None # Or some other default action
|
||||
|
||||
|
||||
def instantiate_tool(node_type, class_object, params):
|
||||
if node_type == "JsonSpec":
|
||||
params["dict_"] = load_file_into_dict(params.pop("path"))
|
||||
return class_object(**params)
|
||||
elif node_type == "PythonFunction":
|
||||
function_string = params["code"]
|
||||
if isinstance(function_string, str):
|
||||
return validate.eval_function(function_string)
|
||||
raise ValueError("Function should be a string")
|
||||
elif node_type.lower() == "tool":
|
||||
return class_object(**params)
|
||||
return None # Or some other default action
|
||||
|
||||
|
||||
def instantiate_toolkit(node_type, class_object, params):
|
||||
loaded_toolkit = class_object(**params)
|
||||
if toolkits_creator.has_create_function(node_type):
|
||||
return load_toolkits_executor(node_type, loaded_toolkit, params)
|
||||
return loaded_toolkit
|
||||
|
||||
|
||||
def instantiate_embedding(class_object, params):
|
||||
params.pop("model", None)
|
||||
try:
|
||||
return class_object(**params)
|
||||
except ValidationError:
|
||||
params = {
|
||||
key: value
|
||||
for key, value in params.items()
|
||||
if key in class_object.__fields__
|
||||
}
|
||||
return class_object(**params)
|
||||
|
||||
|
||||
def instantiate_vectorstore(class_object, params):
|
||||
if len(params.get("documents", [])) == 0:
|
||||
raise ValueError(
|
||||
"The source you provided did not load correctly or was empty."
|
||||
"This may cause an error in the vectorstore."
|
||||
)
|
||||
return class_object.from_documents(**params)
|
||||
|
||||
|
||||
def instantiate_documentloader(class_object, params):
|
||||
return class_object(**params).load()
|
||||
|
||||
|
||||
def instantiate_textsplitter(class_object, params):
|
||||
documents = params.pop("documents")
|
||||
text_splitter = class_object(**params)
|
||||
return text_splitter.split_documents(documents)
|
||||
|
||||
|
||||
def instantiate_utility(node_type, class_object, params):
|
||||
if node_type == "SQLDatabase":
|
||||
return class_object.from_uri(params.pop("uri"))
|
||||
return class_object(**params)
|
||||
|
||||
|
||||
def load_flow_from_json(path: str, build=True):
|
||||
# This is done to avoid circular imports
|
||||
from langflow.graph import Graph
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue