From 05b088cdfeeb1842f0cd6a442fb6e924d3f5355c Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 6 Feb 2024 23:11:41 -0300 Subject: [PATCH] Add support for Text and Document inputs in RetrievalQAComponent --- .../langflow/components/chains/RetrievalQA.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/backend/langflow/components/chains/RetrievalQA.py b/src/backend/langflow/components/chains/RetrievalQA.py index 5f1232443..3fc1933d9 100644 --- a/src/backend/langflow/components/chains/RetrievalQA.py +++ b/src/backend/langflow/components/chains/RetrievalQA.py @@ -2,8 +2,9 @@ from typing import Callable, Optional, Union from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.retrieval_qa.base import BaseRetrievalQA, RetrievalQA +from langchain_core.documents import Document from langflow import CustomComponent -from langflow.field_typing import BaseMemory, BaseRetriever +from langflow.field_typing import BaseMemory, BaseRetriever, Text class RetrievalQAComponent(CustomComponent): @@ -18,18 +19,20 @@ class RetrievalQAComponent(CustomComponent): "input_key": {"display_name": "Input Key", "advanced": True}, "output_key": {"display_name": "Output Key", "advanced": True}, "return_source_documents": {"display_name": "Return Source Documents"}, + "inputs": {"display_name": "Input", "input_types": ["Text", "Document"]}, } def build( self, combine_documents_chain: BaseCombineDocumentsChain, retriever: BaseRetriever, + inputs: str = "", memory: Optional[BaseMemory] = None, input_key: str = "query", output_key: str = "result", return_source_documents: bool = True, - ) -> Union[BaseRetrievalQA, Callable]: - return RetrievalQA( + ) -> Union[BaseRetrievalQA, Callable, Text]: + runnable = RetrievalQA( combine_documents_chain=combine_documents_chain, retriever=retriever, memory=memory, @@ -37,3 +40,8 @@ class RetrievalQAComponent(CustomComponent): output_key=output_key, return_source_documents=return_source_documents, ) + if isinstance(inputs, Document): + inputs = inputs.page_content + + result = runnable.invoke({input_key: inputs}) + return result.content if hasattr(result, "content") else result