From 7c1c5131065e8bc7effc2f0e120b394bea77cd28 Mon Sep 17 00:00:00 2001 From: Gabriel Almeida Date: Wed, 10 May 2023 11:41:34 -0300 Subject: [PATCH] refactor(loading.py): refactor instantiate_class function to improve readability and maintainability --- src/backend/langflow/interface/loading.py | 183 ++++++++++++---------- 1 file changed, 104 insertions(+), 79 deletions(-) diff --git a/src/backend/langflow/interface/loading.py b/src/backend/langflow/interface/loading.py index a2d9a799c..0e23c2b45 100644 --- a/src/backend/langflow/interface/loading.py +++ b/src/backend/langflow/interface/loading.py @@ -13,9 +13,9 @@ from langchain.agents.load_tools import ( ) from langchain.agents.loading import load_agent_from_config from langchain.agents.tools import Tool +from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager from langchain.chains.loading import load_chain_from_config -from langchain.base_language import BaseLanguageModel from langchain.llms.loading import load_llm_from_config from pydantic import ValidationError @@ -30,88 +30,19 @@ from langflow.utils import util, validate def instantiate_class(node_type: str, base_type: str, params: Dict) -> Any: """Instantiate class from module type and key, and params""" + params = convert_params_to_sets(params) + if node_type in CUSTOM_AGENTS: - if custom_agent := CUSTOM_AGENTS.get(node_type): - return custom_agent.initialize(**params) # type: ignore - params = process_params(params) + custom_agent = CUSTOM_AGENTS.get(node_type) + if custom_agent: + return custom_agent.initialize(**params) + class_object = import_by_type(_type=base_type, name=node_type) - # check if it is a class before using issubclass - - # if isinstance(class_object, type) and issubclass(class_object, BaseModel): - # # validate params - # fields = class_object.__fields__ - # params = {key: value for key, value in params.items() if key in fields} - - if base_type == "agents": - # We need to initialize it differently - return load_agent_executor(class_object, params) - elif base_type == "prompts": - if node_type == "ZeroShotPrompt": - if "tools" not in params: - params["tools"] = [] - return ZeroShotAgent.create_prompt(**params) - elif base_type == "tools": - if node_type == "JsonSpec": - params["dict_"] = load_file_into_dict(params.pop("path")) - return class_object(**params) - elif node_type == "PythonFunction": - # If the node_type is "PythonFunction" - # we need to get the function from the params - # which will be a str containing a python function - # and then we need to compile it and return the function - # as the instance - function_string = params["code"] - if isinstance(function_string, str): - return validate.eval_function(function_string) - raise ValueError("Function should be a string") - elif node_type.lower() == "tool": - return class_object(**params) - elif base_type == "toolkits": - loaded_toolkit = class_object(**params) - # Check if node_type has a loader - if toolkits_creator.has_create_function(node_type): - return load_toolkits_executor(node_type, loaded_toolkit, params) - return loaded_toolkit - elif base_type == "embeddings": - # ? Why remove model from params? - try: - params.pop("model") - except KeyError: - pass - # remove all params that are not in class_object.__fields__ - try: - return class_object(**params) - except ValidationError: - params = { - key: value - for key, value in params.items() - if key in class_object.__fields__ - } - return class_object(**params) - elif base_type == "vectorstores": - if len(params.get("documents", [])) == 0: - # Error when the pdf or other source was not correctly - # loaded. - raise ValueError( - "The source you provided did not load correctly or was empty." - "This may cause an error in the vectorstore." - ) - return class_object.from_documents(**params) - elif base_type == "documentloaders": - return class_object(**params).load() - elif base_type == "textsplitters": - documents = params.pop("documents") - text_splitter = class_object(**params) - return text_splitter.split_documents(documents) - elif base_type == "utilities": - if node_type == "SQLDatabase": - return class_object.from_uri(params.pop("uri")) - - return class_object(**params) + return instantiate_based_on_type(class_object, base_type, node_type, params) -def process_params(params): - """Process params""" +def convert_params_to_sets(params): + """Convert certain params to sets""" if "allowed_special" in params: params["allowed_special"] = set(params["allowed_special"]) if "disallowed_special" in params: @@ -119,6 +50,100 @@ def process_params(params): return params +def instantiate_based_on_type(class_object, base_type, node_type, params): + if base_type == "agents": + return instantiate_agent(class_object, params) + elif base_type == "prompts": + return instantiate_prompt(node_type, params) + elif base_type == "tools": + return instantiate_tool(node_type, class_object, params) + elif base_type == "toolkits": + return instantiate_toolkit(node_type, class_object, params) + elif base_type == "embeddings": + return instantiate_embedding(class_object, params) + elif base_type == "vectorstores": + return instantiate_vectorstore(class_object, params) + elif base_type == "documentloaders": + return instantiate_documentloader(class_object, params) + elif base_type == "textsplitters": + return instantiate_textsplitter(class_object, params) + elif base_type == "utilities": + return instantiate_utility(node_type, class_object, params) + else: + return class_object(**params) + + +def instantiate_agent(class_object, params): + return load_agent_executor(class_object, params) + + +def instantiate_prompt(node_type, params): + if node_type == "ZeroShotPrompt": + if "tools" not in params: + params["tools"] = [] + return ZeroShotAgent.create_prompt(**params) + return None # Or some other default action + + +def instantiate_tool(node_type, class_object, params): + if node_type == "JsonSpec": + params["dict_"] = load_file_into_dict(params.pop("path")) + return class_object(**params) + elif node_type == "PythonFunction": + function_string = params["code"] + if isinstance(function_string, str): + return validate.eval_function(function_string) + raise ValueError("Function should be a string") + elif node_type.lower() == "tool": + return class_object(**params) + return None # Or some other default action + + +def instantiate_toolkit(node_type, class_object, params): + loaded_toolkit = class_object(**params) + if toolkits_creator.has_create_function(node_type): + return load_toolkits_executor(node_type, loaded_toolkit, params) + return loaded_toolkit + + +def instantiate_embedding(class_object, params): + params.pop("model", None) + try: + return class_object(**params) + except ValidationError: + params = { + key: value + for key, value in params.items() + if key in class_object.__fields__ + } + return class_object(**params) + + +def instantiate_vectorstore(class_object, params): + if len(params.get("documents", [])) == 0: + raise ValueError( + "The source you provided did not load correctly or was empty." + "This may cause an error in the vectorstore." + ) + return class_object.from_documents(**params) + + +def instantiate_documentloader(class_object, params): + return class_object(**params).load() + + +def instantiate_textsplitter(class_object, params): + documents = params.pop("documents") + text_splitter = class_object(**params) + return text_splitter.split_documents(documents) + + +def instantiate_utility(node_type, class_object, params): + if node_type == "SQLDatabase": + return class_object.from_uri(params.pop("uri")) + return class_object(**params) + + def load_flow_from_json(path: str, build=True): # This is done to avoid circular imports from langflow.graph import Graph