feat: Add MemoryComponent for retrieving stored chat messages

This commit is contained in:
Rodrigo 2024-06-13 00:41:35 -03:00
commit 77dc6b3d4f
3 changed files with 110 additions and 83 deletions

View file

@ -0,0 +1,85 @@
from typing import List
from langflow.custom import Component
from langflow.inputs import DropdownInput, StrInput, IntInput
from langflow.template import Output
from langflow.memory import get_messages
from langflow.schema import Data
from langflow.field_typing import Text
class MemoryComponent(Component):
display_name = "Memory"
description = "Retrieves stored chat messages."
icon = "history"
inputs = [
DropdownInput(
name="sender",
display_name="Sender Type",
options=["Machine", "User", "Machine and User"],
value="Machine and User",
info="Type of sender.",
advanced=True,
),
StrInput(
name="sender_name",
display_name="Sender Name",
info="Name of the sender.",
advanced=True,
),
IntInput(
name="n_messages",
display_name="Number of Messages",
value=100,
info="Number of messages to retrieve.",
advanced=True,
),
StrInput(
name="session_id",
display_name="Session ID",
info="Session ID of the chat history.",
advanced=True,
),
DropdownInput(
name="order",
display_name="Order",
options=["Ascending", "Descending"],
value="Descending",
info="Order of the messages.",
advanced=True,
),
]
outputs = [
Output(display_name="Messages", name="messages", method="retrieve_messages"),
Output(display_name="Text", name="messages_text", method="retrieve_messages_as_text"),
]
def retrieve_messages(self) -> List[Data]:
sender = self.sender
sender_name = self.sender_name
session_id = self.session_id
n_messages = self.n_messages
order = "DESC" if self.order == "Descending" else "ASC"
if sender == "Machine and User":
sender = None
messages = get_messages(
sender=sender,
sender_name=sender_name,
session_id=session_id,
limit=n_messages,
order=order,
)
self.status = messages
return messages
def retrieve_messages_as_text(self) -> Text:
messages = self.retrieve_messages()
messages_text = "\n".join(
[f"{message.data.get('sender_name')}: {message.data.get('text')}" for message in messages]
)
self.status = messages_text
return Text(messages_text)

View file

@ -1,58 +0,0 @@
from typing import List, Optional
from langflow.custom import CustomComponent
from langflow.memory import get_messages
from langflow.schema import Data
class MessageHistoryComponent(CustomComponent):
display_name = "Memory"
description = "Retrieves stored chat messages."
def build_config(self):
return {
"sender": {
"options": ["Machine", "User", "Machine and User"],
"display_name": "Sender Type",
"advanced": True,
},
"sender_name": {"display_name": "Sender Name", "advanced": True},
"n_messages": {
"display_name": "Number of Messages",
"info": "Number of messages to retrieve.",
"advanced": True,
},
"session_id": {
"display_name": "Session ID",
"info": "Session ID of the chat history.",
"input_types": ["Text"],
"advanced": True,
},
"order": {
"options": ["Ascending", "Descending"],
"display_name": "Order",
"info": "Order of the messages.",
"advanced": True,
},
}
def build(
self,
sender: Optional[str] = "Machine and User",
sender_name: Optional[str] = None,
session_id: Optional[str] = None,
n_messages: int = 100,
order: Optional[str] = "Descending",
) -> List[Data]:
order = "DESC" if order == "Descending" else "ASC"
if sender == "Machine and User":
sender = None
messages = get_messages(
sender=sender,
sender_name=sender_name,
session_id=session_id,
limit=n_messages,
order=order,
)
self.status = messages
return messages

View file

