diff --git a/src/backend/langflow/interface/prompts/custom.py b/src/backend/langflow/interface/prompts/custom.py index 286210271..ef16f1474 100644 --- a/src/backend/langflow/interface/prompts/custom.py +++ b/src/backend/langflow/interface/prompts/custom.py @@ -71,7 +71,3 @@ Human: {input} CUSTOM_PROMPTS: Dict[str, Type[BaseCustomPrompt]] = { "SeriesCharacterPrompt": SeriesCharacterPrompt } - -if __name__ == "__main__": - prompt = SeriesCharacterPrompt(character="Harry Potter", series="Harry Potter") - print(prompt.template) diff --git a/src/backend/langflow/template/frontend_node/base.py b/src/backend/langflow/template/frontend_node/base.py index 6d00cead0..074a85499 100644 --- a/src/backend/langflow/template/frontend_node/base.py +++ b/src/backend/langflow/template/frontend_node/base.py @@ -27,6 +27,9 @@ class FrontendNode(BaseModel): def add_extra_fields(self) -> None: pass + def add_extra_base_classes(self) -> None: + pass + @staticmethod def format_field(field: TemplateField, name: Optional[str] = None) -> None: """Formats a given field based on its attributes and value.""" diff --git a/src/backend/langflow/template/frontend_node/chains.py b/src/backend/langflow/template/frontend_node/chains.py index cfcda2ef4..cb06c90f0 100644 --- a/src/backend/langflow/template/frontend_node/chains.py +++ b/src/backend/langflow/template/frontend_node/chains.py @@ -2,10 +2,24 @@ from typing import Optional from langflow.template.field.base import TemplateField from langflow.template.frontend_node.base import FrontendNode +from langflow.template.frontend_node.constants import QA_CHAIN_TYPES from langflow.template.template.base import Template class ChainFrontendNode(FrontendNode): + def add_extra_fields(self) -> None: + if self.template.type_name == "ConversationalRetrievalChain": + # add memory + self.template.add_field( + TemplateField( + field_type="BaseChatMemory", + required=False, + show=True, + name="memory", + advanced=False, + ) + ) + @staticmethod def format_field(field: TemplateField, name: Optional[str] = None) -> None: FrontendNode.format_field(field, name) @@ -155,3 +169,41 @@ class MidJourneyPromptChainNode(FrontendNode): "ConversationChain", "MidJourneyPromptChain", ] + + +class CombineDocsChainNode(FrontendNode): + name: str = "CombineDocsChain" + template: Template = Template( + type_name="load_qa_chain", + fields=[ + TemplateField( + field_type="str", + required=True, + is_list=True, + show=True, + multiline=False, + options=QA_CHAIN_TYPES, + value=QA_CHAIN_TYPES[0], + name="chain_type", + advanced=False, + ), + TemplateField( + field_type="BaseLanguageModel", + required=True, + show=True, + name="llm", + display_name="LLM", + advanced=False, + ), + ], + ) + description: str = """Construct a zero shot agent from an LLM and tools.""" + base_classes: list[str] = ["BaseCombineDocumentsChain", "function"] + + def to_dict(self): + return super().to_dict() + + @staticmethod + def format_field(field: TemplateField, name: Optional[str] = None) -> None: + # do nothing and don't return anything + pass diff --git a/src/backend/langflow/template/frontend_node/memories.py b/src/backend/langflow/template/frontend_node/memories.py index 91d892627..20c3c9272 100644 --- a/src/backend/langflow/template/frontend_node/memories.py +++ b/src/backend/langflow/template/frontend_node/memories.py @@ -5,6 +5,20 @@ from langflow.template.frontend_node.base import FrontendNode class MemoryFrontendNode(FrontendNode): + #! Needs testing + def add_extra_fields(self) -> None: + # add return_messages field + self.template.add_field( + TemplateField( + field_type="bool", + required=False, + show=True, + name="return_messages", + advanced=False, + value=False, + ) + ) + @staticmethod def format_field(field: TemplateField, name: Optional[str] = None) -> None: FrontendNode.format_field(field, name) @@ -18,3 +32,7 @@ class MemoryFrontendNode(FrontendNode): field.value = 10 field.display_name = "Memory Size" field.password = False + if field.name == "return_messages": + field.required = False + field.show = True + field.advanced = False