Refactor RetrievalQAWithSourcesChainComponent class
This commit is contained in:
parent
0387de2dd2
commit
c1a228fe6c
1 changed files with 20 additions and 4 deletions
|
|
@ -1,11 +1,11 @@
|
|||
from typing import Optional
|
||||
|
||||
from langchain.chains import RetrievalQAWithSourcesChain
|
||||
from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langflow import CustomComponent
|
||||
from langflow.field_typing import BaseLanguageModel, BaseMemory, BaseRetriever
|
||||
from langflow.field_typing import BaseLanguageModel, BaseMemory, BaseRetriever, Text
|
||||
|
||||
|
||||
class RetrievalQAWithSourcesChainComponent(CustomComponent):
|
||||
|
|
@ -31,8 +31,8 @@ class RetrievalQAWithSourcesChainComponent(CustomComponent):
|
|||
chain_type: str,
|
||||
memory: Optional[BaseMemory] = None,
|
||||
return_source_documents: Optional[bool] = True,
|
||||
) -> BaseQAWithSourcesChain:
|
||||
return RetrievalQAWithSourcesChain.from_chain_type(
|
||||
) -> Text:
|
||||
runnable = RetrievalQAWithSourcesChain.from_chain_type(
|
||||
llm=llm,
|
||||
chain_type=chain_type,
|
||||
combine_documents_chain=combine_documents_chain,
|
||||
|
|
@ -40,3 +40,19 @@ class RetrievalQAWithSourcesChainComponent(CustomComponent):
|
|||
return_source_documents=return_source_documents,
|
||||
retriever=retriever,
|
||||
)
|
||||
if isinstance(inputs, Document):
|
||||
inputs = inputs.page_content
|
||||
self.status = runnable
|
||||
input_key = runnable.input_keys[0]
|
||||
result = runnable.invoke({input_key: inputs})
|
||||
result = result.content if hasattr(result, "content") else result
|
||||
# Result is a dict with keys "query", "result" and "source_documents"
|
||||
# for now we just return the result
|
||||
records = self.to_records(result.get("source_documents"))
|
||||
references_str = ""
|
||||
if return_source_documents:
|
||||
references_str = self.create_references_from_records(records)
|
||||
result_str = result.get("result")
|
||||
final_result = "\n".join([result_str, references_str])
|
||||
self.status = final_result
|
||||
return final_result
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue