From df43dbc6bcc0776ab6e3d82f0498adeb898b479f Mon Sep 17 00:00:00 2001 From: zhenjianpeng Date: Thu, 29 Jun 2023 10:26:27 +0800 Subject: [PATCH] add custom nodes for signature, fix some problem --- src/backend/langflow/interface/memories/base.py | 3 +++ src/backend/langflow/template/frontend_node/memories.py | 9 +++++++++ 2 files changed, 12 insertions(+) diff --git a/src/backend/langflow/interface/memories/base.py b/src/backend/langflow/interface/memories/base.py index f0d8f88f5..a211517f5 100644 --- a/src/backend/langflow/interface/memories/base.py +++ b/src/backend/langflow/interface/memories/base.py @@ -7,6 +7,7 @@ from langflow.template.frontend_node.base import FrontendNode from langflow.template.frontend_node.memories import MemoryFrontendNode from langflow.utils.logger import logger from langflow.utils.util import build_template_from_class +from langflow.custom.customs import get_custom_nodes class MemoryCreator(LangChainTypeCreator): @@ -26,6 +27,8 @@ class MemoryCreator(LangChainTypeCreator): def get_signature(self, name: str) -> Optional[Dict]: """Get the signature of a memory.""" try: + if name in get_custom_nodes(self.type_name).keys(): + return get_custom_nodes(self.type_name)[name] return build_template_from_class(name, memory_type_to_cls_dict) except ValueError as exc: raise ValueError("Memory not found") from exc diff --git a/src/backend/langflow/template/frontend_node/memories.py b/src/backend/langflow/template/frontend_node/memories.py index e2f533e7f..1ba0737aa 100644 --- a/src/backend/langflow/template/frontend_node/memories.py +++ b/src/backend/langflow/template/frontend_node/memories.py @@ -8,6 +8,11 @@ from langflow.template.template.base import Template class MemoryFrontendNode(FrontendNode): #! Needs testing def add_extra_fields(self) -> None: + # chat history should have another way to add common field? + # prevent adding incorect field in ChatMessageHistory + if "BaseChatMessageHistory" in self.base_classes: + pass + # add return_messages field self.template.add_field( TemplateField( @@ -65,6 +70,10 @@ class MemoryFrontendNode(FrontendNode): field.value = "" if field.name == "memory_key": field.value = "chat_history" + if field.name == "chat_memory": + field.show = True + field.advanced = False + field.required = False class PostgresChatMessageHistoryFrontendNode(MemoryFrontendNode):