Feat: introducing Graph RAG component (#7056)

* GraphRAG retriever componet, unit test, module confiugration and extra dependencies (faker for testing and langchain-graph-retriever

* [autofix.ci] apply automated fixes

* 🔧 (test_graph_rag_component.py): Fix linting issues by adding noqa comments to ignore S311 rule for lines with random.choice and random.randint functions.

* [autofix.ci] apply automated fixes

* Updated graph retriever version and added graph component to the same branch

* Removed uv.lock from branch

* Re-added uv.lock

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: cristhianzl <cristhian.lousa@gmail.com>
Co-authored-by: Eric Hare <ericrhare@gmail.com>
This commit is contained in:
Pedro Pacheco 2025-03-18 12:55:49 -06:00 committed by GitHub
commit b5a93b4c55
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 1457 additions and 944 deletions

View file

@ -115,6 +115,8 @@ dependencies = [
"apify-client>=1.8.1",
"pylint>=3.3.4",
"ruff>=0.9.7",
"langchain-graph-retriever==0.6.1",
"graph-retriever==0.6.1",
]
[dependency-groups]
@ -158,6 +160,7 @@ dev = [
"hypothesis>=6.123.17",
"locust>=2.32.9",
"pytest-rerunfailures>=15.0",
"faker>=37.0.0",
]
[tool.uv.sources]

View file

@ -7,6 +7,7 @@ from .clickhouse import ClickhouseVectorStoreComponent
from .couchbase import CouchbaseVectorStoreComponent
from .elasticsearch import ElasticsearchVectorStoreComponent
from .faiss import FaissVectorStoreComponent
from .graph_rag import GraphRAGComponent
from .hcd import HCDVectorStoreComponent
from .milvus import MilvusVectorStoreComponent
from .mongodb_atlas import MongoVectorStoreComponent
@ -32,6 +33,7 @@ __all__ = [
"CouchbaseVectorStoreComponent",
"ElasticsearchVectorStoreComponent",
"FaissVectorStoreComponent",
"GraphRAGComponent",
"HCDVectorStoreComponent",
"MilvusVectorStoreComponent",
"MongoVectorStoreComponent",

View file

@ -0,0 +1,141 @@
import inspect
from abc import ABC
import graph_retriever.strategies as strategies_module
from langchain_graph_retriever import GraphRetriever
from langflow.base.vectorstores.model import LCVectorStoreComponent
from langflow.helpers import docs_to_data
from langflow.inputs import DropdownInput, HandleInput, MultilineInput, NestedDictInput, StrInput
from langflow.schema import Data
def traversal_strategies() -> list[str]:
"""Retrieves a list of class names from the strategies_module.
This function uses the `inspect` module to get all the class members
from the `strategies_module` and returns their names as a list of strings.
Returns:
list[str]: A list of strategy class names.
"""
classes = inspect.getmembers(strategies_module, inspect.isclass)
return [name for name, cls in classes if ABC not in cls.__bases__]
class GraphRAGComponent(LCVectorStoreComponent):
"""GraphRAGComponent is a component for performing Graph RAG traversal in a vector store.
Attributes:
display_name (str): The display name of the component.
description (str): A brief description of the component.
name (str): The name of the component.
icon (str): The icon representing the component.
inputs (list): A list of input configurations for the component.
Methods:
_build_search_args():
Builds the arguments required for the search operation.
search_documents() -> list[Data]:
Searches for documents using the specified strategy, edge definition, and query.
_edge_definition_from_input() -> tuple:
Processes the edge definition input and returns it as a tuple.
"""
display_name: str = "Graph RAG"
description: str = "Graph RAG traversal for vector store."
name = "Graph RAG"
icon: str = "AstraDB"
inputs = [
HandleInput(
name="embedding_model",
display_name="Embedding Model",
input_types=["Embeddings"],
info="Specify the Embedding Model. Not required for Astra Vectorize collections.",
required=False,
),
HandleInput(
name="vector_store",
display_name="Vector Store Connection",
input_types=["VectorStore"],
info="Connection to Vector Store.",
),
StrInput(
name="edge_definition",
display_name="Edge Definition",
info="Edge definition for the graph traversal.",
),
DropdownInput(
name="strategy",
display_name="Traversal Strategies",
options=traversal_strategies(),
),
MultilineInput(
name="search_query",
display_name="Search Query",
tool_mode=True,
),
NestedDictInput(
name="graphrag_strategy_kwargs",
display_name="Strategy Parameters",
info=(
"Optional dictionary of additional parameters for the retrieval strategy. "
"Please see https://datastax.github.io/graph-rag/reference/graph_retriever/strategies/ for details."
),
advanced=True,
),
]
def search_documents(self) -> list[Data]:
"""Searches for documents using the graph retriever based on the selected strategy, edge definition, and query.
Returns:
list[Data]: A list of retrieved documents.
Raises:
AttributeError: If there is an issue with attribute access.
TypeError: If there is a type mismatch.
ValueError: If there is a value error.
"""
additional_params = self.graphrag_strategy_kwargs or {}
# Invoke the graph retriever based on the selected strategy, edge definition, and query
strategy_class = getattr(strategies_module, self.strategy)
retriever = GraphRetriever(
store=self.vector_store,
edges=[self._evaluate_edge_definition_input()],
strategy=strategy_class(**additional_params),
)
return docs_to_data(retriever.invoke(self.search_query))
def _edge_definition_from_input(self) -> tuple:
"""Generates the edge definition from the input data.
Returns:
tuple: A tuple representing the edge definition.
"""
values = self.edge_definition.split(",")
values = [value.strip() for value in values]
return tuple(values)
def _evaluate_edge_definition_input(self) -> tuple:
from graph_retriever.edges.metadata import Id
"""Evaluates the edge definition, converting any function calls from strings.
Args:
edge_definition (tuple): The edge definition to evaluate.
Returns:
tuple: The evaluated edge definition.
"""
evaluated_values = []
for value in self._edge_definition_from_input():
if value == "Id()":
evaluated_values.append(Id()) # Evaluate Id() as a function call
else:
evaluated_values.append(value)
return tuple(evaluated_values)

View file

@ -0,0 +1,164 @@
import random
import pytest
from faker import Faker
from langchain_community.embeddings.fake import DeterministicFakeEmbedding
from langchain_core.documents import Document
from langchain_core.vectorstores.in_memory import InMemoryVectorStore
from langflow.components.vectorstores.graph_rag import GraphRAGComponent
from tests.base import ComponentTestBaseWithoutClient
class TestGraphRAGComponent(ComponentTestBaseWithoutClient):
"""Test suite for the GraphRAGComponent class, focusing on graph traversal and retrieval functionality.
Fixtures:
component_class: Returns the GraphRAGComponent class to be tested.
animals: Provides a list of Document objects representing various animals with metadata.
embedding: Provides a FakeEmbeddings instance with a specified size.
vector_store: Initializes an InMemoryVectorStore with the provided animals and embedding.
file_names_mapping: Returns an empty list since this component doesn't have version-specific files.
default_kwargs: Returns an empty dictionary since this component doesn't have any default arguments.
Test Cases:
test_graphrag: Tests the search_documents method of the GraphRAGComponent class by setting attributes and
verifying the number of results returned.
"""
@pytest.fixture
def component_class(self):
"""Return the component class to test."""
return GraphRAGComponent
@pytest.fixture
def animals(self, n: int = 20, match_prob: float = 0.3) -> list[Document]:
"""Animals dataset for testing.
Generate a list of animal-related document objects with random metadata.
Parameters:
n (int): Number of documents to generate.
match_prob (float): Probability of sharing metadata across documents.
Returns:
List[Document]: A list of generated Document objects.
"""
# Initialize Faker for generating random text
fake = Faker()
random.seed(42)
fake.seed_instance(42)
# Define possible attributes for animals
animal_types = ["mammal", "bird", "reptile", "insect"]
habitats = ["savanna", "marine", "wetlands", "forest", "desert"]
diets = ["carnivorous", "herbivorous", "omnivorous"]
origins = ["north america", "south america", "africa", "asia", "australia"]
shared_metadata = {} # Common metadata that may be shared across documents
def update_metadata(meta: dict) -> dict:
"""Modify metadata based on predefined conditions and probability."""
if random.random() < match_prob: # noqa: S311
meta.update(shared_metadata) # Apply shared metadata
elif meta["type"] == "mammal":
meta["habitat"] = random.choice(habitats) # noqa: S311
elif meta["type"] == "reptile":
meta["diet"] = random.choice(diets) # noqa: S311
elif meta["type"] == "insect":
meta["origin"] = random.choice(origins) # noqa: S311
return meta
# Generate and return a list of documents
return [
Document(
id=fake.uuid4(),
page_content=fake.sentence(),
metadata=update_metadata(
{
"type": random.choice(animal_types), # noqa: S311
"number_of_legs": random.choice([0, 2, 4, 6, 8]), # noqa: S311
"keywords": fake.words(random.randint(2, 5)), # noqa: S311
# Add optional tags with 30% probability
**(
{
"tags": [
{"a": random.randint(1, 10), "b": random.randint(1, 10)} # noqa: S311
for _ in range(random.randint(1, 2)) # noqa: S311
]
}
if random.random() < 0.3 # noqa: S311
else {}
),
# Add nested metadata with 20% probability
**({"nested": {"a": random.randint(1, 10)}} if random.random() < 0.2 else {}), # noqa: S311
}
),
)
for _ in range(n)
]
@pytest.fixture
def embedding(self):
return DeterministicFakeEmbedding(size=8)
@pytest.fixture
def vector_store(self, animals: list[Document], embedding: DeterministicFakeEmbedding) -> InMemoryVectorStore:
"""Return an empty list since this component doesn't have version-specific files."""
store = InMemoryVectorStore(embedding=embedding)
store.add_documents(animals)
return store
@pytest.fixture
def file_names_mapping(self):
"""Return an empty list since this component doesn't have version-specific files."""
@pytest.fixture
def default_kwargs(self):
"""Return an empty dictionary since this component doesn't have any default arguments."""
return {"k": 10, "start_k": 3, "max_depth": 2}
def test_graphrag(
self,
component_class: GraphRAGComponent,
embedding: DeterministicFakeEmbedding,
vector_store: InMemoryVectorStore,
default_kwargs,
):
"""Test GraphRAGComponent's document search functionality.
This test verifies that the component correctly retrieves documents using the
provided embedding model, vector store, and search query.
Args:
component_class (GraphRAGComponent): The component class to test.
embedding (FakeEmbeddings): The embedding model for the component.
vector_store (InMemoryVectorStore): The vector store used in retrieval.
default_kwargs (dict): Default keyword arguments for the retrieval strategy.
Returns:
None: The test asserts that 10 search results are returned.
"""
component = component_class()
component.set_attributes(
{
"embedding_model": embedding,
"vector_store": vector_store,
"edge_definition": "type, type",
"strategy": "Eager",
"search_query": "information environment technology",
"graphrag_strategy_kwargs": default_kwargs,
}
)
results = component.search_documents()
# Quantity of documents
assert len(results) == 10
# Ensures all the k-start_k documents returned via traversal have the same metadata as the
# ones returned via the similarity search
assert list({doc.data["type"] for doc in results if doc.data["_depth"] == 0}) == list(
{doc.data["type"] for doc in results if doc.data["_depth"] >= 1}
)

2091
uv.lock generated

File diff suppressed because it is too large Load diff