diff --git a/src/backend/langflow/interface/initialize/loading.py b/src/backend/langflow/interface/initialize/loading.py index ea9af8d63..65461f141 100644 --- a/src/backend/langflow/interface/initialize/loading.py +++ b/src/backend/langflow/interface/initialize/loading.py @@ -15,6 +15,7 @@ from langflow.interface.custom_lists import CUSTOM_NODES from langflow.interface.importing.utils import get_function, import_by_type from langflow.interface.toolkits.base import toolkits_creator from langflow.interface.chains.base import chain_creator +from langflow.interface.retrievers.base import retriever_creator from langflow.interface.utils import load_file_into_dict from langflow.utils import validate from langchain.chains.base import Chain @@ -79,10 +80,23 @@ def instantiate_based_on_type(class_object, base_type, node_type, params): return instantiate_chains(node_type, class_object, params) elif base_type == "llms": return instantiate_llm(node_type, class_object, params) + elif base_type == "retrievers": + return instantiate_retriever(node_type, class_object, params) else: return class_object(**params) +def instantiate_retriever(node_type, class_object, params): + if "retriever" in params and hasattr(params["retriever"], "as_retriever"): + params["retriever"] = params["retriever"].as_retriever() + if node_type in retriever_creator.from_method_nodes: + method = retriever_creator.from_method_nodes[node_type] + if class_method := getattr(class_object, method, None): + return class_method(**params) + raise ValueError(f"Method {method} not found in {class_object}") + return class_object(**params) + + def instantiate_llm(node_type, class_object, params: Dict): return class_object(**params)