From d157fffee3de4634a7d9a738d473029fcb7bae70 Mon Sep 17 00:00:00 2001 From: zhenjianpeng Date: Wed, 28 Jun 2023 15:20:06 +0800 Subject: [PATCH] adding pg support for external message persistance --- src/backend/langflow/config.yaml | 1 + src/backend/langflow/custom/customs.py | 3 ++ .../langflow/interface/initialize/loading.py | 2 + .../template/frontend_node/memories.py | 40 +++++++++++++++++++ 4 files changed, 46 insertions(+) diff --git a/src/backend/langflow/config.yaml b/src/backend/langflow/config.yaml index d8cd4a325..0205cbb31 100644 --- a/src/backend/langflow/config.yaml +++ b/src/backend/langflow/config.yaml @@ -69,6 +69,7 @@ memories: - ConversationBufferMemory - ConversationSummaryMemory - ConversationKGMemory + - PostgresChatMessageHistory prompts: - PromptTemplate - FewShotPromptTemplate diff --git a/src/backend/langflow/custom/customs.py b/src/backend/langflow/custom/customs.py index fb6c1da16..0f1e44308 100644 --- a/src/backend/langflow/custom/customs.py +++ b/src/backend/langflow/custom/customs.py @@ -21,6 +21,9 @@ CUSTOM_NODES = { "utilities": { "SQLDatabase": frontend_node.agents.SQLDatabaseNode(), }, + "memories": { + "PostgresChatMessageHistory": frontend_node.memories.PostgresChatMessageHistoryFrontendNode(), + }, "chains": { "SeriesCharacterChain": frontend_node.chains.SeriesCharacterChainNode(), "TimeTravelGuideChain": frontend_node.chains.TimeTravelGuideChainNode(), diff --git a/src/backend/langflow/interface/initialize/loading.py b/src/backend/langflow/interface/initialize/loading.py index 88b981f9d..c527d745a 100644 --- a/src/backend/langflow/interface/initialize/loading.py +++ b/src/backend/langflow/interface/initialize/loading.py @@ -225,6 +225,7 @@ def load_agent_executor(agent_class: type[agent_module.Agent], params, **kwargs) """Load agent executor from agent class, tools and chain""" allowed_tools: Sequence[BaseTool] = params.get("allowed_tools", []) llm_chain = params["llm_chain"] + memory = params["memory"] # if allowed_tools is not a list or set, make it a list if not isinstance(allowed_tools, (list, set)) and isinstance( allowed_tools, BaseTool @@ -237,6 +238,7 @@ def load_agent_executor(agent_class: type[agent_module.Agent], params, **kwargs) return AgentExecutor.from_agent_and_tools( agent=agent, tools=allowed_tools, + memory=memory, **kwargs, ) diff --git a/src/backend/langflow/template/frontend_node/memories.py b/src/backend/langflow/template/frontend_node/memories.py index 4b312c926..e2f533e7f 100644 --- a/src/backend/langflow/template/frontend_node/memories.py +++ b/src/backend/langflow/template/frontend_node/memories.py @@ -2,6 +2,7 @@ from typing import Optional from langflow.template.field.base import TemplateField from langflow.template.frontend_node.base import FrontendNode +from langflow.template.template.base import Template class MemoryFrontendNode(FrontendNode): @@ -64,3 +65,42 @@ class MemoryFrontendNode(FrontendNode): field.value = "" if field.name == "memory_key": field.value = "chat_history" + + +class PostgresChatMessageHistoryFrontendNode(MemoryFrontendNode): + name: str = "PostgresChatMessageHistory" + template: Template = Template( + type_name="PostgresChatMessageHistory", + fields=[ + TemplateField( + field_type="str", + required=True, + placeholder="", + is_list=False, + show=True, + multiline=False, + name="session_id", + ), + TemplateField( + field_type="str", + required=True, + show=True, + name="connection_string", + ), + TemplateField( + field_type="str", + required=True, + placeholder="", + is_list=False, + show=True, + multiline=False, + value="message_store", + name="table_name", + ), + ], + ) + description: str = "Memory store with Postgres" + base_classes: list[str] = [ + "PostgresChatMessageHistory", + "BaseChatMessageHistory" + ] \ No newline at end of file