feat: Add CassandraGraphVectorStoreComponent and HtmlLinkExtractorComponent (#3757)

* Add CassandraGraphVectorStoreComponent and HtmlLinkExtractorComponent

* Move uuid import to global imports

* fix test with new text spliter

* update poetry lock

* ci: add continue-on-error to py_autofix.yml

---------

Co-authored-by: Ítalo Johnny <italojohnnydosanjos@gmail.com>
Co-authored-by: anovazzi1 <otavio2204@gmail.com>
This commit is contained in:
Christophe Bornet 2024-09-11 19:56:25 +02:00 committed by GitHub
commit 7273a6e78a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 1008 additions and 666 deletions

View file

@ -42,3 +42,4 @@ jobs:
- uses: autofix-ci/action@dd55f44df8f7cdb7a6bf74c78677eb8acd40cd0a
- name: Diff poetry.lock
uses: nborrmann/diff-poetry-lock@main
continue-on-error: true

1296
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -111,6 +111,7 @@ jq = "^1.8.0"
clickhouse-connect = {version = "0.7.19", optional = true, extras = ["clickhouse-connect"]}
langchain-unstructured = "^0.1.2"
pydantic-settings = "2.4.0"
ragstack-ai-knowledge-store = "^0.2.1"
[tool.poetry.group.dev.dependencies]

View file

@ -0,0 +1,49 @@
from abc import abstractmethod
from typing import Any
from langchain_core.documents import BaseDocumentTransformer
from langflow.custom import Component
from langflow.io import Output
from langflow.schema import Data
from langflow.utils.util import build_loader_repr_from_data
class LCDocumentTransformerComponent(Component):
trace_type = "document_transformer"
outputs = [
Output(display_name="Data", name="data", method="transform_data"),
]
def transform_data(self) -> list[Data]:
data_input = self.get_data_input()
documents = []
if not isinstance(data_input, list):
data_input = [data_input]
for _input in data_input:
if isinstance(_input, Data):
documents.append(_input.to_lc_document())
else:
documents.append(_input)
transformer = self.build_document_transformer()
docs = transformer.transform_documents(documents)
data = self.to_data(docs)
self.repr_value = build_loader_repr_from_data(data)
return data
@abstractmethod
def get_data_input(self) -> Any:
"""
Get the data input.
"""
pass
@abstractmethod
def build_document_transformer(self) -> BaseDocumentTransformer:
"""
Build the text splitter.
"""
pass

View file

@ -1,15 +1,14 @@
from abc import abstractmethod
from typing import Any
from langchain_core.documents import BaseDocumentTransformer
from langchain_text_splitters import TextSplitter
from langflow.custom import Component
from langflow.base.document_transformers.model import LCDocumentTransformerComponent
from langflow.io import Output
from langflow.schema import Data
from langflow.utils.util import build_loader_repr_from_data
class LCTextSplitterComponent(Component):
class LCTextSplitterComponent(LCDocumentTransformerComponent):
trace_type = "text_splitter"
outputs = [
Output(display_name="Data", name="data", method="split_data"),
@ -25,30 +24,10 @@ class LCTextSplitterComponent(Component):
raise ValueError(f"Method '{method_name}' must be defined.")
def split_data(self) -> list[Data]:
data_input = self.get_data_input()
documents = []
return self.transform_data()
if not isinstance(data_input, list):
data_input = [data_input]
for _input in data_input:
if isinstance(_input, Data):
documents.append(_input.to_lc_document())
else:
documents.append(_input)
splitter = self.build_text_splitter()
docs = splitter.split_documents(documents)
data = self.to_data(docs)
self.repr_value = build_loader_repr_from_data(data)
return data
@abstractmethod
def get_data_input(self) -> Any:
"""
Get the data input.
"""
pass
def build_document_transformer(self) -> BaseDocumentTransformer:
return self.build_text_splitter()
@abstractmethod
def build_text_splitter(self) -> TextSplitter:

View file

@ -5,6 +5,7 @@ from . import (
embeddings,
helpers,
inputs,
link_extractors,
memories,
models,
outputs,
@ -27,6 +28,7 @@ __all__ = [
"models",
"helpers",
"inputs",
"link_extractors",
"memories",
"outputs",
"retrievers",

View file

@ -0,0 +1,33 @@
from typing import Any
from langchain_community.graph_vectorstores.extractors import LinkExtractorTransformer, HtmlLinkExtractor
from langchain_core.documents import BaseDocumentTransformer
from langflow.base.document_transformers.model import LCDocumentTransformerComponent
from langflow.inputs import DataInput, StrInput, BoolInput
class HtmlLinkExtractorComponent(LCDocumentTransformerComponent):
display_name = "HTML Link Extractor"
description = "Extract hyperlinks from HTML content."
documentation = "https://python.langchain.com/v0.2/api_reference/community/graph_vectorstores/langchain_community.graph_vectorstores.extractors.html_link_extractor.HtmlLinkExtractor.html"
name = "HtmlLinkExtractor"
inputs = [
StrInput(name="kind", display_name="Kind of edge", value="hyperlink", required=False),
BoolInput(name="drop_fragments", display_name="Drop URL fragments", value=True, required=False),
DataInput(
name="data_input",
display_name="Input",
info="The texts from which to extract links.",
input_types=["Document", "Data"],
),
]
def get_data_input(self) -> Any:
return self.data_input
def build_document_transformer(self) -> BaseDocumentTransformer:
return LinkExtractorTransformer(
[HtmlLinkExtractor(kind=self.kind, drop_fragments=self.drop_fragments).as_document_extractor()]
)

View file

@ -0,0 +1,5 @@
from .HtmlLinkExtractor import HtmlLinkExtractorComponent
__all__ = [
"HtmlLinkExtractorComponent",
]

View file

@ -0,0 +1,249 @@
from typing import List
from langchain_community.graph_vectorstores import CassandraGraphVectorStore
from loguru import logger
from uuid import UUID
from langflow.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store
from langflow.helpers.data import docs_to_data
from langflow.inputs import DictInput, FloatInput
from langflow.io import (
DataInput,
DropdownInput,
HandleInput,
IntInput,
MessageTextInput,
MultilineInput,
SecretStrInput,
)
from langflow.schema import Data
class CassandraGraphVectorStoreComponent(LCVectorStoreComponent):
display_name = "Cassandra Graph"
description = "Cassandra Graph Vector Store"
documentation = "https://python.langchain.com/v0.2/api_reference/community/graph_vectorstores.html"
name = "CassandraGraph"
icon = "Cassandra"
inputs = [
MessageTextInput(
name="database_ref",
display_name="Contact Points / Astra Database ID",
info="Contact points for the database (or AstraDB database ID)",
required=True,
),
MessageTextInput(
name="username", display_name="Username", info="Username for the database (leave empty for AstraDB)."
),
SecretStrInput(
name="token",
display_name="Password / AstraDB Token",
info="User password for the database (or AstraDB token).",
required=True,
),
MessageTextInput(
name="keyspace",
display_name="Keyspace",
info="Table Keyspace (or AstraDB namespace).",
required=True,
),
MessageTextInput(
name="table_name",
display_name="Table Name",
info="The name of the table (or AstraDB collection) where vectors will be stored.",
required=True,
),
DropdownInput(
name="setup_mode",
display_name="Setup Mode",
info="Configuration mode for setting up the Cassandra table, with options like 'Sync' or 'Off'.",
options=["Sync", "Off"],
value="Sync",
advanced=True,
),
DictInput(
name="cluster_kwargs",
display_name="Cluster arguments",
info="Optional dictionary of additional keyword arguments for the Cassandra cluster.",
advanced=True,
is_list=True,
),
MultilineInput(name="search_query", display_name="Search Query"),
DataInput(
name="ingest_data",
display_name="Ingest Data",
is_list=True,
),
HandleInput(name="embedding", display_name="Embedding", input_types=["Embeddings"]),
IntInput(
name="number_of_results",
display_name="Number of Results",
info="Number of results to return.",
value=4,
advanced=True,
),
DropdownInput(
name="search_type",
display_name="Search Type",
info="Search type to use",
options=[
"Traversal",
"MMR traversal",
"Similarity",
"Similarity with score threshold",
"MMR (Max Marginal Relevance)",
],
value="Traversal",
advanced=True,
),
IntInput(
name="depth",
display_name="Depth of traversal",
info="The maximum depth of edges to traverse. (when using 'Traversal' or 'MMR traversal')",
value=1,
advanced=True,
),
FloatInput(
name="search_score_threshold",
display_name="Search Score Threshold",
info="Minimum similarity score threshold for search results. (when using 'Similarity with score threshold')",
value=0,
advanced=True,
),
DictInput(
name="search_filter",
display_name="Search Metadata Filter",
info="Optional dictionary of filters to apply to the search query.",
advanced=True,
is_list=True,
),
]
@check_cached_vector_store
def build_vector_store(self) -> CassandraGraphVectorStore:
try:
import cassio
from langchain_community.utilities.cassandra import SetupMode
except ImportError:
raise ImportError(
"Could not import cassio integration package. " "Please install it with `pip install cassio`."
)
database_ref = self.database_ref
try:
UUID(self.database_ref)
is_astra = True
except ValueError:
is_astra = False
if "," in self.database_ref:
# use a copy because we can't change the type of the parameter
database_ref = self.database_ref.split(",")
if is_astra:
cassio.init(
database_id=database_ref,
token=self.token,
cluster_kwargs=self.cluster_kwargs,
)
else:
cassio.init(
contact_points=database_ref,
username=self.username,
password=self.token,
cluster_kwargs=self.cluster_kwargs,
)
documents = []
for _input in self.ingest_data or []:
if isinstance(_input, Data):
documents.append(_input.to_lc_document())
else:
documents.append(_input)
if self.setup_mode == "Off":
setup_mode = SetupMode.OFF
else:
setup_mode = SetupMode.SYNC
if documents:
logger.debug(f"Adding {len(documents)} documents to the Vector Store.")
store = CassandraGraphVectorStore.from_documents(
documents=documents,
embedding=self.embedding,
node_table=self.table_name,
keyspace=self.keyspace,
)
else:
logger.debug("No documents to add to the Vector Store.")
store = CassandraGraphVectorStore(
embedding=self.embedding,
node_table=self.table_name,
keyspace=self.keyspace,
setup_mode=setup_mode,
)
return store
def _map_search_type(self):
if self.search_type == "Similarity":
return "similarity"
elif self.search_type == "Similarity with score threshold":
return "similarity_score_threshold"
elif self.search_type == "MMR (Max Marginal Relevance)":
return "mmr"
elif self.search_type == "MMR Traversal":
return "mmr_traversal"
else:
return "traversal"
def search_documents(self) -> List[Data]:
vector_store = self.build_vector_store()
logger.debug(f"Search input: {self.search_query}")
logger.debug(f"Search type: {self.search_type}")
logger.debug(f"Number of results: {self.number_of_results}")
if self.search_query and isinstance(self.search_query, str) and self.search_query.strip():
try:
search_type = self._map_search_type()
search_args = self._build_search_args()
logger.debug(f"Search args: {str(search_args)}")
docs = vector_store.search(query=self.search_query, search_type=search_type, **search_args)
except KeyError as e:
if "content" in str(e):
raise ValueError(
"You should ingest data through Langflow (or LangChain) to query it in Langflow. Your collection does not contain a field name 'content'."
) from e
else:
raise e
logger.debug(f"Retrieved documents: {len(docs)}")
data = docs_to_data(docs)
self.status = data
return data
else:
return []
def _build_search_args(self):
args = {
"k": self.number_of_results,
"score_threshold": self.search_score_threshold,
"depth": self.depth,
}
if self.search_filter:
clean_filter = {k: v for k, v in self.search_filter.items() if k and v}
if len(clean_filter) > 0:
args["filter"] = clean_filter
return args
def get_retriever_kwargs(self):
search_args = self._build_search_args()
return {
"search_type": self._map_search_type(),
"search_kwargs": search_args,
}

View file

@ -86,6 +86,7 @@ import {
Laptop2,
Layers,
Link,
Link2,
Loader2,
Lock,
LogIn,
@ -340,6 +341,7 @@ export const nodeNames: { [char: string]: string } = {
langchain_utilities: "Utilities",
output_parsers: "Output Parsers",
custom_components: "Custom",
link_extractors: "Link Extractors",
unknown: "Other",
};
@ -454,6 +456,7 @@ export const nodeIconsLucide: iconsType = {
saved_components: GradientSave,
ScrollText,
documentloaders: Paperclip,
link_extractors: Link2,
vectorstores: Layers,
vectorsearch: TextSearch,
toolkits: Package2,

View file

@ -112,7 +112,6 @@ test("user must see on handle click the possibility connections - LLMChain", asy
"disclosure-utilities",
"disclosure-prototypes",
"disclosure-retrievers",
"disclosure-text splitters",
"disclosure-tools",
];
@ -123,7 +122,6 @@ test("user must see on handle click the possibility connections - LLMChain", asy
"toolsSearch API",
"prototypesSub Flow",
"retrieversSelf Query Retriever",
"textsplittersCharacterTextSplitter",
];
await Promise.all(