feat: refactor memories (#2621)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
ccd6d1c666
commit
c88e9af121
9 changed files with 717 additions and 674 deletions
|
|
@ -49,18 +49,6 @@ class ChatComponent(Component):
|
|||
},
|
||||
}
|
||||
|
||||
def store_message(
|
||||
self,
|
||||
message: Message,
|
||||
) -> list[Message]:
|
||||
messages = store_message(
|
||||
message,
|
||||
flow_id=self.graph.flow_id,
|
||||
)
|
||||
|
||||
self.status = messages
|
||||
return messages
|
||||
|
||||
def build_with_data(
|
||||
self,
|
||||
sender: Optional[str] = "User",
|
||||
|
|
@ -86,5 +74,9 @@ class ChatComponent(Component):
|
|||
|
||||
self.status = message_text
|
||||
if session_id and isinstance(message, Message) and isinstance(message.text, str):
|
||||
self.store_message(message)
|
||||
messages = store_message(
|
||||
message,
|
||||
flow_id=self.graph.flow_id,
|
||||
)
|
||||
self.status = messages
|
||||
return message_text # type: ignore
|
||||
|
|
|
|||
|
|
@ -11,13 +11,13 @@ class LCChatMemoryComponent(Component):
|
|||
outputs = [
|
||||
Output(
|
||||
display_name="Memory",
|
||||
name="base_memory",
|
||||
method="build_base_memory",
|
||||
name="memory",
|
||||
method="build_message_history",
|
||||
)
|
||||
]
|
||||
|
||||
def _validate_outputs(self):
|
||||
required_output_methods = ["build_base_memory"]
|
||||
required_output_methods = ["build_message_history"]
|
||||
output_names = [output.name for output in self.outputs]
|
||||
for method_name in required_output_methods:
|
||||
if method_name not in output_names:
|
||||
|
|
|
|||
|
|
@ -1,18 +1,27 @@
|
|||
from langflow.custom import Component
|
||||
from langflow.helpers.data import data_to_text
|
||||
from langflow.inputs import HandleInput
|
||||
from langflow.io import DropdownInput, IntInput, MessageTextInput, MultilineInput, Output
|
||||
from langflow.memory import get_messages
|
||||
from langflow.memory import get_messages, LCBuiltinChatMemory
|
||||
from langflow.schema import Data
|
||||
from langflow.schema.message import Message
|
||||
from langflow.field_typing import BaseChatMemory
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
|
||||
|
||||
class MemoryComponent(Component):
|
||||
display_name = "Chat Memory"
|
||||
description = "Retrieves stored chat messages."
|
||||
description = "Retrieves stored chat messages from Langflow tables or an external memory."
|
||||
icon = "message-square-more"
|
||||
name = "Memory"
|
||||
|
||||
inputs = [
|
||||
HandleInput(
|
||||
name="memory",
|
||||
display_name="External Memory",
|
||||
input_types=["BaseChatMessageHistory"],
|
||||
info="Retrieve messages from an external memory. If empty, it will use the Langflow tables.",
|
||||
),
|
||||
DropdownInput(
|
||||
name="sender",
|
||||
display_name="Sender Type",
|
||||
|
|
@ -58,8 +67,9 @@ class MemoryComponent(Component):
|
|||
]
|
||||
|
||||
outputs = [
|
||||
Output(display_name="Chat History", name="messages", method="retrieve_messages"),
|
||||
Output(display_name="Messages (Data)", name="messages", method="retrieve_messages"),
|
||||
Output(display_name="Messages (Text)", name="messages_text", method="retrieve_messages_as_text"),
|
||||
Output(display_name="Memory", name="lc_memory", method="build_lc_memory"),
|
||||
]
|
||||
|
||||
def retrieve_messages(self) -> Data:
|
||||
|
|
@ -72,17 +82,38 @@ class MemoryComponent(Component):
|
|||
if sender == "Machine and User":
|
||||
sender = None
|
||||
|
||||
messages = get_messages(
|
||||
sender=sender,
|
||||
sender_name=sender_name,
|
||||
session_id=session_id,
|
||||
limit=n_messages,
|
||||
order=order,
|
||||
)
|
||||
self.status = messages
|
||||
return messages
|
||||
if self.memory:
|
||||
# override session_id
|
||||
self.memory.session_id = session_id
|
||||
|
||||
stored = self.memory.messages
|
||||
if sender:
|
||||
expected_type = "Machine" if sender == "Machine" else "User"
|
||||
stored = [m for m in stored if m.type == expected_type]
|
||||
if order == "ASC":
|
||||
stored = stored[::-1]
|
||||
if n_messages:
|
||||
stored = stored[:n_messages]
|
||||
stored = [Message.from_lc_message(m) for m in stored]
|
||||
else:
|
||||
stored = get_messages(
|
||||
sender=sender,
|
||||
sender_name=sender_name,
|
||||
session_id=session_id,
|
||||
limit=n_messages,
|
||||
order=order,
|
||||
)
|
||||
self.status = stored
|
||||
return stored
|
||||
|
||||
def retrieve_messages_as_text(self) -> Message:
|
||||
messages_text = data_to_text(self.template, self.retrieve_messages())
|
||||
self.status = messages_text
|
||||
return Message(text=messages_text)
|
||||
stored_text = data_to_text(self.template, self.retrieve_messages())
|
||||
self.status = stored_text
|
||||
return Message(text=stored_text)
|
||||
|
||||
def build_lc_memory(self) -> BaseChatMemory:
|
||||
if self.memory:
|
||||
chat_memory = self.memory
|
||||
else:
|
||||
chat_memory = LCBuiltinChatMemory(flow_id=self.graph.flow_id, session_id=self.session_id)
|
||||
return ConversationBufferMemory(chat_memory=chat_memory)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from langflow.custom import Component
|
||||
from langflow.inputs import MessageInput, StrInput
|
||||
from langflow.inputs import MessageInput, StrInput, HandleInput
|
||||
from langflow.schema.message import Message
|
||||
from langflow.template import Output
|
||||
from langflow.memory import get_messages, store_message
|
||||
|
|
@ -7,12 +7,18 @@ from langflow.memory import get_messages, store_message
|
|||
|
||||
class StoreMessageComponent(Component):
|
||||
display_name = "Store Message"
|
||||
description = "Stores a chat message or text."
|
||||
description = "Stores a chat message or text into Langflow tables or an external memory."
|
||||
icon = "save"
|
||||
name = "StoreMessage"
|
||||
|
||||
inputs = [
|
||||
MessageInput(name="message", display_name="Message", info="The chat message to be stored.", required=True),
|
||||
HandleInput(
|
||||
name="memory",
|
||||
display_name="External Memory",
|
||||
input_types=["BaseChatMessageHistory"],
|
||||
info="The external memory to store the message. If empty, it will use the Langflow tables.",
|
||||
),
|
||||
StrInput(
|
||||
name="sender",
|
||||
display_name="Sender",
|
||||
|
|
@ -42,7 +48,17 @@ class StoreMessageComponent(Component):
|
|||
message.sender = self.sender or message.sender
|
||||
message.sender_name = self.sender_name or message.sender_name
|
||||
|
||||
store_message(message, flow_id=self.graph.flow_id)
|
||||
stored = get_messages(session_id=message.session_id, sender_name=message.sender_name, sender=message.sender)
|
||||
if self.memory:
|
||||
# override session_id
|
||||
self.memory.session_id = message.session_id
|
||||
lc_message = message.to_lc_message()
|
||||
self.memory.add_messages([lc_message])
|
||||
stored = self.memory.messages
|
||||
stored = [Message.from_lc_message(m) for m in stored]
|
||||
if message.sender:
|
||||
stored = [m for m in stored if m.sender == message.sender]
|
||||
else:
|
||||
store_message(message, flow_id=self.graph.flow_id)
|
||||
stored = get_messages(session_id=message.session_id, sender_name=message.sender_name, sender=message.sender)
|
||||
self.status = stored
|
||||
return stored
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
from langflow.base.data.utils import IMG_FILE_TYPES, TEXT_FILE_TYPES
|
||||
from langflow.base.io.chat import ChatComponent
|
||||
from langflow.inputs import BoolInput
|
||||
from langflow.io import DropdownInput, FileInput, MessageTextInput, MultilineInput, Output
|
||||
from langflow.memory import store_message
|
||||
from langflow.schema.message import Message
|
||||
|
||||
|
||||
|
|
@ -17,6 +19,12 @@ class ChatInput(ChatComponent):
|
|||
value="",
|
||||
info="Message to be passed as input.",
|
||||
),
|
||||
BoolInput(
|
||||
name="store_message",
|
||||
display_name="Store Messages",
|
||||
info="Store the message in the history.",
|
||||
value=True,
|
||||
),
|
||||
DropdownInput(
|
||||
name="sender",
|
||||
display_name="Sender Type",
|
||||
|
|
@ -56,8 +64,12 @@ class ChatInput(ChatComponent):
|
|||
session_id=self.session_id,
|
||||
files=self.files,
|
||||
)
|
||||
|
||||
if self.session_id and isinstance(message, Message) and isinstance(message.text, str):
|
||||
self.store_message(message)
|
||||
store_message(
|
||||
message,
|
||||
flow_id=self.graph.flow_id,
|
||||
)
|
||||
self.message.value = message
|
||||
|
||||
self.status = message
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
from langflow.base.io.chat import ChatComponent
|
||||
from langflow.inputs import BoolInput
|
||||
from langflow.io import DropdownInput, MessageTextInput, Output
|
||||
from langflow.memory import store_message
|
||||
from langflow.schema.message import Message
|
||||
|
||||
|
||||
|
|
@ -15,6 +17,12 @@ class ChatOutput(ChatComponent):
|
|||
display_name="Text",
|
||||
info="Message to be passed as output.",
|
||||
),
|
||||
BoolInput(
|
||||
name="store_message",
|
||||
display_name="Store Messages",
|
||||
info="Store the message in the history.",
|
||||
value=True,
|
||||
),
|
||||
DropdownInput(
|
||||
name="sender",
|
||||
display_name="Sender Type",
|
||||
|
|
@ -49,7 +57,10 @@ class ChatOutput(ChatComponent):
|
|||
session_id=self.session_id,
|
||||
)
|
||||
if self.session_id and isinstance(message, Message) and isinstance(message.text, str):
|
||||
self.store_message(message)
|
||||
store_message(
|
||||
message,
|
||||
flow_id=self.graph.flow_id,
|
||||
)
|
||||
self.message.value = message
|
||||
|
||||
self.status = message
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -1,4 +1,5 @@
|
|||
import warnings
|
||||
from typing import List, Sequence
|
||||
from uuid import UUID
|
||||
|
||||
from loguru import logger
|
||||
|
|
@ -9,6 +10,8 @@ from langflow.schema.message import Message
|
|||
from langflow.services.database.models.message.model import MessageRead, MessageTable
|
||||
from langflow.services.database.utils import migrate_messages_from_monitor_service_to_database
|
||||
from langflow.services.deps import session_scope
|
||||
from langflow.field_typing import BaseChatMessageHistory
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
|
||||
def get_messages(
|
||||
|
|
@ -19,7 +22,7 @@ def get_messages(
|
|||
order: str | None = "DESC",
|
||||
flow_id: UUID | None = None,
|
||||
limit: int | None = None,
|
||||
):
|
||||
) -> List[Message]:
|
||||
"""
|
||||
Retrieves messages from the monitor service based on the provided filters.
|
||||
|
||||
|
|
@ -136,3 +139,29 @@ def store_message(
|
|||
raise ValueError("All of session_id, sender, and sender_name must be provided.")
|
||||
|
||||
return add_messages([message], flow_id=flow_id)
|
||||
|
||||
|
||||
class LCBuiltinChatMemory(BaseChatMessageHistory):
|
||||
def __init__(
|
||||
self,
|
||||
flow_id: str,
|
||||
session_id: str,
|
||||
) -> None:
|
||||
self.flow_id = flow_id
|
||||
self.session_id = session_id
|
||||
|
||||
@property
|
||||
def messages(self) -> List[BaseMessage]:
|
||||
messages = get_messages(
|
||||
session_id=self.session_id,
|
||||
)
|
||||
return [m.to_lc_message() for m in messages]
|
||||
|
||||
def add_messages(self, messages: Sequence[BaseMessage]) -> None:
|
||||
for lc_message in messages:
|
||||
message = Message.from_lc_message(lc_message)
|
||||
message.session_id = self.session_id
|
||||
store_message(message, flow_id=self.flow_id)
|
||||
|
||||
def clear(self) -> None:
|
||||
delete_messages(self.session_id)
|
||||
|
|
|
|||
|
|
@ -102,6 +102,19 @@ class Message(Data):
|
|||
|
||||
return AIMessage(content=text) # type: ignore
|
||||
|
||||
@classmethod
|
||||
def from_lc_message(cls, lc_message: BaseMessage) -> "Message":
|
||||
if lc_message.type == "human":
|
||||
sender = "User"
|
||||
elif lc_message.type == "ai":
|
||||
sender = "Machine"
|
||||
elif lc_message.type == "system":
|
||||
sender = "System"
|
||||
else:
|
||||
sender = lc_message.type
|
||||
|
||||
return cls(text=lc_message.content, sender=sender, sender_name=sender)
|
||||
|
||||
@classmethod
|
||||
def from_data(cls, data: "Data") -> "Message":
|
||||
"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue