test: add astra integration test (#2189)
* add first astra integ test framework * use fixtures * remove old tests from merge * Add correct sender type * chore: Update unit test command in GitHub workflow --------- Co-authored-by: ogabrielluiz <gabriel@langflow.org>
This commit is contained in:
parent
5a04adfa1f
commit
ca660cf8df
31 changed files with 211 additions and 12 deletions
2
.github/workflows/python_test.yml
vendored
2
.github/workflows/python_test.yml
vendored
|
|
@ -42,4 +42,4 @@ jobs:
|
|||
poetry install
|
||||
- name: Run unit tests
|
||||
run: |
|
||||
make tests args="-n auto"
|
||||
make unit_tests args="-n auto"
|
||||
|
|
|
|||
6
Makefile
6
Makefile
|
|
@ -55,9 +55,11 @@ coverage: ## run the tests and generate a coverage report
|
|||
|
||||
|
||||
# allow passing arguments to pytest
|
||||
tests: ## run the tests
|
||||
poetry run pytest tests --instafail -ra -n auto -m "not api_key_required" $(args)
|
||||
unit_tests:
|
||||
poetry run pytest tests/unit --instafail -ra -n auto -m "not api_key_required" $(args)
|
||||
|
||||
integration_tests:
|
||||
poetry run pytest tests/integration --instafail -ra -n auto $(args)
|
||||
|
||||
format: ## run code formatters
|
||||
poetry run ruff check . --fix
|
||||
|
|
|
|||
10
poetry.lock
generated
10
poetry.lock
generated
|
|
@ -2438,8 +2438,8 @@ files = [
|
|||
[package.dependencies]
|
||||
cffi = {version = ">=1.12.2", markers = "platform_python_implementation == \"CPython\" and sys_platform == \"win32\""}
|
||||
greenlet = [
|
||||
{version = ">=2.0.0", markers = "platform_python_implementation == \"CPython\" and python_version < \"3.11\""},
|
||||
{version = ">=3.0rc3", markers = "platform_python_implementation == \"CPython\" and python_version >= \"3.11\""},
|
||||
{version = ">=2.0.0", markers = "platform_python_implementation == \"CPython\" and python_version < \"3.11\""},
|
||||
]
|
||||
"zope.event" = "*"
|
||||
"zope.interface" = "*"
|
||||
|
|
@ -2566,12 +2566,12 @@ files = [
|
|||
google-auth = ">=2.14.1,<3.0.dev0"
|
||||
googleapis-common-protos = ">=1.56.2,<2.0.dev0"
|
||||
grpcio = [
|
||||
{version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""},
|
||||
{version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
|
||||
{version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""},
|
||||
]
|
||||
grpcio-status = [
|
||||
{version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""},
|
||||
{version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
|
||||
{version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""},
|
||||
]
|
||||
proto-plus = ">=1.22.3,<2.0.0dev"
|
||||
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0"
|
||||
|
|
@ -4550,8 +4550,8 @@ psutil = ">=5.9.1"
|
|||
pywin32 = {version = "*", markers = "platform_system == \"Windows\""}
|
||||
pyzmq = ">=25.0.0"
|
||||
requests = [
|
||||
{version = ">=2.26.0", markers = "python_version <= \"3.11\""},
|
||||
{version = ">=2.32.2", markers = "python_version > \"3.11\""},
|
||||
{version = ">=2.26.0", markers = "python_version <= \"3.11\""},
|
||||
]
|
||||
tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""}
|
||||
Werkzeug = ">=2.0.0"
|
||||
|
|
@ -6019,9 +6019,9 @@ files = [
|
|||
|
||||
[package.dependencies]
|
||||
numpy = [
|
||||
{version = ">=1.26.0,<2", markers = "python_version >= \"3.12\""},
|
||||
{version = ">=1.22.4,<2", markers = "python_version < \"3.11\""},
|
||||
{version = ">=1.23.2,<2", markers = "python_version == \"3.11\""},
|
||||
{version = ">=1.26.0,<2", markers = "python_version >= \"3.12\""},
|
||||
]
|
||||
python-dateutil = ">=2.8.2"
|
||||
pytz = ">=2020.1"
|
||||
|
|
|
|||
|
|
@ -89,6 +89,7 @@ class AstraDBMessageWriterComponent(BaseMemoryComponent):
|
|||
sender_name=sender_name,
|
||||
metadata=metadata,
|
||||
session_id=session_id,
|
||||
type=sender,
|
||||
)
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -163,4 +163,3 @@ class AstraDBVectorStoreComponent(CustomComponent):
|
|||
)
|
||||
|
||||
return vector_store
|
||||
return vector_store
|
||||
|
|
|
|||
|
|
@ -4,19 +4,21 @@ from typing import TYPE_CHECKING, List, Optional, Union
|
|||
|
||||
import duckdb
|
||||
from langflow.services.base import Service
|
||||
from langflow.services.monitor.schema import MessageModel, TransactionModel, VertexBuildModel
|
||||
from langflow.services.monitor.utils import add_row_to_table, drop_and_create_table_if_schema_mismatch
|
||||
from loguru import logger
|
||||
from platformdirs import user_cache_dir
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.services.settings.manager import SettingsService
|
||||
from langflow.services.monitor.schema import MessageModel, TransactionModel, VertexBuildModel
|
||||
|
||||
|
||||
class MonitorService(Service):
|
||||
name = "monitor_service"
|
||||
|
||||
def __init__(self, settings_service: "SettingsService"):
|
||||
from langflow.services.monitor.schema import MessageModel, TransactionModel, VertexBuildModel
|
||||
|
||||
self.settings_service = settings_service
|
||||
self.base_cache_dir = Path(user_cache_dir("langflow"))
|
||||
self.db_path = self.base_cache_dir / "monitor.duckdb"
|
||||
|
|
@ -45,7 +47,7 @@ class MonitorService(Service):
|
|||
def add_row(
|
||||
self,
|
||||
table_name: str,
|
||||
data: Union[dict, TransactionModel, MessageModel, VertexBuildModel],
|
||||
data: Union[dict, "TransactionModel", "MessageModel", "VertexBuildModel"],
|
||||
):
|
||||
# Make sure the model passed matches the table
|
||||
|
||||
|
|
@ -127,7 +129,7 @@ class MonitorService(Service):
|
|||
|
||||
return self.exec_query(query, read_only=False)
|
||||
|
||||
def add_message(self, message: MessageModel):
|
||||
def add_message(self, message: "MessageModel"):
|
||||
self.add_row("messages", message)
|
||||
|
||||
def get_messages(
|
||||
|
|
|
|||
160
tests/integration/astra/test_astra_component.py
Normal file
160
tests/integration/astra/test_astra_component.py
Normal file
|
|
@ -0,0 +1,160 @@
|
|||
import os
|
||||
import pytest
|
||||
|
||||
from integration.utils import MockEmbeddings, check_env_vars
|
||||
|
||||
from langflow.components.memories.AstraDBMessageReader import (
|
||||
AstraDBMessageReaderComponent,
|
||||
)
|
||||
from langflow.components.memories.AstraDBMessageWriter import (
|
||||
AstraDBMessageWriterComponent,
|
||||
)
|
||||
from langflow.components.vectorsearch.AstraDBSearch import AstraDBSearchComponent
|
||||
from langflow.components.vectorstores.AstraDB import AstraDBVectorStoreComponent
|
||||
from langflow.schema.record import Record
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
COLLECTION = "test_basic"
|
||||
SEARCH_COLLECTION = "test_search"
|
||||
MEMORY_COLLECTION = "test_memory"
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def astra_fixture(request):
|
||||
"""
|
||||
Sets up the astra collection and cleans up after
|
||||
"""
|
||||
try:
|
||||
from langchain_astradb import AstraDBVectorStore
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import langchain Astra DB integration package. Please install it with `pip install langchain-astradb`."
|
||||
)
|
||||
|
||||
store = AstraDBVectorStore(
|
||||
collection_name=request.param,
|
||||
embedding=MockEmbeddings(),
|
||||
api_endpoint=os.getenv("ASTRA_DB_API_ENDPOINT"),
|
||||
token=os.getenv("ASTRA_DB_APPLICATION_TOKEN"),
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
store.delete_collection()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not check_env_vars("ASTRA_DB_APPLICATION_TOKEN", "ASTRA_DB_API_ENDPOINT"),
|
||||
reason="missing astra env vars",
|
||||
)
|
||||
@pytest.mark.parametrize("astra_fixture", [COLLECTION], indirect=True)
|
||||
def test_astra_setup(astra_fixture):
|
||||
application_token = os.getenv("ASTRA_DB_APPLICATION_TOKEN")
|
||||
api_endpoint = os.getenv("ASTRA_DB_API_ENDPOINT")
|
||||
embedding = MockEmbeddings()
|
||||
|
||||
component = AstraDBVectorStoreComponent()
|
||||
component.build(
|
||||
token=application_token,
|
||||
api_endpoint=api_endpoint,
|
||||
collection_name=COLLECTION,
|
||||
embedding=embedding,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not check_env_vars("ASTRA_DB_APPLICATION_TOKEN", "ASTRA_DB_API_ENDPOINT"),
|
||||
reason="missing astra env vars",
|
||||
)
|
||||
@pytest.mark.parametrize("astra_fixture", [SEARCH_COLLECTION], indirect=True)
|
||||
def test_astra_embeds_and_search(astra_fixture):
|
||||
application_token = os.getenv("ASTRA_DB_APPLICATION_TOKEN")
|
||||
api_endpoint = os.getenv("ASTRA_DB_API_ENDPOINT")
|
||||
embedding = MockEmbeddings()
|
||||
|
||||
documents = [Document(page_content="test1"), Document(page_content="test2")]
|
||||
records = [Record.from_document(d) for d in documents]
|
||||
|
||||
component = AstraDBVectorStoreComponent()
|
||||
component.build(
|
||||
token=application_token,
|
||||
api_endpoint=api_endpoint,
|
||||
collection_name=SEARCH_COLLECTION,
|
||||
embedding=embedding,
|
||||
inputs=records,
|
||||
)
|
||||
|
||||
component = AstraDBSearchComponent()
|
||||
records = component.build(
|
||||
token=application_token,
|
||||
api_endpoint=api_endpoint,
|
||||
collection_name=SEARCH_COLLECTION,
|
||||
embedding=embedding,
|
||||
input_value="test1",
|
||||
number_of_results=1,
|
||||
)
|
||||
|
||||
assert len(records) == 1
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not check_env_vars("ASTRA_DB_APPLICATION_TOKEN", "ASTRA_DB_API_ENDPOINT"),
|
||||
reason="missing astra env vars",
|
||||
)
|
||||
def test_astra_memory():
|
||||
application_token = os.getenv("ASTRA_DB_APPLICATION_TOKEN")
|
||||
api_endpoint = os.getenv("ASTRA_DB_API_ENDPOINT")
|
||||
|
||||
writer = AstraDBMessageWriterComponent()
|
||||
reader = AstraDBMessageReaderComponent()
|
||||
|
||||
input_value = Record.from_document(
|
||||
Document(
|
||||
page_content="memory1",
|
||||
metadata={"session_id": 1, "sender": "human", "sender_name": "Bob"},
|
||||
)
|
||||
)
|
||||
writer.build(
|
||||
input_value=input_value,
|
||||
session_id=1,
|
||||
token=application_token,
|
||||
api_endpoint=api_endpoint,
|
||||
collection_name=MEMORY_COLLECTION,
|
||||
)
|
||||
|
||||
# verify reading w/ same session id pulls the same record
|
||||
records = reader.build(
|
||||
session_id=1,
|
||||
token=application_token,
|
||||
api_endpoint=api_endpoint,
|
||||
collection_name=MEMORY_COLLECTION,
|
||||
)
|
||||
assert len(records) == 1
|
||||
assert isinstance(records[0], Record)
|
||||
content = records[0].get_text()
|
||||
assert content == "memory1"
|
||||
|
||||
# verify reading w/ different session id does not pull the same record
|
||||
records = reader.build(
|
||||
session_id=2,
|
||||
token=application_token,
|
||||
api_endpoint=api_endpoint,
|
||||
collection_name=MEMORY_COLLECTION,
|
||||
)
|
||||
assert len(records) == 0
|
||||
|
||||
# Cleanup store - doing here rather than fixture (see https://github.com/langchain-ai/langchain-datastax/pull/36)
|
||||
try:
|
||||
from langchain_astradb import AstraDBVectorStore
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import langchain Astra DB integration package. Please install it with `pip install langchain-astradb`."
|
||||
)
|
||||
store = AstraDBVectorStore(
|
||||
collection_name=MEMORY_COLLECTION,
|
||||
embedding=MockEmbeddings(),
|
||||
api_endpoint=api_endpoint,
|
||||
token=application_token,
|
||||
)
|
||||
store.delete_collection()
|
||||
35
tests/integration/utils.py
Normal file
35
tests/integration/utils.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
import os
|
||||
from typing import List
|
||||
|
||||
from langflow.field_typing import Embeddings, VectorStore
|
||||
|
||||
|
||||
def check_env_vars(*vars):
|
||||
"""
|
||||
Check if all specified environment variables are set.
|
||||
|
||||
Args:
|
||||
*vars (str): The environment variables to check.
|
||||
|
||||
Returns:
|
||||
bool: True if all environment variables are set, False otherwise.
|
||||
"""
|
||||
return all(os.getenv(var) for var in vars)
|
||||
|
||||
|
||||
class MockEmbeddings(Embeddings):
|
||||
def __init__(self):
|
||||
self.embedded_documents = None
|
||||
self.embedded_query = None
|
||||
|
||||
@staticmethod
|
||||
def mock_embedding(text: str):
|
||||
return [len(text) / 2, len(text) / 5, len(text) / 10]
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
self.embedded_documents = texts
|
||||
return [self.mock_embedding(text) for text in texts]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
self.embedded_query = text
|
||||
return self.mock_embedding(text)
|
||||
0
tests/unit/__init__.py
Normal file
0
tests/unit/__init__.py
Normal file
Loading…
Add table
Add a link
Reference in a new issue