test: add astra integration test (#2189)

* add first astra integ test framework

* use fixtures

* remove old tests from merge

* Add correct sender type

* chore: Update unit test command in GitHub workflow

---------

Co-authored-by: ogabrielluiz <gabriel@langflow.org>
This commit is contained in:
Jordan Frazier 2024-06-15 19:50:38 -07:00 committed by GitHub
commit ca660cf8df
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
31 changed files with 211 additions and 12 deletions

View file

@ -0,0 +1,160 @@
import os
import pytest
from integration.utils import MockEmbeddings, check_env_vars
from langflow.components.memories.AstraDBMessageReader import (
AstraDBMessageReaderComponent,
)
from langflow.components.memories.AstraDBMessageWriter import (
AstraDBMessageWriterComponent,
)
from langflow.components.vectorsearch.AstraDBSearch import AstraDBSearchComponent
from langflow.components.vectorstores.AstraDB import AstraDBVectorStoreComponent
from langflow.schema.record import Record
from langchain_core.documents import Document
COLLECTION = "test_basic"
SEARCH_COLLECTION = "test_search"
MEMORY_COLLECTION = "test_memory"
@pytest.fixture()
def astra_fixture(request):
"""
Sets up the astra collection and cleans up after
"""
try:
from langchain_astradb import AstraDBVectorStore
except ImportError:
raise ImportError(
"Could not import langchain Astra DB integration package. Please install it with `pip install langchain-astradb`."
)
store = AstraDBVectorStore(
collection_name=request.param,
embedding=MockEmbeddings(),
api_endpoint=os.getenv("ASTRA_DB_API_ENDPOINT"),
token=os.getenv("ASTRA_DB_APPLICATION_TOKEN"),
)
yield
store.delete_collection()
@pytest.mark.skipif(
not check_env_vars("ASTRA_DB_APPLICATION_TOKEN", "ASTRA_DB_API_ENDPOINT"),
reason="missing astra env vars",
)
@pytest.mark.parametrize("astra_fixture", [COLLECTION], indirect=True)
def test_astra_setup(astra_fixture):
application_token = os.getenv("ASTRA_DB_APPLICATION_TOKEN")
api_endpoint = os.getenv("ASTRA_DB_API_ENDPOINT")
embedding = MockEmbeddings()
component = AstraDBVectorStoreComponent()
component.build(
token=application_token,
api_endpoint=api_endpoint,
collection_name=COLLECTION,
embedding=embedding,
)
@pytest.mark.skipif(
not check_env_vars("ASTRA_DB_APPLICATION_TOKEN", "ASTRA_DB_API_ENDPOINT"),
reason="missing astra env vars",
)
@pytest.mark.parametrize("astra_fixture", [SEARCH_COLLECTION], indirect=True)
def test_astra_embeds_and_search(astra_fixture):
application_token = os.getenv("ASTRA_DB_APPLICATION_TOKEN")
api_endpoint = os.getenv("ASTRA_DB_API_ENDPOINT")
embedding = MockEmbeddings()
documents = [Document(page_content="test1"), Document(page_content="test2")]
records = [Record.from_document(d) for d in documents]
component = AstraDBVectorStoreComponent()
component.build(
token=application_token,
api_endpoint=api_endpoint,
collection_name=SEARCH_COLLECTION,
embedding=embedding,
inputs=records,
)
component = AstraDBSearchComponent()
records = component.build(
token=application_token,
api_endpoint=api_endpoint,
collection_name=SEARCH_COLLECTION,
embedding=embedding,
input_value="test1",
number_of_results=1,
)
assert len(records) == 1
@pytest.mark.skipif(
not check_env_vars("ASTRA_DB_APPLICATION_TOKEN", "ASTRA_DB_API_ENDPOINT"),
reason="missing astra env vars",
)
def test_astra_memory():
application_token = os.getenv("ASTRA_DB_APPLICATION_TOKEN")
api_endpoint = os.getenv("ASTRA_DB_API_ENDPOINT")
writer = AstraDBMessageWriterComponent()
reader = AstraDBMessageReaderComponent()
input_value = Record.from_document(
Document(
page_content="memory1",
metadata={"session_id": 1, "sender": "human", "sender_name": "Bob"},
)
)
writer.build(
input_value=input_value,
session_id=1,
token=application_token,
api_endpoint=api_endpoint,
collection_name=MEMORY_COLLECTION,
)
# verify reading w/ same session id pulls the same record
records = reader.build(
session_id=1,
token=application_token,
api_endpoint=api_endpoint,
collection_name=MEMORY_COLLECTION,
)
assert len(records) == 1
assert isinstance(records[0], Record)
content = records[0].get_text()
assert content == "memory1"
# verify reading w/ different session id does not pull the same record
records = reader.build(
session_id=2,
token=application_token,
api_endpoint=api_endpoint,
collection_name=MEMORY_COLLECTION,
)
assert len(records) == 0
# Cleanup store - doing here rather than fixture (see https://github.com/langchain-ai/langchain-datastax/pull/36)
try:
from langchain_astradb import AstraDBVectorStore
except ImportError:
raise ImportError(
"Could not import langchain Astra DB integration package. Please install it with `pip install langchain-astradb`."
)
store = AstraDBVectorStore(
collection_name=MEMORY_COLLECTION,
embedding=MockEmbeddings(),
api_endpoint=api_endpoint,
token=application_token,
)
store.delete_collection()

View file

@ -0,0 +1,35 @@
import os
from typing import List
from langflow.field_typing import Embeddings, VectorStore
def check_env_vars(*vars):
"""
Check if all specified environment variables are set.
Args:
*vars (str): The environment variables to check.
Returns:
bool: True if all environment variables are set, False otherwise.
"""
return all(os.getenv(var) for var in vars)
class MockEmbeddings(Embeddings):
def __init__(self):
self.embedded_documents = None
self.embedded_query = None
@staticmethod
def mock_embedding(text: str):
return [len(text) / 2, len(text) / 5, len(text) / 10]
def embed_documents(self, texts: List[str]) -> List[List[float]]:
self.embedded_documents = texts
return [self.mock_embedding(text) for text in texts]
def embed_query(self, text: str) -> List[float]:
self.embedded_query = text
return self.mock_embedding(text)

0
tests/unit/__init__.py Normal file
View file