Add support for Text and Document inputs in RetrievalQAComponent

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-02-06 23:11:41 -03:00
commit 05b088cdfe

View file

@ -2,8 +2,9 @@ from typing import Callable, Optional, Union
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.retrieval_qa.base import BaseRetrievalQA, RetrievalQA
from langchain_core.documents import Document
from langflow import CustomComponent
from langflow.field_typing import BaseMemory, BaseRetriever
from langflow.field_typing import BaseMemory, BaseRetriever, Text
class RetrievalQAComponent(CustomComponent):
@ -18,18 +19,20 @@ class RetrievalQAComponent(CustomComponent):
"input_key": {"display_name": "Input Key", "advanced": True},
"output_key": {"display_name": "Output Key", "advanced": True},
"return_source_documents": {"display_name": "Return Source Documents"},
"inputs": {"display_name": "Input", "input_types": ["Text", "Document"]},
}
def build(
self,
combine_documents_chain: BaseCombineDocumentsChain,
retriever: BaseRetriever,
inputs: str = "",
memory: Optional[BaseMemory] = None,
input_key: str = "query",
output_key: str = "result",
return_source_documents: bool = True,
) -> Union[BaseRetrievalQA, Callable]:
return RetrievalQA(
) -> Union[BaseRetrievalQA, Callable, Text]:
runnable = RetrievalQA(
combine_documents_chain=combine_documents_chain,
retriever=retriever,
memory=memory,
@ -37,3 +40,8 @@ class RetrievalQAComponent(CustomComponent):
output_key=output_key,
return_source_documents=return_source_documents,
)
if isinstance(inputs, Document):
inputs = inputs.page_content
result = runnable.invoke({input_key: inputs})
return result.content if hasattr(result, "content") else result