Feat: introducing Graph RAG component (#7056)
* GraphRAG retriever componet, unit test, module confiugration and extra dependencies (faker for testing and langchain-graph-retriever
* [autofix.ci] apply automated fixes
* 🔧 (test_graph_rag_component.py): Fix linting issues by adding noqa comments to ignore S311 rule for lines with random.choice and random.randint functions.
* [autofix.ci] apply automated fixes
* Updated graph retriever version and added graph component to the same branch
* Removed uv.lock from branch
* Re-added uv.lock
---------
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: cristhianzl <cristhian.lousa@gmail.com>
Co-authored-by: Eric Hare <ericrhare@gmail.com>
This commit is contained in:
parent
ee43c51297
commit
b5a93b4c55
5 changed files with 1457 additions and 944 deletions
|
|
@ -115,6 +115,8 @@ dependencies = [
|
|||
"apify-client>=1.8.1",
|
||||
"pylint>=3.3.4",
|
||||
"ruff>=0.9.7",
|
||||
"langchain-graph-retriever==0.6.1",
|
||||
"graph-retriever==0.6.1",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
|
|
@ -158,6 +160,7 @@ dev = [
|
|||
"hypothesis>=6.123.17",
|
||||
"locust>=2.32.9",
|
||||
"pytest-rerunfailures>=15.0",
|
||||
"faker>=37.0.0",
|
||||
]
|
||||
|
||||
[tool.uv.sources]
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from .clickhouse import ClickhouseVectorStoreComponent
|
|||
from .couchbase import CouchbaseVectorStoreComponent
|
||||
from .elasticsearch import ElasticsearchVectorStoreComponent
|
||||
from .faiss import FaissVectorStoreComponent
|
||||
from .graph_rag import GraphRAGComponent
|
||||
from .hcd import HCDVectorStoreComponent
|
||||
from .milvus import MilvusVectorStoreComponent
|
||||
from .mongodb_atlas import MongoVectorStoreComponent
|
||||
|
|
@ -32,6 +33,7 @@ __all__ = [
|
|||
"CouchbaseVectorStoreComponent",
|
||||
"ElasticsearchVectorStoreComponent",
|
||||
"FaissVectorStoreComponent",
|
||||
"GraphRAGComponent",
|
||||
"HCDVectorStoreComponent",
|
||||
"MilvusVectorStoreComponent",
|
||||
"MongoVectorStoreComponent",
|
||||
|
|
|
|||
141
src/backend/base/langflow/components/vectorstores/graph_rag.py
Normal file
141
src/backend/base/langflow/components/vectorstores/graph_rag.py
Normal file
|
|
@ -0,0 +1,141 @@
|
|||
import inspect
|
||||
from abc import ABC
|
||||
|
||||
import graph_retriever.strategies as strategies_module
|
||||
from langchain_graph_retriever import GraphRetriever
|
||||
|
||||
from langflow.base.vectorstores.model import LCVectorStoreComponent
|
||||
from langflow.helpers import docs_to_data
|
||||
from langflow.inputs import DropdownInput, HandleInput, MultilineInput, NestedDictInput, StrInput
|
||||
from langflow.schema import Data
|
||||
|
||||
|
||||
def traversal_strategies() -> list[str]:
|
||||
"""Retrieves a list of class names from the strategies_module.
|
||||
|
||||
This function uses the `inspect` module to get all the class members
|
||||
from the `strategies_module` and returns their names as a list of strings.
|
||||
|
||||
Returns:
|
||||
list[str]: A list of strategy class names.
|
||||
"""
|
||||
classes = inspect.getmembers(strategies_module, inspect.isclass)
|
||||
return [name for name, cls in classes if ABC not in cls.__bases__]
|
||||
|
||||
|
||||
class GraphRAGComponent(LCVectorStoreComponent):
|
||||
"""GraphRAGComponent is a component for performing Graph RAG traversal in a vector store.
|
||||
|
||||
Attributes:
|
||||
display_name (str): The display name of the component.
|
||||
description (str): A brief description of the component.
|
||||
name (str): The name of the component.
|
||||
icon (str): The icon representing the component.
|
||||
inputs (list): A list of input configurations for the component.
|
||||
|
||||
Methods:
|
||||
_build_search_args():
|
||||
Builds the arguments required for the search operation.
|
||||
search_documents() -> list[Data]:
|
||||
Searches for documents using the specified strategy, edge definition, and query.
|
||||
_edge_definition_from_input() -> tuple:
|
||||
Processes the edge definition input and returns it as a tuple.
|
||||
"""
|
||||
|
||||
display_name: str = "Graph RAG"
|
||||
description: str = "Graph RAG traversal for vector store."
|
||||
name = "Graph RAG"
|
||||
icon: str = "AstraDB"
|
||||
|
||||
inputs = [
|
||||
HandleInput(
|
||||
name="embedding_model",
|
||||
display_name="Embedding Model",
|
||||
input_types=["Embeddings"],
|
||||
info="Specify the Embedding Model. Not required for Astra Vectorize collections.",
|
||||
required=False,
|
||||
),
|
||||
HandleInput(
|
||||
name="vector_store",
|
||||
display_name="Vector Store Connection",
|
||||
input_types=["VectorStore"],
|
||||
info="Connection to Vector Store.",
|
||||
),
|
||||
StrInput(
|
||||
name="edge_definition",
|
||||
display_name="Edge Definition",
|
||||
info="Edge definition for the graph traversal.",
|
||||
),
|
||||
DropdownInput(
|
||||
name="strategy",
|
||||
display_name="Traversal Strategies",
|
||||
options=traversal_strategies(),
|
||||
),
|
||||
MultilineInput(
|
||||
name="search_query",
|
||||
display_name="Search Query",
|
||||
tool_mode=True,
|
||||
),
|
||||
NestedDictInput(
|
||||
name="graphrag_strategy_kwargs",
|
||||
display_name="Strategy Parameters",
|
||||
info=(
|
||||
"Optional dictionary of additional parameters for the retrieval strategy. "
|
||||
"Please see https://datastax.github.io/graph-rag/reference/graph_retriever/strategies/ for details."
|
||||
),
|
||||
advanced=True,
|
||||
),
|
||||
]
|
||||
|
||||
def search_documents(self) -> list[Data]:
|
||||
"""Searches for documents using the graph retriever based on the selected strategy, edge definition, and query.
|
||||
|
||||
Returns:
|
||||
list[Data]: A list of retrieved documents.
|
||||
|
||||
Raises:
|
||||
AttributeError: If there is an issue with attribute access.
|
||||
TypeError: If there is a type mismatch.
|
||||
ValueError: If there is a value error.
|
||||
"""
|
||||
additional_params = self.graphrag_strategy_kwargs or {}
|
||||
|
||||
# Invoke the graph retriever based on the selected strategy, edge definition, and query
|
||||
strategy_class = getattr(strategies_module, self.strategy)
|
||||
retriever = GraphRetriever(
|
||||
store=self.vector_store,
|
||||
edges=[self._evaluate_edge_definition_input()],
|
||||
strategy=strategy_class(**additional_params),
|
||||
)
|
||||
|
||||
return docs_to_data(retriever.invoke(self.search_query))
|
||||
|
||||
def _edge_definition_from_input(self) -> tuple:
|
||||
"""Generates the edge definition from the input data.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple representing the edge definition.
|
||||
"""
|
||||
values = self.edge_definition.split(",")
|
||||
values = [value.strip() for value in values]
|
||||
|
||||
return tuple(values)
|
||||
|
||||
def _evaluate_edge_definition_input(self) -> tuple:
|
||||
from graph_retriever.edges.metadata import Id
|
||||
|
||||
"""Evaluates the edge definition, converting any function calls from strings.
|
||||
|
||||
Args:
|
||||
edge_definition (tuple): The edge definition to evaluate.
|
||||
|
||||
Returns:
|
||||
tuple: The evaluated edge definition.
|
||||
"""
|
||||
evaluated_values = []
|
||||
for value in self._edge_definition_from_input():
|
||||
if value == "Id()":
|
||||
evaluated_values.append(Id()) # Evaluate Id() as a function call
|
||||
else:
|
||||
evaluated_values.append(value)
|
||||
return tuple(evaluated_values)
|
||||
|
|
@ -0,0 +1,164 @@
|
|||
import random
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from langchain_community.embeddings.fake import DeterministicFakeEmbedding
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.vectorstores.in_memory import InMemoryVectorStore
|
||||
from langflow.components.vectorstores.graph_rag import GraphRAGComponent
|
||||
|
||||
from tests.base import ComponentTestBaseWithoutClient
|
||||
|
||||
|
||||
class TestGraphRAGComponent(ComponentTestBaseWithoutClient):
|
||||
"""Test suite for the GraphRAGComponent class, focusing on graph traversal and retrieval functionality.
|
||||
|
||||
Fixtures:
|
||||
component_class: Returns the GraphRAGComponent class to be tested.
|
||||
animals: Provides a list of Document objects representing various animals with metadata.
|
||||
embedding: Provides a FakeEmbeddings instance with a specified size.
|
||||
vector_store: Initializes an InMemoryVectorStore with the provided animals and embedding.
|
||||
file_names_mapping: Returns an empty list since this component doesn't have version-specific files.
|
||||
default_kwargs: Returns an empty dictionary since this component doesn't have any default arguments.
|
||||
|
||||
Test Cases:
|
||||
test_graphrag: Tests the search_documents method of the GraphRAGComponent class by setting attributes and
|
||||
verifying the number of results returned.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def component_class(self):
|
||||
"""Return the component class to test."""
|
||||
return GraphRAGComponent
|
||||
|
||||
@pytest.fixture
|
||||
def animals(self, n: int = 20, match_prob: float = 0.3) -> list[Document]:
|
||||
"""Animals dataset for testing.
|
||||
|
||||
Generate a list of animal-related document objects with random metadata.
|
||||
|
||||
Parameters:
|
||||
n (int): Number of documents to generate.
|
||||
match_prob (float): Probability of sharing metadata across documents.
|
||||
|
||||
Returns:
|
||||
List[Document]: A list of generated Document objects.
|
||||
"""
|
||||
# Initialize Faker for generating random text
|
||||
fake = Faker()
|
||||
random.seed(42)
|
||||
fake.seed_instance(42)
|
||||
|
||||
# Define possible attributes for animals
|
||||
animal_types = ["mammal", "bird", "reptile", "insect"]
|
||||
habitats = ["savanna", "marine", "wetlands", "forest", "desert"]
|
||||
diets = ["carnivorous", "herbivorous", "omnivorous"]
|
||||
origins = ["north america", "south america", "africa", "asia", "australia"]
|
||||
|
||||
shared_metadata = {} # Common metadata that may be shared across documents
|
||||
|
||||
def update_metadata(meta: dict) -> dict:
|
||||
"""Modify metadata based on predefined conditions and probability."""
|
||||
if random.random() < match_prob: # noqa: S311
|
||||
meta.update(shared_metadata) # Apply shared metadata
|
||||
elif meta["type"] == "mammal":
|
||||
meta["habitat"] = random.choice(habitats) # noqa: S311
|
||||
elif meta["type"] == "reptile":
|
||||
meta["diet"] = random.choice(diets) # noqa: S311
|
||||
elif meta["type"] == "insect":
|
||||
meta["origin"] = random.choice(origins) # noqa: S311
|
||||
return meta
|
||||
|
||||
# Generate and return a list of documents
|
||||
return [
|
||||
Document(
|
||||
id=fake.uuid4(),
|
||||
page_content=fake.sentence(),
|
||||
metadata=update_metadata(
|
||||
{
|
||||
"type": random.choice(animal_types), # noqa: S311
|
||||
"number_of_legs": random.choice([0, 2, 4, 6, 8]), # noqa: S311
|
||||
"keywords": fake.words(random.randint(2, 5)), # noqa: S311
|
||||
# Add optional tags with 30% probability
|
||||
**(
|
||||
{
|
||||
"tags": [
|
||||
{"a": random.randint(1, 10), "b": random.randint(1, 10)} # noqa: S311
|
||||
for _ in range(random.randint(1, 2)) # noqa: S311
|
||||
]
|
||||
}
|
||||
if random.random() < 0.3 # noqa: S311
|
||||
else {}
|
||||
),
|
||||
# Add nested metadata with 20% probability
|
||||
**({"nested": {"a": random.randint(1, 10)}} if random.random() < 0.2 else {}), # noqa: S311
|
||||
}
|
||||
),
|
||||
)
|
||||
for _ in range(n)
|
||||
]
|
||||
|
||||
@pytest.fixture
|
||||
def embedding(self):
|
||||
return DeterministicFakeEmbedding(size=8)
|
||||
|
||||
@pytest.fixture
|
||||
def vector_store(self, animals: list[Document], embedding: DeterministicFakeEmbedding) -> InMemoryVectorStore:
|
||||
"""Return an empty list since this component doesn't have version-specific files."""
|
||||
store = InMemoryVectorStore(embedding=embedding)
|
||||
store.add_documents(animals)
|
||||
return store
|
||||
|
||||
@pytest.fixture
|
||||
def file_names_mapping(self):
|
||||
"""Return an empty list since this component doesn't have version-specific files."""
|
||||
|
||||
@pytest.fixture
|
||||
def default_kwargs(self):
|
||||
"""Return an empty dictionary since this component doesn't have any default arguments."""
|
||||
return {"k": 10, "start_k": 3, "max_depth": 2}
|
||||
|
||||
def test_graphrag(
|
||||
self,
|
||||
component_class: GraphRAGComponent,
|
||||
embedding: DeterministicFakeEmbedding,
|
||||
vector_store: InMemoryVectorStore,
|
||||
default_kwargs,
|
||||
):
|
||||
"""Test GraphRAGComponent's document search functionality.
|
||||
|
||||
This test verifies that the component correctly retrieves documents using the
|
||||
provided embedding model, vector store, and search query.
|
||||
|
||||
Args:
|
||||
component_class (GraphRAGComponent): The component class to test.
|
||||
embedding (FakeEmbeddings): The embedding model for the component.
|
||||
vector_store (InMemoryVectorStore): The vector store used in retrieval.
|
||||
default_kwargs (dict): Default keyword arguments for the retrieval strategy.
|
||||
|
||||
Returns:
|
||||
None: The test asserts that 10 search results are returned.
|
||||
"""
|
||||
component = component_class()
|
||||
|
||||
component.set_attributes(
|
||||
{
|
||||
"embedding_model": embedding,
|
||||
"vector_store": vector_store,
|
||||
"edge_definition": "type, type",
|
||||
"strategy": "Eager",
|
||||
"search_query": "information environment technology",
|
||||
"graphrag_strategy_kwargs": default_kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
results = component.search_documents()
|
||||
|
||||
# Quantity of documents
|
||||
assert len(results) == 10
|
||||
|
||||
# Ensures all the k-start_k documents returned via traversal have the same metadata as the
|
||||
# ones returned via the similarity search
|
||||
assert list({doc.data["type"] for doc in results if doc.data["_depth"] == 0}) == list(
|
||||
{doc.data["type"] for doc in results if doc.data["_depth"] >= 1}
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue