refactor: Update SelfQueryRetrieverComponent build method to handle different input types

This commit is contained in:
ogabrielluiz 2024-06-10 14:29:16 -03:00
commit f872a3e753

View file

@ -4,7 +4,7 @@ from langchain.retrievers.self_query.base import SelfQueryRetriever
from langchain_core.vectorstores import VectorStore
from langflow.custom import CustomComponent
from langflow.field_typing import BaseLanguageModel
from langflow.field_typing import BaseLanguageModel, Text
from langflow.schema import Record
from langflow.schema.message import Message
@ -14,25 +14,54 @@ class SelfQueryRetrieverComponent(CustomComponent):
description: str = "Retriever that uses a vector store and an LLM to generate the vector store queries."
icon = "LangChain"
def build_config(self):
return {
"query": {
"display_name": "Query",
"input_types": ["Message", "Text"],
"info": "Query to be passed as input.",
},
"vectorstore": {
"display_name": "Vector Store",
"info": "Vector Store to be passed as input.",
},
"attribute_infos": {
"display_name": "Metadata Field Info",
"info": "Metadata Field Info to be passed as input.",
},
"document_content_description": {
"display_name": "Document Content Description",
"info": "Document Content Description to be passed as input.",
},
"llm": {
"display_name": "LLM",
"info": "LLM to be passed as input.",
},
}
def build(
self,
query: Message,
vectorstore: VectorStore,
metadata_field_info: list[AttributeInfo],
document_content_description: str,
attribute_infos: list[Record],
document_content_description: Text,
llm: BaseLanguageModel,
) -> Record:
metadata_field_info = [i[0] for i in metadata_field_info]
metadata_field_infos = [AttributeInfo(**record.data) for record in attribute_infos]
self_query_retriever = SelfQueryRetriever.from_llm(
llm,
vectorstore,
document_content_description,
metadata_field_info,
llm=llm,
vectorstore=vectorstore,
document_contents=document_content_description,
metadata_field_info=metadata_field_infos,
enable_limit=True,
)
input_text = query.text
if isinstance(query, Message):
input_text = query.text
elif isinstance(query, str):
input_text = query
else:
raise ValueError(f"Query type {type(query)} not supported.")
documents = self_query_retriever.invoke(input=input_text)
records = [Record.from_document(document) for document in documents]
self.status = records