From 894fd16e8ed99b3b0aba621ae13cec06ed83cb88 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Thu, 29 Jun 2023 09:56:59 -0300 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat(retrievers):=20add=20base=20re?= =?UTF-8?q?triever=20class=20and=20frontend=20node=20class=20=F0=9F=90=9B?= =?UTF-8?q?=20fix(util.py):=20handle=20non-string=20types=20in=20format=5F?= =?UTF-8?q?dict=20function=20The=20base=20retriever=20class=20is=20added?= =?UTF-8?q?=20to=20provide=20a=20common=20interface=20for=20all=20retrieve?= =?UTF-8?q?rs=20in=20the=20language=20chain.=20The=20frontend=20node=20cla?= =?UTF-8?q?ss=20for=20retrievers=20is=20also=20added=20to=20handle=20the?= =?UTF-8?q?=20formatting=20of=20fields=20specific=20to=20retrievers.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In the util.py file, a fix is made to handle non-string types in the format_dict function. Previously, if the type of a field was not a string, an error would occur. This fix ensures that the type is converted to a string before further processing. --- .../langflow/interface/retrievers/base.py | 49 +++++++++++++++++++ .../template/frontend_node/retrievers.py | 15 ++++++ src/backend/langflow/utils/util.py | 3 ++ 3 files changed, 67 insertions(+) create mode 100644 src/backend/langflow/interface/retrievers/base.py create mode 100644 src/backend/langflow/template/frontend_node/retrievers.py diff --git a/src/backend/langflow/interface/retrievers/base.py b/src/backend/langflow/interface/retrievers/base.py new file mode 100644 index 000000000..3eb38861e --- /dev/null +++ b/src/backend/langflow/interface/retrievers/base.py @@ -0,0 +1,49 @@ +from typing import Any, Dict, List, Optional, Type + +from langchain import retrievers + +from langflow.interface.base import LangChainTypeCreator +from langflow.interface.importing.utils import import_class +from langflow.settings import settings +from langflow.template.frontend_node.retrievers import RetrieverFrontendNode +from langflow.utils.logger import logger +from langflow.utils.util import build_template_from_method + + +class RetrieverCreator(LangChainTypeCreator): + type_name: str = "retrievers" + + @property + def frontend_node_class(self) -> Type[RetrieverFrontendNode]: + return RetrieverFrontendNode + + @property + def type_to_loader_dict(self) -> Dict: + if self.type_dict is None: + self.type_dict: dict[str, Any] = { + retriever_name: import_class(f"langchain.retrievers.{retriever_name}") + for retriever_name in retrievers.__all__ + } + return self.type_dict + + def get_signature(self, name: str) -> Optional[Dict]: + """Get the signature of an embedding.""" + try: + return build_template_from_method( + name, type_to_cls_dict=self.type_to_loader_dict, method_name="from_llm" + ) + except ValueError as exc: + raise ValueError(f"Retriever {name} not found") from exc + except AttributeError as exc: + logger.error(f"Retriever {name} not loaded: {exc}") + return None + + def to_list(self) -> List[str]: + return [ + retriever + for retriever in self.type_to_loader_dict.keys() + if retriever in settings.retrievers or settings.dev + ] + + +retriever_creator = RetrieverCreator() diff --git a/src/backend/langflow/template/frontend_node/retrievers.py b/src/backend/langflow/template/frontend_node/retrievers.py new file mode 100644 index 000000000..b482c8b84 --- /dev/null +++ b/src/backend/langflow/template/frontend_node/retrievers.py @@ -0,0 +1,15 @@ +from typing import Optional + +from langflow.template.field.base import TemplateField +from langflow.template.frontend_node.base import FrontendNode + + +class RetrieverFrontendNode(FrontendNode): + @staticmethod + def format_field(field: TemplateField, name: Optional[str] = None) -> None: + FrontendNode.format_field(field, name) + # Define common field attributes + field.show = True + if field.name == "parser_key": + field.display_name = "Parser Key" + field.password = False diff --git a/src/backend/langflow/utils/util.py b/src/backend/langflow/utils/util.py index 7fcf1f4d4..02a9520b4 100644 --- a/src/backend/langflow/utils/util.py +++ b/src/backend/langflow/utils/util.py @@ -233,6 +233,9 @@ def format_dict(d, name: Optional[str] = None): _type = value["type"] + if not isinstance(_type, str): + _type = _type.__name__ + # Remove 'Optional' wrapper if "Optional" in _type: _type = _type.replace("Optional[", "")[:-1]