diff --git a/src/backend/langflow/components/chains/RetrievalQAWithSourcesChain.py b/src/backend/langflow/components/chains/RetrievalQAWithSourcesChain.py index 3c46cd8bd..4ec1ce886 100644 --- a/src/backend/langflow/components/chains/RetrievalQAWithSourcesChain.py +++ b/src/backend/langflow/components/chains/RetrievalQAWithSourcesChain.py @@ -1,11 +1,11 @@ from typing import Optional from langchain.chains import RetrievalQAWithSourcesChain -from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain +from langchain_core.documents import Document from langflow import CustomComponent -from langflow.field_typing import BaseLanguageModel, BaseMemory, BaseRetriever +from langflow.field_typing import BaseLanguageModel, BaseMemory, BaseRetriever, Text class RetrievalQAWithSourcesChainComponent(CustomComponent): @@ -31,8 +31,8 @@ class RetrievalQAWithSourcesChainComponent(CustomComponent): chain_type: str, memory: Optional[BaseMemory] = None, return_source_documents: Optional[bool] = True, - ) -> BaseQAWithSourcesChain: - return RetrievalQAWithSourcesChain.from_chain_type( + ) -> Text: + runnable = RetrievalQAWithSourcesChain.from_chain_type( llm=llm, chain_type=chain_type, combine_documents_chain=combine_documents_chain, @@ -40,3 +40,19 @@ class RetrievalQAWithSourcesChainComponent(CustomComponent): return_source_documents=return_source_documents, retriever=retriever, ) + if isinstance(inputs, Document): + inputs = inputs.page_content + self.status = runnable + input_key = runnable.input_keys[0] + result = runnable.invoke({input_key: inputs}) + result = result.content if hasattr(result, "content") else result + # Result is a dict with keys "query", "result" and "source_documents" + # for now we just return the result + records = self.to_records(result.get("source_documents")) + references_str = "" + if return_source_documents: + references_str = self.create_references_from_records(records) + result_str = result.get("result") + final_result = "\n".join([result_str, references_str]) + self.status = final_result + return final_result