@ -69,14 +69,14 @@ export default function GenericNode({
const [nodeName, setNodeName] = useState(data.node!.display_name);
const [inputDescription, setInputDescription] = useState(false);
const [nodeDescription, setNodeDescription] = useState(
data.node?.description!
data.node?.description!,
);
const [isOutdated, setIsOutdated] = useState(false);
const buildStatus = useFlowStore(
(state) => state.flowBuildStatus[data.id]?.status
(state) => state.flowBuildStatus[data.id]?.status,
);
const lastRunTime = useFlowStore(
(state) => state.flowBuildStatus[data.id]?.timestamp
(state) => state.flowBuildStatus[data.id]?.timestamp,
);
const [validationStatus, setValidationStatus] =
useState<VertexBuildTypeAPI | null>(null);
@ -93,7 +93,7 @@ export default function GenericNode({
data.node!,
setNode,
setIsOutdated,
updateNodeInternals
updateNodeInternals,
);
const name = nodeIconsLucide[data.type] ? data.type : types[data.type];
@ -120,12 +120,12 @@ export default function GenericNode({
selected: boolean,
showNode: boolean,
buildStatus: BuildStatus | undefined,
validationStatus: VertexBuildTypeAPI | null
validationStatus: VertexBuildTypeAPI | null,
) => {
const specificClassFromBuildStatus = getSpecificClassFromBuildStatus(
buildStatus,
validationStatus,
isDark
isDark,
);
const baseBorderClass = getBaseBorderClass(selected);
@ -134,7 +134,7 @@ export default function GenericNode({
baseBorderClass,
nodeSizeClass,
"generic-node-div group/node",
specificClassFromBuildStatus
specificClassFromBuildStatus,
);
return names;
};
@ -179,7 +179,7 @@ export default function GenericNode({
showNode,
isEmoji,
nodeIconFragment,
checkNodeIconFragment
checkNodeIconFragment,
);
function countHandles(): void {
@ -356,7 +356,7 @@ export default function GenericNode({
selected,
showNode,
buildStatus,
validationStatus
validationStatus,
)}
>
{data.node?.beta && showNode && (
@ -503,7 +503,7 @@ export default function GenericNode({
}
title={getFieldTitle(
data.node?.template!,
templateField
templateField,
)}
info={data.node?.template[templateField].info}
name={templateField}
@ -531,7 +531,7 @@ export default function GenericNode({
proxy={data.node?.template[templateField].proxy}
showNode={showNode}
/>
)
),
)}
{/* <ParameterComponent
index={0}
@ -634,7 +634,7 @@ export default function GenericNode({
? "pb-8"
: "pb-8 pt-5"
: "",
"relative"
"relative",
)}
>
{/* increase height!! */}
@ -691,7 +691,7 @@ export default function GenericNode({
!data.node?.description) &&
nameEditable
? "font-light italic"
: ""
: "",
)}
onDoubleClick={(e) => {
setInputDescription(true);
@ -760,13 +760,13 @@ export default function GenericNode({
}
title={getFieldTitle(
data.node?.template!,
templateField
templateField,
)}
info={data.node?.template[templateField].info}
name={templateField}
tooltipTitle={
data.node?.template[templateField].input_types?.join(
"\n"
"\n",
) ?? data.node?.template[templateField].type
}
required={data.node!.template[templateField].required}
@ -793,7 +793,7 @@ export default function GenericNode({
<div
className={classNames(
Object.keys(data.node!.template).length < 1 ? "hidden" : "",
"flex-max-width justify-center"
"flex-max-width justify-center",
)}
>
{" "}
@ -804,9 +804,9 @@ export default function GenericNode({
renderOutputParameter(
output,
data.node!.outputs?.findIndex(
(out) => out.name === output.name
) ?? idx
)
(out) => out.name === output.name,
) ?? idx,
),
)}
<div
className={cn(showHiddenOutputs ? "" : "h-0 overflow-hidden")}
@ -817,9 +817,9 @@ export default function GenericNode({
renderOutputParameter(
output,
data.node!.outputs?.findIndex(
(out) => out.name === output.name
) ?? idx
)
(out) => out.name === output.name,
) ?? idx,
),
)}
</div>
</div>
@ -830,7 +830,7 @@ export default function GenericNode({
(shownOutputs && shownOutputs.length > 0) ||
showHiddenOutputs
? "bottom-5"
: "bottom-1.5"
: "bottom-1.5",
)}
>
<Button
@ -843,8 +843,8 @@ export default function GenericNode({
name={"ChevronDown"}
strokeWidth={1.5}
className={cn(
"h-5 w-5 pt-px text-muted-foreground transition-all group-hover:text-medium-indigo group-hover/node:opacity-100",
showHiddenOutputs ? "rotate-180 transform" : ""
"h-5 w-5 pt-px text-muted-foreground group-hover:text-medium-indigo group-hover/node:opacity-100",
showHiddenOutputs ? "rotate-180 transform" : "",
)}
/>
</Button>