diff --git a/src/backend/base/langflow/components/memories/AstraDBMessageReader.py b/src/backend/base/langflow/components/memories/AstraDBMessageReader.py new file mode 100644 index 000000000..9b82dd308 --- /dev/null +++ b/src/backend/base/langflow/components/memories/AstraDBMessageReader.py @@ -0,0 +1,95 @@ +from typing import Optional, cast + +from langchain_astradb.chat_message_histories import AstraDBChatMessageHistory + +from langflow.base.memory.memory import BaseMemoryComponent +from langflow.field_typing import Text +from langflow.schema.schema import Record + + +class AstraDBMessageReaderComponent(BaseMemoryComponent): + display_name = "Astra DB Message Reader" + description = "Retrieves stored chat messages from Astra DB." + + def build_config(self): + return { + "session_id": { + "display_name": "Session ID", + "info": "Session ID of the chat history.", + "input_types": ["Text"], + }, + "collection_name": { + "display_name": "Collection Name", + "info": "Collection name for Astra DB.", + "input_types": ["Text"], + }, + "token": { + "display_name": "Astra DB Application Token", + "info": "Token for the Astra DB instance.", + "password": True, + }, + "api_endpoint": { + "display_name": "Astra DB API Endpoint", + "info": "API Endpoint for the Astra DB instance.", + "password": True, + }, + "namespace": { + "display_name": "Namespace", + "info": "Namespace for the Astra DB instance.", + "input_types": ["Text"], + "advanced": True, + }, + } + + def get_messages(self, **kwargs) -> list[Record]: + """ + Retrieves messages from the AstraDBChatMessageHistory memory. + + Args: + memory (AstraDBChatMessageHistory): The AstraDBChatMessageHistory instance to retrieve messages from. + + Returns: + list[Record]: A list of Record objects representing the search results. + """ + memory: AstraDBChatMessageHistory = cast( + AstraDBChatMessageHistory, kwargs.get("memory") + ) + if not memory: + raise ValueError("AstraDBChatMessageHistory instance is required.") + + # Get messages from the memory + messages = memory.messages + results = [Record.from_lc_message(message) for message in messages] + + return list(results) + + def build( + self, + session_id: Text, + collection_name: str, + token: str, + api_endpoint: str, + namespace: Optional[str] = None, + ) -> list[Record]: + try: + from langchain_community.chat_message_histories.astradb import ( + AstraDBChatMessageHistory, + ) + except ImportError: + raise ImportError( + "Could not import langchain Astra DB integration package. " + "Please install it with `pip install langchain-astradb`." + ) + + memory = AstraDBChatMessageHistory( + session_id=session_id, + collection_name=collection_name, + token=token, + api_endpoint=api_endpoint, + namespace=namespace, + ) + + records = self.get_messages(memory=memory) + self.status = records + + return records diff --git a/src/backend/base/langflow/components/memories/AstraDBMessageWriter.py b/src/backend/base/langflow/components/memories/AstraDBMessageWriter.py new file mode 100644 index 000000000..33525656e --- /dev/null +++ b/src/backend/base/langflow/components/memories/AstraDBMessageWriter.py @@ -0,0 +1,117 @@ +from typing import Optional + +from langflow.base.memory.memory import BaseMemoryComponent +from langflow.field_typing import Text +from langflow.schema.schema import Record + +from langchain_core.messages import BaseMessage +from langchain_community.chat_message_histories.astradb import AstraDBChatMessageHistory + + +class AstraDBMessageWriterComponent(BaseMemoryComponent): + display_name = "Astra DB Message Writer" + description = "Writes a message to Astra DB." + + def build_config(self): + return { + "input_value": { + "display_name": "Input Record", + "info": "Record to write to Astra DB.", + }, + "session_id": { + "display_name": "Session ID", + "info": "Session ID of the chat history.", + "input_types": ["Text"], + }, + "collection_name": { + "display_name": "Collection Name", + "info": "Collection name for Astra DB.", + "input_types": ["Text"], + }, + "token": { + "display_name": "Astra DB Application Token", + "info": "Token for the Astra DB instance.", + "password": True, + }, + "api_endpoint": { + "display_name": "Astra DB API Endpoint", + "info": "API Endpoint for the Astra DB instance.", + "password": True, + }, + "namespace": { + "display_name": "Namespace", + "info": "Namespace for the Astra DB instance.", + "input_types": ["Text"], + "advanced": True, + }, + } + + def add_message( + self, + sender: str, + sender_name: str, + text: Text, + session_id: str, + metadata: Optional[dict] = None, + **kwargs, + ): + """ + Adds a message to the AstraDBChatMessageHistory memory. + + Args: + sender (Text): The type of the message sender. Valid values are "Machine" or "User". + sender_name (Text): The name of the message sender. + text (Text): The content of the message. + session_id (Text): The session ID associated with the message. + metadata (dict | None, optional): Additional metadata for the message. Defaults to None. + **kwargs: Additional keyword arguments. + + Raises: + ValueError: If the AstraDBChatMessageHistory instance is not provided. + + """ + memory: AstraDBChatMessageHistory | None = kwargs.pop("memory", None) + if memory is None: + raise ValueError("AstraDBChatMessageHistory instance is required.") + + text_list = [BaseMessage( + content=text, + sender=sender, + sender_name=sender_name, + metadata=metadata, + session_id=session_id, + )] + + memory.add_messages(text_list) + + def build( + self, + input_value: Record, + session_id: Text, + collection_name: str, + token: str, + api_endpoint: str, + namespace: Optional[str] = None, + ) -> Record: + try: + from langchain_community.chat_message_histories.astradb import ( + AstraDBChatMessageHistory, + ) + except ImportError: + raise ImportError( + "Could not import langchain Astra DB integration package. " + "Please install it with `pip install langchain-astradb`." + ) + + memory = AstraDBChatMessageHistory( + session_id=session_id, + collection_name=collection_name, + token=token, + api_endpoint=api_endpoint, + namespace=namespace, + ) + + self.add_message(**input_value.data, memory=memory) + self.status = f"Added message to Astra DB memory for session {session_id}" + + return input_value