fix: update Self Query Retriever Component (#3653)
* 🔧 (pyproject.toml): add lark dependency to support parsing and processing of grammars in the project ♻️ (SelfQueryRetriever.py): refactor input types in metadata fields to improve clarity and maintainability * 📝 (SelfQueryRetriever.py): Update class name and imports for consistency and clarity 📝 (SelfQueryRetriever.py): Refactor input and output definitions for better readability and maintainability 📝 (SelfQueryRetriever.py): Refactor method signatures and variable names for improved code organization and understanding * [autofix.ci] apply automated fixes * ♻️ (SelfQueryRetriever.py): Remove unused import 'VectorStore' to clean up the code and improve maintainability. --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
076f4f0772
commit
706d559003
3 changed files with 75 additions and 51 deletions
19
poetry.lock
generated
19
poetry.lock
generated
|
|
@ -5111,6 +5111,23 @@ langchain = ["langchain (>=0.2.0,<0.3.0)"]
|
|||
litellm = ["litellm (>=1.40.15,<2.0.0)"]
|
||||
openai = ["openai (>=1.42.0,<2.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "lark"
|
||||
version = "1.2.2"
|
||||
description = "a modern parsing library"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "lark-1.2.2-py3-none-any.whl", hash = "sha256:c2276486b02f0f1b90be155f2c8ba4a8e194d42775786db622faccd652d8e80c"},
|
||||
{file = "lark-1.2.2.tar.gz", hash = "sha256:ca807d0162cd16cef15a8feecb862d7319e7a09bdb13aef927968e45040fed80"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
atomic-cache = ["atomicwrites"]
|
||||
interegular = ["interegular (>=0.3.1,<0.4.0)"]
|
||||
nearley = ["js2py"]
|
||||
regex = ["regex"]
|
||||
|
||||
[[package]]
|
||||
name = "litellm"
|
||||
version = "1.44.8"
|
||||
|
|
@ -11826,4 +11843,4 @@ local = ["ctransformers", "llama-cpp-python", "sentence-transformers"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.10,<3.13"
|
||||
content-hash = "d6fd6b327ba3ded7e8eefd2505c0cc6f15d4a5a9f1fd34020dc25324e9f13be1"
|
||||
content-hash = "0be9d1ea13484a0ccf92511c188edaab862ab3b883813efaca2f9bfbbfccd2a8"
|
||||
|
|
|
|||
|
|
@ -106,6 +106,7 @@ composio-langchain = "^0.5.8"
|
|||
spider-client = "^0.0.27"
|
||||
nltk = "^3.9.1"
|
||||
bson = "^0.5.10"
|
||||
lark = "^1.2.2"
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
|
|
|
|||
|
|
@ -1,70 +1,76 @@
|
|||
# from langflow.field_typing import Data
|
||||
from typing import List
|
||||
|
||||
from langchain.chains.query_constructor.base import AttributeInfo
|
||||
from langchain.retrievers.self_query.base import SelfQueryRetriever
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
from langflow.custom import CustomComponent
|
||||
from langflow.field_typing import LanguageModel, Text
|
||||
from langflow.custom import Component
|
||||
from langflow.inputs import HandleInput, MessageTextInput
|
||||
from langflow.io import Output
|
||||
from langflow.schema import Data
|
||||
from langflow.schema.message import Message
|
||||
|
||||
|
||||
class SelfQueryRetrieverComponent(CustomComponent):
|
||||
display_name: str = "Self Query Retriever"
|
||||
description: str = "Retriever that uses a vector store and an LLM to generate the vector store queries."
|
||||
class SelfQueryRetrieverComponent(Component):
|
||||
display_name = "Self Query Retriever"
|
||||
description = "Retriever that uses a vector store and an LLM to generate the vector store queries."
|
||||
name = "SelfQueryRetriever"
|
||||
icon = "LangChain"
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"query": {
|
||||
"display_name": "Query",
|
||||
"input_types": ["Message", "Text"],
|
||||
"info": "Query to be passed as input.",
|
||||
},
|
||||
"vectorstore": {
|
||||
"display_name": "Vector Store",
|
||||
"info": "Vector Store to be passed as input.",
|
||||
},
|
||||
"attribute_infos": {
|
||||
"display_name": "Metadata Field Info",
|
||||
"info": "Metadata Field Info to be passed as input.",
|
||||
},
|
||||
"document_content_description": {
|
||||
"display_name": "Document Content Description",
|
||||
"info": "Document Content Description to be passed as input.",
|
||||
},
|
||||
"llm": {
|
||||
"display_name": "LLM",
|
||||
"info": "LLM to be passed as input.",
|
||||
},
|
||||
}
|
||||
inputs = [
|
||||
HandleInput(
|
||||
name="query",
|
||||
display_name="Query",
|
||||
info="Query to be passed as input.",
|
||||
input_types=["Message", "Text"],
|
||||
),
|
||||
HandleInput(
|
||||
name="vectorstore",
|
||||
display_name="Vector Store",
|
||||
info="Vector Store to be passed as input.",
|
||||
input_types=["VectorStore"],
|
||||
),
|
||||
HandleInput(
|
||||
name="attribute_infos",
|
||||
display_name="Metadata Field Info",
|
||||
info="Metadata Field Info to be passed as input.",
|
||||
input_types=["Data"],
|
||||
is_list=True,
|
||||
),
|
||||
MessageTextInput(
|
||||
name="document_content_description",
|
||||
display_name="Document Content Description",
|
||||
info="Document Content Description to be passed as input.",
|
||||
),
|
||||
HandleInput(
|
||||
name="llm",
|
||||
display_name="LLM",
|
||||
info="LLM to be passed as input.",
|
||||
input_types=["LanguageModel"],
|
||||
),
|
||||
]
|
||||
|
||||
def build(
|
||||
self,
|
||||
query: Message,
|
||||
vectorstore: VectorStore,
|
||||
attribute_infos: list[Data],
|
||||
document_content_description: Text,
|
||||
llm: LanguageModel,
|
||||
) -> Data:
|
||||
metadata_field_infos = [AttributeInfo(**value.data) for value in attribute_infos]
|
||||
outputs = [
|
||||
Output(display_name="Retrieved Documents", name="documents", method="retrieve_documents"),
|
||||
]
|
||||
|
||||
def retrieve_documents(self) -> List[Data]:
|
||||
metadata_field_infos = [AttributeInfo(**value.data) for value in self.attribute_infos]
|
||||
self_query_retriever = SelfQueryRetriever.from_llm(
|
||||
llm=llm,
|
||||
vectorstore=vectorstore,
|
||||
document_contents=document_content_description,
|
||||
llm=self.llm,
|
||||
vectorstore=self.vectorstore,
|
||||
document_contents=self.document_content_description,
|
||||
metadata_field_info=metadata_field_infos,
|
||||
enable_limit=True,
|
||||
)
|
||||
|
||||
if isinstance(query, Message):
|
||||
input_text = query.text
|
||||
elif isinstance(query, str):
|
||||
input_text = query
|
||||
if isinstance(self.query, Message):
|
||||
input_text = self.query.text
|
||||
elif isinstance(self.query, str):
|
||||
input_text = self.query
|
||||
else:
|
||||
raise ValueError(f"Query type {type(self.query)} not supported.")
|
||||
|
||||
if not isinstance(query, str):
|
||||
raise ValueError(f"Query type {type(query)} not supported.")
|
||||
documents = self_query_retriever.invoke(input=input_text, config={"callbacks": self.get_langchain_callbacks()})
|
||||
data = [Data.from_document(document) for document in documents]
|
||||
self.status = data
|
||||
return data # type: ignore
|
||||
return data
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue