Refactor RetrievalQAWithSourcesChainComponent class

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-02-26 09:51:46 -03:00
commit c1a228fe6c

View file

@ -1,11 +1,11 @@
from typing import Optional
from langchain.chains import RetrievalQAWithSourcesChain
from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain_core.documents import Document
from langflow import CustomComponent
from langflow.field_typing import BaseLanguageModel, BaseMemory, BaseRetriever
from langflow.field_typing import BaseLanguageModel, BaseMemory, BaseRetriever, Text
class RetrievalQAWithSourcesChainComponent(CustomComponent):
@ -31,8 +31,8 @@ class RetrievalQAWithSourcesChainComponent(CustomComponent):
chain_type: str,
memory: Optional[BaseMemory] = None,
return_source_documents: Optional[bool] = True,
) -> BaseQAWithSourcesChain:
return RetrievalQAWithSourcesChain.from_chain_type(
) -> Text:
runnable = RetrievalQAWithSourcesChain.from_chain_type(
llm=llm,
chain_type=chain_type,
combine_documents_chain=combine_documents_chain,
@ -40,3 +40,19 @@ class RetrievalQAWithSourcesChainComponent(CustomComponent):
return_source_documents=return_source_documents,
retriever=retriever,
)
if isinstance(inputs, Document):
inputs = inputs.page_content
self.status = runnable
input_key = runnable.input_keys[0]
result = runnable.invoke({input_key: inputs})
result = result.content if hasattr(result, "content") else result
# Result is a dict with keys "query", "result" and "source_documents"
# for now we just return the result
records = self.to_records(result.get("source_documents"))
references_str = ""
if return_source_documents:
references_str = self.create_references_from_records(records)
result_str = result.get("result")
final_result = "\n".join([result_str, references_str])
self.status = final_result
return final_result