diff --git a/tests/integration/astra/test_astra_component.py b/tests/integration/astra/test_astra_component.py index 161e6436c..2a67924c4 100644 --- a/tests/integration/astra/test_astra_component.py +++ b/tests/integration/astra/test_astra_component.py @@ -1,20 +1,15 @@ import os + import pytest +from langchain_core.documents import Document +from langflow.components.memories.AstraDBMessageReader import AstraDBMessageReaderComponent +from langflow.components.memories.AstraDBMessageWriter import AstraDBMessageWriterComponent +from langflow.components.vectorsearch.AstraDBSearch import AstraDBSearchComponent +from langflow.components.vectorstores.AstraDB import AstraVectorStoreComponent +from langflow.schema.data import Data from integration.utils import MockEmbeddings, check_env_vars -from langflow.components.memories.AstraDBMessageReader import ( - AstraDBMessageReaderComponent, -) -from langflow.components.memories.AstraDBMessageWriter import ( - AstraDBMessageWriterComponent, -) -from langflow.components.vectorsearch.AstraDBSearch import AstraDBSearchComponent -from langflow.components.vectorstores.AstraDB import AstraDBVectorStoreComponent -from langflow.schema.record import Record - -from langchain_core.documents import Document - COLLECTION = "test_basic" SEARCH_COLLECTION = "test_search" MEMORY_COLLECTION = "test_memory" @@ -54,7 +49,7 @@ def test_astra_setup(astra_fixture): api_endpoint = os.getenv("ASTRA_DB_API_ENDPOINT") embedding = MockEmbeddings() - component = AstraDBVectorStoreComponent() + component = AstraVectorStoreComponent() component.build( token=application_token, api_endpoint=api_endpoint, @@ -74,9 +69,9 @@ def test_astra_embeds_and_search(astra_fixture): embedding = MockEmbeddings() documents = [Document(page_content="test1"), Document(page_content="test2")] - records = [Record.from_document(d) for d in documents] + records = [Data.from_document(d) for d in documents] - component = AstraDBVectorStoreComponent() + component = AstraVectorStoreComponent() component.build( token=application_token, api_endpoint=api_endpoint, @@ -109,7 +104,7 @@ def test_astra_memory(): writer = AstraDBMessageWriterComponent() reader = AstraDBMessageReaderComponent() - input_value = Record.from_document( + input_value = Data.from_document( Document( page_content="memory1", metadata={"session_id": 1, "sender": "human", "sender_name": "Bob"}, @@ -131,7 +126,7 @@ def test_astra_memory(): collection_name=MEMORY_COLLECTION, ) assert len(records) == 1 - assert isinstance(records[0], Record) + assert isinstance(records[0], Data) content = records[0].get_text() assert content == "memory1"