From 2785a8bc3aa585482535eaca84b526fa78d985fa Mon Sep 17 00:00:00 2001 From: Rodrigo Nader Date: Thu, 2 May 2024 14:29:54 -0300 Subject: [PATCH] Refactor store_message function and add StoreMessageComponent (#1817) * Refactor store_message function in chat.py and memory.py * Refactor store_message function in chat.py and memory.py * Add StoreMessageComponent to langflow components * Refactor store_message function in chat.py and memory.py to require session_id, sender, and sender_name * Refactor StoreMessageComponent to use Optional[str] for sender_name and session_id parameters --------- Co-authored-by: Gabriel Luiz Freitas Almeida --- src/backend/base/langflow/base/io/chat.py | 37 ++++------------ .../components/experimental/StoreMessage.py | 44 +++++++++++++++++++ src/backend/base/langflow/memory.py | 38 ++++++++++++++++ 3 files changed, 91 insertions(+), 28 deletions(-) create mode 100644 src/backend/base/langflow/components/experimental/StoreMessage.py diff --git a/src/backend/base/langflow/base/io/chat.py b/src/backend/base/langflow/base/io/chat.py index ea24dc968..b3cb7c01d 100644 --- a/src/backend/base/langflow/base/io/chat.py +++ b/src/backend/base/langflow/base/io/chat.py @@ -1,10 +1,9 @@ -import warnings from typing import Optional, Union from langflow.field_typing import Text from langflow.helpers.record import records_to_text from langflow.interface.custom.custom_component import CustomComponent -from langflow.memory import add_messages +from langflow.memory import store_message from langflow.schema import Record @@ -50,34 +49,16 @@ class ChatComponent(CustomComponent): sender: Optional[str] = None, sender_name: Optional[str] = None, ) -> list[Record]: - if not message: - warnings.warn("No message provided.") - return [] - if not session_id or not sender or not sender_name: - raise ValueError("All of session_id, sender, and sender_name must be provided.") - if isinstance(message, Record): - record = message - record.data.update( - { - "session_id": session_id, - "sender": sender, - "sender_name": sender_name, - } - ) - else: - record = Record( - data={ - "text": message, - "session_id": session_id, - "sender": sender, - "sender_name": sender_name, - }, - ) + records = store_message( + message, + session_id=session_id, + sender=sender, + sender_name=sender_name, + ) - self.status = record - records = add_messages([record]) - return records[0] + self.status = records + return records def build_with_record( self, diff --git a/src/backend/base/langflow/components/experimental/StoreMessage.py b/src/backend/base/langflow/components/experimental/StoreMessage.py new file mode 100644 index 000000000..ac50b78b4 --- /dev/null +++ b/src/backend/base/langflow/components/experimental/StoreMessage.py @@ -0,0 +1,44 @@ +from typing import List, Optional + +from langflow.interface.custom.custom_component import CustomComponent +from langflow.memory import get_messages, store_message +from langflow.schema import Record + + +class StoreMessageComponent(CustomComponent): + display_name = "Store Message" + description = "Stores a chat message given a Session ID." + beta: bool = True + + def build_config(self): + return { + "sender": { + "options": ["Machine", "User"], + "display_name": "Sender Type", + }, + "sender_name": {"display_name": "Sender Name"}, + "message": {"display_name": "Message"}, + "session_id": { + "display_name": "Session ID", + "info": "Session ID of the chat history.", + "input_types": ["Text"], + }, + } + + def build( + self, + sender: str = "User", + sender_name: Optional[str] = None, + session_id: Optional[str] = None, + message: str = "", + ) -> List[Record]: + + store_message( + sender=sender, + sender_name=sender_name, + session_id=session_id, + message=message, + ) + + self.status = get_messages(session_id=session_id) + return get_messages(session_id=session_id) diff --git a/src/backend/base/langflow/memory.py b/src/backend/base/langflow/memory.py index aea9b3000..15cf6d6d8 100644 --- a/src/backend/base/langflow/memory.py +++ b/src/backend/base/langflow/memory.py @@ -1,3 +1,4 @@ +import warnings from typing import Optional, Union from loguru import logger @@ -51,6 +52,7 @@ def get_messages( "sender": row.sender, "sender_name": row.sender_name, "session_id": row.session_id, + "timestamp": row.timestamp, }, ) records.append(record) @@ -98,3 +100,39 @@ def delete_messages(session_id: str): """ monitor_service = get_monitor_service() monitor_service.delete_messages(session_id) + + +def store_message( + message: Union[str, Record], + session_id: Optional[str] = None, + sender: Optional[str] = None, + sender_name: Optional[str] = None, +) -> list[Record]: + + if not message: + warnings.warn("No message provided.") + return [] + + if not session_id or not sender or not sender_name: + raise ValueError("All of session_id, sender, and sender_name must be provided.") + + if isinstance(message, Record): + record = message + record.data.update( + { + "session_id": session_id, + "sender": sender, + "sender_name": sender_name, + } + ) + elif isinstance(message, str): + record = Record( + data={ + "text": message, + "session_id": session_id, + "sender": sender, + "sender_name": sender_name, + }, + ) + + return add_messages([record])