Refactor RetrievalQAComponent and RetrievalQAWithSourcesChainComponent
This commit is contained in:
parent
0d3435efeb
commit
39e7a92f45
2 changed files with 15 additions and 8 deletions
|
|
@ -1,11 +1,11 @@
|
|||
from typing import Optional
|
||||
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.retrieval_qa.base import RetrievalQA
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langflow.field_typing import BaseMemory, BaseRetriever, Text
|
||||
from langflow.field_typing import BaseLanguageModel, BaseMemory, BaseRetriever, Text
|
||||
from langflow.interface.custom.custom_component import CustomComponent
|
||||
from langflow.schema.schema import Record
|
||||
|
||||
|
||||
class RetrievalQAComponent(CustomComponent):
|
||||
|
|
@ -14,7 +14,8 @@ class RetrievalQAComponent(CustomComponent):
|
|||
|
||||
def build_config(self):
|
||||
return {
|
||||
"combine_documents_chain": {"display_name": "Combine Documents Chain"},
|
||||
"llm": {"display_name": "LLM"},
|
||||
"chain_type": {"display_name": "Chain Type", "options": ["Stuff", "Map Reduce", "Refine", "Map Rerank"]},
|
||||
"retriever": {"display_name": "Retriever"},
|
||||
"memory": {"display_name": "Memory", "required": False},
|
||||
"input_key": {"display_name": "Input Key", "advanced": True},
|
||||
|
|
@ -22,13 +23,14 @@ class RetrievalQAComponent(CustomComponent):
|
|||
"return_source_documents": {"display_name": "Return Source Documents"},
|
||||
"input_value": {
|
||||
"display_name": "Input",
|
||||
"input_types": ["Text", "Document"],
|
||||
"input_types": ["Record", "Document"],
|
||||
},
|
||||
}
|
||||
|
||||
def build(
|
||||
self,
|
||||
combine_documents_chain: BaseCombineDocumentsChain,
|
||||
llm: BaseLanguageModel,
|
||||
chain_type: str,
|
||||
retriever: BaseRetriever,
|
||||
input_value: str = "",
|
||||
memory: Optional[BaseMemory] = None,
|
||||
|
|
@ -36,8 +38,10 @@ class RetrievalQAComponent(CustomComponent):
|
|||
output_key: str = "result",
|
||||
return_source_documents: bool = True,
|
||||
) -> Text:
|
||||
runnable = RetrievalQA(
|
||||
combine_documents_chain=combine_documents_chain,
|
||||
chain_type = chain_type.lower().replace(" ", "_")
|
||||
runnable = RetrievalQA.from_chain_type(
|
||||
llm=llm,
|
||||
chain_type=chain_type,
|
||||
retriever=retriever,
|
||||
memory=memory,
|
||||
input_key=input_key,
|
||||
|
|
@ -46,6 +50,8 @@ class RetrievalQAComponent(CustomComponent):
|
|||
)
|
||||
if isinstance(input_value, Document):
|
||||
input_value = input_value.page_content
|
||||
if isinstance(input_value, Record):
|
||||
input_value = input_value.get_text()
|
||||
self.status = runnable
|
||||
result = runnable.invoke({input_key: input_value})
|
||||
result = result.content if hasattr(result, "content") else result
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ class RetrievalQAWithSourcesChainComponent(CustomComponent):
|
|||
"llm": {"display_name": "LLM"},
|
||||
"chain_type": {
|
||||
"display_name": "Chain Type",
|
||||
"options": ["stuff", "map_reduce", "map_rerank", "refine"],
|
||||
"options": ["Stuff", "Map Reduce", "Refine", "Map Rerank"],
|
||||
"info": "The type of chain to use to combined Documents.",
|
||||
},
|
||||
"memory": {"display_name": "Memory"},
|
||||
|
|
@ -37,6 +37,7 @@ class RetrievalQAWithSourcesChainComponent(CustomComponent):
|
|||
memory: Optional[BaseMemory] = None,
|
||||
return_source_documents: Optional[bool] = True,
|
||||
) -> Text:
|
||||
chain_type = chain_type.lower().replace(" ", "_")
|
||||
runnable = RetrievalQAWithSourcesChain.from_chain_type(
|
||||
llm=llm,
|
||||
chain_type=chain_type,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue