feat: refactor memories (#2621)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Nicolò Boschi 2024-07-12 18:30:04 +02:00 committed by GitHub
commit c88e9af121
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 717 additions and 674 deletions

View file

@ -49,18 +49,6 @@ class ChatComponent(Component):
},
}
def store_message(
self,
message: Message,
) -> list[Message]:
messages = store_message(
message,
flow_id=self.graph.flow_id,
)
self.status = messages
return messages
def build_with_data(
self,
sender: Optional[str] = "User",
@ -86,5 +74,9 @@ class ChatComponent(Component):
self.status = message_text
if session_id and isinstance(message, Message) and isinstance(message.text, str):
self.store_message(message)
messages = store_message(
message,
flow_id=self.graph.flow_id,
)
self.status = messages
return message_text # type: ignore

View file

@ -11,13 +11,13 @@ class LCChatMemoryComponent(Component):
outputs = [
Output(
display_name="Memory",
name="base_memory",
method="build_base_memory",
name="memory",
method="build_message_history",
)
]
def _validate_outputs(self):
required_output_methods = ["build_base_memory"]
required_output_methods = ["build_message_history"]
output_names = [output.name for output in self.outputs]
for method_name in required_output_methods:
if method_name not in output_names:

View file

@ -1,18 +1,27 @@
from langflow.custom import Component
from langflow.helpers.data import data_to_text
from langflow.inputs import HandleInput
from langflow.io import DropdownInput, IntInput, MessageTextInput, MultilineInput, Output
from langflow.memory import get_messages
from langflow.memory import get_messages, LCBuiltinChatMemory
from langflow.schema import Data
from langflow.schema.message import Message
from langflow.field_typing import BaseChatMemory
from langchain.memory import ConversationBufferMemory
class MemoryComponent(Component):
display_name = "Chat Memory"
description = "Retrieves stored chat messages."
description = "Retrieves stored chat messages from Langflow tables or an external memory."
icon = "message-square-more"
name = "Memory"
inputs = [
HandleInput(
name="memory",
display_name="External Memory",
input_types=["BaseChatMessageHistory"],
info="Retrieve messages from an external memory. If empty, it will use the Langflow tables.",
),
DropdownInput(
name="sender",
display_name="Sender Type",
@ -58,8 +67,9 @@ class MemoryComponent(Component):
]
outputs = [
Output(display_name="Chat History", name="messages", method="retrieve_messages"),
Output(display_name="Messages (Data)", name="messages", method="retrieve_messages"),
Output(display_name="Messages (Text)", name="messages_text", method="retrieve_messages_as_text"),
Output(display_name="Memory", name="lc_memory", method="build_lc_memory"),
]
def retrieve_messages(self) -> Data:
@ -72,17 +82,38 @@ class MemoryComponent(Component):
if sender == "Machine and User":
sender = None
messages = get_messages(
sender=sender,
sender_name=sender_name,
session_id=session_id,
limit=n_messages,
order=order,
)
self.status = messages
return messages
if self.memory:
# override session_id
self.memory.session_id = session_id
stored = self.memory.messages
if sender:
expected_type = "Machine" if sender == "Machine" else "User"
stored = [m for m in stored if m.type == expected_type]
if order == "ASC":
stored = stored[::-1]
if n_messages:
stored = stored[:n_messages]
stored = [Message.from_lc_message(m) for m in stored]
else:
stored = get_messages(
sender=sender,
sender_name=sender_name,
session_id=session_id,
limit=n_messages,
order=order,
)
self.status = stored
return stored
def retrieve_messages_as_text(self) -> Message:
messages_text = data_to_text(self.template, self.retrieve_messages())
self.status = messages_text
return Message(text=messages_text)
stored_text = data_to_text(self.template, self.retrieve_messages())
self.status = stored_text
return Message(text=stored_text)
def build_lc_memory(self) -> BaseChatMemory:
if self.memory:
chat_memory = self.memory
else:
chat_memory = LCBuiltinChatMemory(flow_id=self.graph.flow_id, session_id=self.session_id)
return ConversationBufferMemory(chat_memory=chat_memory)

View file

@ -1,5 +1,5 @@
from langflow.custom import Component
from langflow.inputs import MessageInput, StrInput
from langflow.inputs import MessageInput, StrInput, HandleInput
from langflow.schema.message import Message
from langflow.template import Output
from langflow.memory import get_messages, store_message
@ -7,12 +7,18 @@ from langflow.memory import get_messages, store_message
class StoreMessageComponent(Component):
display_name = "Store Message"
description = "Stores a chat message or text."
description = "Stores a chat message or text into Langflow tables or an external memory."
icon = "save"
name = "StoreMessage"
inputs = [
MessageInput(name="message", display_name="Message", info="The chat message to be stored.", required=True),
HandleInput(
name="memory",
display_name="External Memory",
input_types=["BaseChatMessageHistory"],
info="The external memory to store the message. If empty, it will use the Langflow tables.",
),
StrInput(
name="sender",
display_name="Sender",
@ -42,7 +48,17 @@ class StoreMessageComponent(Component):
message.sender = self.sender or message.sender
message.sender_name = self.sender_name or message.sender_name
store_message(message, flow_id=self.graph.flow_id)
stored = get_messages(session_id=message.session_id, sender_name=message.sender_name, sender=message.sender)
if self.memory:
# override session_id
self.memory.session_id = message.session_id
lc_message = message.to_lc_message()
self.memory.add_messages([lc_message])
stored = self.memory.messages
stored = [Message.from_lc_message(m) for m in stored]
if message.sender:
stored = [m for m in stored if m.sender == message.sender]
else:
store_message(message, flow_id=self.graph.flow_id)
stored = get_messages(session_id=message.session_id, sender_name=message.sender_name, sender=message.sender)
self.status = stored
return stored

View file

@ -1,6 +1,8 @@
from langflow.base.data.utils import IMG_FILE_TYPES, TEXT_FILE_TYPES
from langflow.base.io.chat import ChatComponent
from langflow.inputs import BoolInput
from langflow.io import DropdownInput, FileInput, MessageTextInput, MultilineInput, Output
from langflow.memory import store_message
from langflow.schema.message import Message
@ -17,6 +19,12 @@ class ChatInput(ChatComponent):
value="",
info="Message to be passed as input.",
),
BoolInput(
name="store_message",
display_name="Store Messages",
info="Store the message in the history.",
value=True,
),
DropdownInput(
name="sender",
display_name="Sender Type",
@ -56,8 +64,12 @@ class ChatInput(ChatComponent):
session_id=self.session_id,
files=self.files,
)
if self.session_id and isinstance(message, Message) and isinstance(message.text, str):
self.store_message(message)
store_message(
message,
flow_id=self.graph.flow_id,
)
self.message.value = message
self.status = message

View file

@ -1,5 +1,7 @@
from langflow.base.io.chat import ChatComponent
from langflow.inputs import BoolInput
from langflow.io import DropdownInput, MessageTextInput, Output
from langflow.memory import store_message
from langflow.schema.message import Message
@ -15,6 +17,12 @@ class ChatOutput(ChatComponent):
display_name="Text",
info="Message to be passed as output.",
),
BoolInput(
name="store_message",
display_name="Store Messages",
info="Store the message in the history.",
value=True,
),
DropdownInput(
name="sender",
display_name="Sender Type",
@ -49,7 +57,10 @@ class ChatOutput(ChatComponent):
session_id=self.session_id,
)
if self.session_id and isinstance(message, Message) and isinstance(message.text, str):
self.store_message(message)
store_message(
message,
flow_id=self.graph.flow_id,
)
self.message.value = message
self.status = message

View file

@ -1,4 +1,5 @@
import warnings
from typing import List, Sequence
from uuid import UUID
from loguru import logger
@ -9,6 +10,8 @@ from langflow.schema.message import Message
from langflow.services.database.models.message.model import MessageRead, MessageTable
from langflow.services.database.utils import migrate_messages_from_monitor_service_to_database
from langflow.services.deps import session_scope
from langflow.field_typing import BaseChatMessageHistory
from langchain_core.messages import BaseMessage
def get_messages(
@ -19,7 +22,7 @@ def get_messages(
order: str | None = "DESC",
flow_id: UUID | None = None,
limit: int | None = None,
):
) -> List[Message]:
"""
Retrieves messages from the monitor service based on the provided filters.
@ -136,3 +139,29 @@ def store_message(
raise ValueError("All of session_id, sender, and sender_name must be provided.")
return add_messages([message], flow_id=flow_id)
class LCBuiltinChatMemory(BaseChatMessageHistory):
def __init__(
self,
flow_id: str,
session_id: str,
) -> None:
self.flow_id = flow_id
self.session_id = session_id
@property
def messages(self) -> List[BaseMessage]:
messages = get_messages(
session_id=self.session_id,
)
return [m.to_lc_message() for m in messages]
def add_messages(self, messages: Sequence[BaseMessage]) -> None:
for lc_message in messages:
message = Message.from_lc_message(lc_message)
message.session_id = self.session_id
store_message(message, flow_id=self.flow_id)
def clear(self) -> None:
delete_messages(self.session_id)

View file

@ -102,6 +102,19 @@ class Message(Data):
return AIMessage(content=text) # type: ignore
@classmethod
def from_lc_message(cls, lc_message: BaseMessage) -> "Message":
if lc_message.type == "human":
sender = "User"
elif lc_message.type == "ai":
sender = "Machine"
elif lc_message.type == "system":
sender = "System"
else:
sender = lc_message.type
return cls(text=lc_message.content, sender=sender, sender_name=sender)
@classmethod
def from_data(cls, data: "Data") -> "Message":
"""