From 77dc6b3d4fbb3a52ffa503db28f9f9381b139f92 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Thu, 13 Jun 2024 00:41:35 -0300 Subject: [PATCH] feat: Add MemoryComponent for retrieving stored chat messages --- .../langflow/components/helpers/Memory.py | 85 +++++++++++++++++++ .../components/helpers/MessageHistory.py | 58 ------------- .../src/CustomNodes/GenericNode/index.tsx | 50 +++++------ 3 files changed, 110 insertions(+), 83 deletions(-) create mode 100644 src/backend/base/langflow/components/helpers/Memory.py delete mode 100644 src/backend/base/langflow/components/helpers/MessageHistory.py diff --git a/src/backend/base/langflow/components/helpers/Memory.py b/src/backend/base/langflow/components/helpers/Memory.py new file mode 100644 index 000000000..480d019be --- /dev/null +++ b/src/backend/base/langflow/components/helpers/Memory.py @@ -0,0 +1,85 @@ +from typing import List + +from langflow.custom import Component +from langflow.inputs import DropdownInput, StrInput, IntInput +from langflow.template import Output +from langflow.memory import get_messages +from langflow.schema import Data +from langflow.field_typing import Text + + +class MemoryComponent(Component): + display_name = "Memory" + description = "Retrieves stored chat messages." + icon = "history" + + inputs = [ + DropdownInput( + name="sender", + display_name="Sender Type", + options=["Machine", "User", "Machine and User"], + value="Machine and User", + info="Type of sender.", + advanced=True, + ), + StrInput( + name="sender_name", + display_name="Sender Name", + info="Name of the sender.", + advanced=True, + ), + IntInput( + name="n_messages", + display_name="Number of Messages", + value=100, + info="Number of messages to retrieve.", + advanced=True, + ), + StrInput( + name="session_id", + display_name="Session ID", + info="Session ID of the chat history.", + advanced=True, + ), + DropdownInput( + name="order", + display_name="Order", + options=["Ascending", "Descending"], + value="Descending", + info="Order of the messages.", + advanced=True, + ), + ] + + outputs = [ + Output(display_name="Messages", name="messages", method="retrieve_messages"), + Output(display_name="Text", name="messages_text", method="retrieve_messages_as_text"), + ] + + def retrieve_messages(self) -> List[Data]: + sender = self.sender + sender_name = self.sender_name + session_id = self.session_id + n_messages = self.n_messages + order = "DESC" if self.order == "Descending" else "ASC" + + if sender == "Machine and User": + sender = None + + messages = get_messages( + sender=sender, + sender_name=sender_name, + session_id=session_id, + limit=n_messages, + order=order, + ) + self.status = messages + return messages + + def retrieve_messages_as_text(self) -> Text: + messages = self.retrieve_messages() + messages_text = "\n".join( + [f"{message.data.get('sender_name')}: {message.data.get('text')}" for message in messages] + ) + self.status = messages_text + return Text(messages_text) diff --git a/src/backend/base/langflow/components/helpers/MessageHistory.py b/src/backend/base/langflow/components/helpers/MessageHistory.py deleted file mode 100644 index 5933c27aa..000000000 --- a/src/backend/base/langflow/components/helpers/MessageHistory.py +++ /dev/null @@ -1,58 +0,0 @@ -from typing import List, Optional - -from langflow.custom import CustomComponent -from langflow.memory import get_messages -from langflow.schema import Data - - -class MessageHistoryComponent(CustomComponent): - display_name = "Memory" - description = "Retrieves stored chat messages." - - def build_config(self): - return { - "sender": { - "options": ["Machine", "User", "Machine and User"], - "display_name": "Sender Type", - "advanced": True, - }, - "sender_name": {"display_name": "Sender Name", "advanced": True}, - "n_messages": { - "display_name": "Number of Messages", - "info": "Number of messages to retrieve.", - "advanced": True, - }, - "session_id": { - "display_name": "Session ID", - "info": "Session ID of the chat history.", - "input_types": ["Text"], - "advanced": True, - }, - "order": { - "options": ["Ascending", "Descending"], - "display_name": "Order", - "info": "Order of the messages.", - "advanced": True, - }, - } - - def build( - self, - sender: Optional[str] = "Machine and User", - sender_name: Optional[str] = None, - session_id: Optional[str] = None, - n_messages: int = 100, - order: Optional[str] = "Descending", - ) -> List[Data]: - order = "DESC" if order == "Descending" else "ASC" - if sender == "Machine and User": - sender = None - messages = get_messages( - sender=sender, - sender_name=sender_name, - session_id=session_id, - limit=n_messages, - order=order, - ) - self.status = messages - return messages diff --git a/src/frontend/src/CustomNodes/GenericNode/index.tsx b/src/frontend/src/CustomNodes/GenericNode/index.tsx index 3e640cba6..537d49e85 100644 --- a/src/frontend/src/CustomNodes/GenericNode/index.tsx +++ b/src/frontend/src/CustomNodes/GenericNode/index.tsx @@ -69,14 +69,14 @@ export default function GenericNode({ const [nodeName, setNodeName] = useState(data.node!.display_name); const [inputDescription, setInputDescription] = useState(false); const [nodeDescription, setNodeDescription] = useState( - data.node?.description! + data.node?.description!, ); const [isOutdated, setIsOutdated] = useState(false); const buildStatus = useFlowStore( - (state) => state.flowBuildStatus[data.id]?.status + (state) => state.flowBuildStatus[data.id]?.status, ); const lastRunTime = useFlowStore( - (state) => state.flowBuildStatus[data.id]?.timestamp + (state) => state.flowBuildStatus[data.id]?.timestamp, ); const [validationStatus, setValidationStatus] = useState(null); @@ -93,7 +93,7 @@ export default function GenericNode({ data.node!, setNode, setIsOutdated, - updateNodeInternals + updateNodeInternals, ); const name = nodeIconsLucide[data.type] ? data.type : types[data.type]; @@ -120,12 +120,12 @@ export default function GenericNode({ selected: boolean, showNode: boolean, buildStatus: BuildStatus | undefined, - validationStatus: VertexBuildTypeAPI | null + validationStatus: VertexBuildTypeAPI | null, ) => { const specificClassFromBuildStatus = getSpecificClassFromBuildStatus( buildStatus, validationStatus, - isDark + isDark, ); const baseBorderClass = getBaseBorderClass(selected); @@ -134,7 +134,7 @@ export default function GenericNode({ baseBorderClass, nodeSizeClass, "generic-node-div group/node", - specificClassFromBuildStatus + specificClassFromBuildStatus, ); return names; }; @@ -179,7 +179,7 @@ export default function GenericNode({ showNode, isEmoji, nodeIconFragment, - checkNodeIconFragment + checkNodeIconFragment, ); function countHandles(): void { @@ -356,7 +356,7 @@ export default function GenericNode({ selected, showNode, buildStatus, - validationStatus + validationStatus, )} > {data.node?.beta && showNode && ( @@ -503,7 +503,7 @@ export default function GenericNode({ } title={getFieldTitle( data.node?.template!, - templateField + templateField, )} info={data.node?.template[templateField].info} name={templateField} @@ -531,7 +531,7 @@ export default function GenericNode({ proxy={data.node?.template[templateField].proxy} showNode={showNode} /> - ) + ), )} {/* {/* increase height!! */} @@ -691,7 +691,7 @@ export default function GenericNode({ !data.node?.description) && nameEditable ? "font-light italic" - : "" + : "", )} onDoubleClick={(e) => { setInputDescription(true); @@ -760,13 +760,13 @@ export default function GenericNode({ } title={getFieldTitle( data.node?.template!, - templateField + templateField, )} info={data.node?.template[templateField].info} name={templateField} tooltipTitle={ data.node?.template[templateField].input_types?.join( - "\n" + "\n", ) ?? data.node?.template[templateField].type } required={data.node!.template[templateField].required} @@ -793,7 +793,7 @@ export default function GenericNode({
{" "} @@ -804,9 +804,9 @@ export default function GenericNode({ renderOutputParameter( output, data.node!.outputs?.findIndex( - (out) => out.name === output.name - ) ?? idx - ) + (out) => out.name === output.name, + ) ?? idx, + ), )}
out.name === output.name - ) ?? idx - ) + (out) => out.name === output.name, + ) ?? idx, + ), )}
@@ -830,7 +830,7 @@ export default function GenericNode({ (shownOutputs && shownOutputs.length > 0) || showHiddenOutputs ? "bottom-5" - : "bottom-1.5" + : "bottom-1.5", )} >