From b5ab95c4cce84af19b99bd26e0b4907fdb2ddaa4 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Wed, 7 Jun 2023 19:46:51 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A8=20refactor(custom.py):=20remove=20?= =?UTF-8?q?unused=20code=20=E2=9C=A8=20feat(frontend=5Fnode):=20add=20extr?= =?UTF-8?q?a=20fields=20to=20MemoryFrontendNode=20and=20ChainFrontendNode?= =?UTF-8?q?=20The=20unused=20code=20in=20custom.py=20has=20been=20removed.?= =?UTF-8?q?=20The=20MemoryFrontendNode=20and=20ChainFrontendNode=20classes?= =?UTF-8?q?=20have=20been=20updated=20to=20include=20additional=20fields?= =?UTF-8?q?=20that=20are=20required=20for=20their=20respective=20templates?= =?UTF-8?q?.=20The=20MemoryFrontendNode=20now=20has=20a=20return=5Fmessage?= =?UTF-8?q?s=20field,=20and=20the=20ChainFrontendNode=20now=20has=20a=20me?= =?UTF-8?q?mory=20field.=20These=20fields=20are=20optional=20and=20can=20b?= =?UTF-8?q?e=20toggled=20on=20or=20off=20as=20required.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../langflow/interface/prompts/custom.py | 4 -- .../langflow/template/frontend_node/base.py | 3 ++ .../langflow/template/frontend_node/chains.py | 52 +++++++++++++++++++ .../template/frontend_node/memories.py | 18 +++++++ 4 files changed, 73 insertions(+), 4 deletions(-) diff --git a/src/backend/langflow/interface/prompts/custom.py b/src/backend/langflow/interface/prompts/custom.py index 286210271..ef16f1474 100644 --- a/src/backend/langflow/interface/prompts/custom.py +++ b/src/backend/langflow/interface/prompts/custom.py @@ -71,7 +71,3 @@ Human: {input} CUSTOM_PROMPTS: Dict[str, Type[BaseCustomPrompt]] = { "SeriesCharacterPrompt": SeriesCharacterPrompt } - -if __name__ == "__main__": - prompt = SeriesCharacterPrompt(character="Harry Potter", series="Harry Potter") - print(prompt.template) diff --git a/src/backend/langflow/template/frontend_node/base.py b/src/backend/langflow/template/frontend_node/base.py index 6d00cead0..074a85499 100644 --- a/src/backend/langflow/template/frontend_node/base.py +++ b/src/backend/langflow/template/frontend_node/base.py @@ -27,6 +27,9 @@ class FrontendNode(BaseModel): def add_extra_fields(self) -> None: pass + def add_extra_base_classes(self) -> None: + pass + @staticmethod def format_field(field: TemplateField, name: Optional[str] = None) -> None: """Formats a given field based on its attributes and value.""" diff --git a/src/backend/langflow/template/frontend_node/chains.py b/src/backend/langflow/template/frontend_node/chains.py index cfcda2ef4..cb06c90f0 100644 --- a/src/backend/langflow/template/frontend_node/chains.py +++ b/src/backend/langflow/template/frontend_node/chains.py @@ -2,10 +2,24 @@ from typing import Optional from langflow.template.field.base import TemplateField from langflow.template.frontend_node.base import FrontendNode +from langflow.template.frontend_node.constants import QA_CHAIN_TYPES from langflow.template.template.base import Template class ChainFrontendNode(FrontendNode): + def add_extra_fields(self) -> None: + if self.template.type_name == "ConversationalRetrievalChain": + # add memory + self.template.add_field( + TemplateField( + field_type="BaseChatMemory", + required=False, + show=True, + name="memory", + advanced=False, + ) + ) + @staticmethod def format_field(field: TemplateField, name: Optional[str] = None) -> None: FrontendNode.format_field(field, name) @@ -155,3 +169,41 @@ class MidJourneyPromptChainNode(FrontendNode): "ConversationChain", "MidJourneyPromptChain", ] + + +class CombineDocsChainNode(FrontendNode): + name: str = "CombineDocsChain" + template: Template = Template( + type_name="load_qa_chain", + fields=[ + TemplateField( + field_type="str", + required=True, + is_list=True, + show=True, + multiline=False, + options=QA_CHAIN_TYPES, + value=QA_CHAIN_TYPES[0], + name="chain_type", + advanced=False, + ), + TemplateField( + field_type="BaseLanguageModel", + required=True, + show=True, + name="llm", + display_name="LLM", + advanced=False, + ), + ], + ) + description: str = """Construct a zero shot agent from an LLM and tools.""" + base_classes: list[str] = ["BaseCombineDocumentsChain", "function"] + + def to_dict(self): + return super().to_dict() + + @staticmethod + def format_field(field: TemplateField, name: Optional[str] = None) -> None: + # do nothing and don't return anything + pass diff --git a/src/backend/langflow/template/frontend_node/memories.py b/src/backend/langflow/template/frontend_node/memories.py index 91d892627..20c3c9272 100644 --- a/src/backend/langflow/template/frontend_node/memories.py +++ b/src/backend/langflow/template/frontend_node/memories.py @@ -5,6 +5,20 @@ from langflow.template.frontend_node.base import FrontendNode class MemoryFrontendNode(FrontendNode): + #! Needs testing + def add_extra_fields(self) -> None: + # add return_messages field + self.template.add_field( + TemplateField( + field_type="bool", + required=False, + show=True, + name="return_messages", + advanced=False, + value=False, + ) + ) + @staticmethod def format_field(field: TemplateField, name: Optional[str] = None) -> None: FrontendNode.format_field(field, name) @@ -18,3 +32,7 @@ class MemoryFrontendNode(FrontendNode): field.value = 10 field.display_name = "Memory Size" field.password = False + if field.name == "return_messages": + field.required = False + field.show = True + field.advanced = False