From 92eb45dd42de1320c79e063d724a05377363de4b Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Fri, 7 Jul 2023 19:06:19 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A7=20chore(config.yaml):=20add=20docu?= =?UTF-8?q?mentation=20for=20SQLDatabase=20wrapper=20=F0=9F=90=9B=20fix(ut?= =?UTF-8?q?ils.py):=20fix=20import=20of=20wrapper=5Fcreator=20from=20langf?= =?UTF-8?q?low.interface.wrappers.base=20=F0=9F=94=A7=20chore(loading.py):?= =?UTF-8?q?=20add=20support=20for=20instantiating=20wrappers=20based=20on?= =?UTF-8?q?=20node=20type=20=F0=9F=94=A7=20chore(base.py):=20add=20support?= =?UTF-8?q?=20for=20creating=20SQLDatabase=20wrapper=20from=20URI=20The=20?= =?UTF-8?q?config.yaml=20file=20was=20modified=20to=20add=20documentation?= =?UTF-8?q?=20for=20the=20SQLDatabase=20wrapper.=20In=20utils.py,=20the=20?= =?UTF-8?q?import=20of=20wrapper=5Fcreator=20from=20langflow.interface.wra?= =?UTF-8?q?ppers.base=20was=20fixed.=20In=20loading.py,=20support=20was=20?= =?UTF-8?q?added=20for=20instantiating=20wrappers=20based=20on=20the=20nod?= =?UTF-8?q?e=20type.=20In=20base.py,=20support=20was=20added=20for=20creat?= =?UTF-8?q?ing=20the=20SQLDatabase=20wrapper=20from=20a=20URI.=20These=20c?= =?UTF-8?q?hanges=20were=20made=20to=20improve=20the=20functionality=20and?= =?UTF-8?q?=20maintainability=20of=20the=20codebase.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/langflow/config.yaml | 2 ++ .../langflow/interface/importing/utils.py | 7 ++++++- .../langflow/interface/initialize/loading.py | 12 ++++++++++++ src/backend/langflow/interface/wrappers/base.py | 17 ++++++++++++++--- 4 files changed, 34 insertions(+), 4 deletions(-) diff --git a/src/backend/langflow/config.yaml b/src/backend/langflow/config.yaml index ab4db6fd9..f4f83301a 100644 --- a/src/backend/langflow/config.yaml +++ b/src/backend/langflow/config.yaml @@ -280,6 +280,8 @@ vectorstores: wrappers: RequestsWrapper: documentation: "" + SQLDatabase: + documentation: "" output_parsers: StructuredOutputParser: documentation: "https://python.langchain.com/docs/modules/model_io/output_parsers/structured" diff --git a/src/backend/langflow/interface/importing/utils.py b/src/backend/langflow/interface/importing/utils.py index 3c7f89b5b..ccfd8d5dd 100644 --- a/src/backend/langflow/interface/importing/utils.py +++ b/src/backend/langflow/interface/importing/utils.py @@ -10,6 +10,7 @@ from langchain.chains.base import Chain from langchain.chat_models.base import BaseChatModel from langchain.tools import BaseTool from langflow.utils import validate +from langflow.interface.wrappers.base import wrapper_creator def import_module(module_path: str) -> Any: @@ -96,7 +97,11 @@ def import_prompt(prompt: str) -> Type[PromptTemplate]: def import_wrapper(wrapper: str) -> Any: """Import wrapper from wrapper name""" - return import_module(f"from langchain.requests import {wrapper}") + if ( + isinstance(wrapper_creator.type_dict, dict) + and wrapper in wrapper_creator.type_dict + ): + return wrapper_creator.type_dict.get(wrapper) def import_toolkit(toolkit: str) -> Any: diff --git a/src/backend/langflow/interface/initialize/loading.py b/src/backend/langflow/interface/initialize/loading.py index 37dbdcda1..25149cd4b 100644 --- a/src/backend/langflow/interface/initialize/loading.py +++ b/src/backend/langflow/interface/initialize/loading.py @@ -20,6 +20,7 @@ from langflow.interface.toolkits.base import toolkits_creator from langflow.interface.chains.base import chain_creator from langflow.interface.output_parsers.base import output_parser_creator from langflow.interface.retrievers.base import retriever_creator +from langflow.interface.wrappers.base import wrapper_creator from langflow.interface.utils import load_file_into_dict from langflow.utils import validate from langchain.chains.base import Chain @@ -89,10 +90,21 @@ def instantiate_based_on_type(class_object, base_type, node_type, params): return instantiate_retriever(node_type, class_object, params) elif base_type == "memory": return instantiate_memory(node_type, class_object, params) + elif base_type == "wrappers": + return instantiate_wrapper(node_type, class_object, params) else: return class_object(**params) +def instantiate_wrapper(node_type, class_object, params): + if node_type in wrapper_creator.from_method_nodes: + method = wrapper_creator.from_method_nodes[node_type] + if class_method := getattr(class_object, method, None): + return class_method(**params) + raise ValueError(f"Method {method} not found in {class_object}") + return class_object(**params) + + def instantiate_output_parser(node_type, class_object, params): if node_type in output_parser_creator.from_method_nodes: method = output_parser_creator.from_method_nodes[node_type] diff --git a/src/backend/langflow/interface/wrappers/base.py b/src/backend/langflow/interface/wrappers/base.py index f5773d07a..77e38f921 100644 --- a/src/backend/langflow/interface/wrappers/base.py +++ b/src/backend/langflow/interface/wrappers/base.py @@ -1,25 +1,36 @@ from typing import Dict, List, Optional -from langchain import requests +from langchain import requests, sql_database from langflow.interface.base import LangChainTypeCreator from langflow.utils.logger import logger -from langflow.utils.util import build_template_from_class +from langflow.utils.util import build_template_from_class, build_template_from_method class WrapperCreator(LangChainTypeCreator): type_name: str = "wrappers" + from_method_nodes = {"SQLDatabase": "from_uri"} + @property def type_to_loader_dict(self) -> Dict: if self.type_dict is None: self.type_dict = { - wrapper.__name__: wrapper for wrapper in [requests.TextRequestsWrapper] + wrapper.__name__: wrapper + for wrapper in [requests.TextRequestsWrapper, sql_database.SQLDatabase] } return self.type_dict def get_signature(self, name: str) -> Optional[Dict]: try: + if name in self.from_method_nodes: + return build_template_from_method( + name, + type_to_cls_dict=self.type_to_loader_dict, + add_function=True, + method_name=self.from_method_nodes[name], + ) + return build_template_from_class(name, self.type_to_loader_dict) except ValueError as exc: raise ValueError("Wrapper not found") from exc