From 72b6681f5a798cb15629842d86dba638face030b Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Thu, 29 Jun 2023 11:17:59 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A7=20chore(loading.py):=20add=20suppo?= =?UTF-8?q?rt=20for=20instantiating=20retrievers=20in=20the=20initializati?= =?UTF-8?q?on=20process=20=F0=9F=9A=80=20feat(loading.py):=20implement=20t?= =?UTF-8?q?he=20ability=20to=20instantiate=20retrievers=20based=20on=20nod?= =?UTF-8?q?e=20type=20and=20class=20object=20The=20`instantiate=5Fbased=5F?= =?UTF-8?q?on=5Ftype`=20function=20now=20includes=20a=20new=20condition=20?= =?UTF-8?q?to=20handle=20the=20instantiation=20of=20retrievers.=20The=20`i?= =?UTF-8?q?nstantiate=5Fretriever`=20function=20is=20introduced=20to=20han?= =?UTF-8?q?dle=20the=20specific=20logic=20for=20creating=20retrievers.=20T?= =?UTF-8?q?his=20change=20allows=20for=20the=20dynamic=20creation=20of=20r?= =?UTF-8?q?etrievers=20based=20on=20the=20provided=20node=20type=20and=20c?= =?UTF-8?q?lass=20object.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../langflow/interface/initialize/loading.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/backend/langflow/interface/initialize/loading.py b/src/backend/langflow/interface/initialize/loading.py index ea9af8d63..65461f141 100644 --- a/src/backend/langflow/interface/initialize/loading.py +++ b/src/backend/langflow/interface/initialize/loading.py @@ -15,6 +15,7 @@ from langflow.interface.custom_lists import CUSTOM_NODES from langflow.interface.importing.utils import get_function, import_by_type from langflow.interface.toolkits.base import toolkits_creator from langflow.interface.chains.base import chain_creator +from langflow.interface.retrievers.base import retriever_creator from langflow.interface.utils import load_file_into_dict from langflow.utils import validate from langchain.chains.base import Chain @@ -79,10 +80,23 @@ def instantiate_based_on_type(class_object, base_type, node_type, params): return instantiate_chains(node_type, class_object, params) elif base_type == "llms": return instantiate_llm(node_type, class_object, params) + elif base_type == "retrievers": + return instantiate_retriever(node_type, class_object, params) else: return class_object(**params) +def instantiate_retriever(node_type, class_object, params): + if "retriever" in params and hasattr(params["retriever"], "as_retriever"): + params["retriever"] = params["retriever"].as_retriever() + if node_type in retriever_creator.from_method_nodes: + method = retriever_creator.from_method_nodes[node_type] + if class_method := getattr(class_object, method, None): + return class_method(**params) + raise ValueError(f"Method {method} not found in {class_object}") + return class_object(**params) + + def instantiate_llm(node_type, class_object, params: Dict): return class_object(**params)