Feature: Updated Chat History for Astra DB (#1895)

* Updated Chat History for Astra DB

* Fix linting issues (i hope)

* Update AstraDBMessageWriter.py

* Fixes from Nicolo's feedback
This commit is contained in:
Eric Hare 2024-05-23 06:13:15 -07:00 committed by GitHub
commit 0347709d3e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 212 additions and 0 deletions

View file

@ -0,0 +1,95 @@
from typing import Optional, cast
from langchain_astradb.chat_message_histories import AstraDBChatMessageHistory
from langflow.base.memory.memory import BaseMemoryComponent
from langflow.field_typing import Text
from langflow.schema.schema import Record
class AstraDBMessageReaderComponent(BaseMemoryComponent):
display_name = "Astra DB Message Reader"
description = "Retrieves stored chat messages from Astra DB."
def build_config(self):
return {
"session_id": {
"display_name": "Session ID",
"info": "Session ID of the chat history.",
"input_types": ["Text"],
},
"collection_name": {
"display_name": "Collection Name",
"info": "Collection name for Astra DB.",
"input_types": ["Text"],
},
"token": {
"display_name": "Astra DB Application Token",
"info": "Token for the Astra DB instance.",
"password": True,
},
"api_endpoint": {
"display_name": "Astra DB API Endpoint",
"info": "API Endpoint for the Astra DB instance.",
"password": True,
},
"namespace": {
"display_name": "Namespace",
"info": "Namespace for the Astra DB instance.",
"input_types": ["Text"],
"advanced": True,
},
}
def get_messages(self, **kwargs) -> list[Record]:
"""
Retrieves messages from the AstraDBChatMessageHistory memory.
Args:
memory (AstraDBChatMessageHistory): The AstraDBChatMessageHistory instance to retrieve messages from.
Returns:
list[Record]: A list of Record objects representing the search results.
"""
memory: AstraDBChatMessageHistory = cast(
AstraDBChatMessageHistory, kwargs.get("memory")
)
if not memory:
raise ValueError("AstraDBChatMessageHistory instance is required.")
# Get messages from the memory
messages = memory.messages
results = [Record.from_lc_message(message) for message in messages]
return list(results)
def build(
self,
session_id: Text,
collection_name: str,
token: str,
api_endpoint: str,
namespace: Optional[str] = None,
) -> list[Record]:
try:
from langchain_community.chat_message_histories.astradb import (
AstraDBChatMessageHistory,
)
except ImportError:
raise ImportError(
"Could not import langchain Astra DB integration package. "
"Please install it with `pip install langchain-astradb`."
)
memory = AstraDBChatMessageHistory(
session_id=session_id,
collection_name=collection_name,
token=token,
api_endpoint=api_endpoint,
namespace=namespace,
)
records = self.get_messages(memory=memory)
self.status = records
return records

View file

@ -0,0 +1,117 @@
from typing import Optional
from langflow.base.memory.memory import BaseMemoryComponent
from langflow.field_typing import Text
from langflow.schema.schema import Record
from langchain_core.messages import BaseMessage
from langchain_community.chat_message_histories.astradb import AstraDBChatMessageHistory
class AstraDBMessageWriterComponent(BaseMemoryComponent):
display_name = "Astra DB Message Writer"
description = "Writes a message to Astra DB."
def build_config(self):
return {
"input_value": {
"display_name": "Input Record",
"info": "Record to write to Astra DB.",
},
"session_id": {
"display_name": "Session ID",
"info": "Session ID of the chat history.",
"input_types": ["Text"],
},
"collection_name": {
"display_name": "Collection Name",
"info": "Collection name for Astra DB.",
"input_types": ["Text"],
},
"token": {
"display_name": "Astra DB Application Token",
"info": "Token for the Astra DB instance.",
"password": True,
},
"api_endpoint": {
"display_name": "Astra DB API Endpoint",
"info": "API Endpoint for the Astra DB instance.",
"password": True,
},
"namespace": {
"display_name": "Namespace",
"info": "Namespace for the Astra DB instance.",
"input_types": ["Text"],
"advanced": True,
},
}
def add_message(
self,
sender: str,
sender_name: str,
text: Text,
session_id: str,
metadata: Optional[dict] = None,
**kwargs,
):
"""
Adds a message to the AstraDBChatMessageHistory memory.
Args:
sender (Text): The type of the message sender. Valid values are "Machine" or "User".
sender_name (Text): The name of the message sender.
text (Text): The content of the message.
session_id (Text): The session ID associated with the message.
metadata (dict | None, optional): Additional metadata for the message. Defaults to None.
**kwargs: Additional keyword arguments.
Raises:
ValueError: If the AstraDBChatMessageHistory instance is not provided.
"""
memory: AstraDBChatMessageHistory | None = kwargs.pop("memory", None)
if memory is None:
raise ValueError("AstraDBChatMessageHistory instance is required.")
text_list = [BaseMessage(
content=text,
sender=sender,
sender_name=sender_name,
metadata=metadata,
session_id=session_id,
)]
memory.add_messages(text_list)
def build(
self,
input_value: Record,
session_id: Text,
collection_name: str,
token: str,
api_endpoint: str,
namespace: Optional[str] = None,
) -> Record:
try:
from langchain_community.chat_message_histories.astradb import (
AstraDBChatMessageHistory,
)
except ImportError:
raise ImportError(
"Could not import langchain Astra DB integration package. "
"Please install it with `pip install langchain-astradb`."
)
memory = AstraDBChatMessageHistory(
session_id=session_id,
collection_name=collection_name,
token=token,
api_endpoint=api_endpoint,
namespace=namespace,
)
self.add_message(**input_value.data, memory=memory)
self.status = f"Added message to Astra DB memory for session {session_id}"
return input_value