🔨 refactor(custom.py): remove unused code

 feat(frontend_node): add extra fields to MemoryFrontendNode and ChainFrontendNode
The unused code in custom.py has been removed. The MemoryFrontendNode and ChainFrontendNode classes have been updated to include additional fields that are required for their respective templates. The MemoryFrontendNode now has a return_messages field, and the ChainFrontendNode now has a memory field. These fields are optional and can be toggled on or off as required.
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-06-07 19:46:51 -03:00
commit b5ab95c4cc
4 changed files with 73 additions and 4 deletions

View file

@ -71,7 +71,3 @@ Human: {input}
CUSTOM_PROMPTS: Dict[str, Type[BaseCustomPrompt]] = {
"SeriesCharacterPrompt": SeriesCharacterPrompt
}
if __name__ == "__main__":
prompt = SeriesCharacterPrompt(character="Harry Potter", series="Harry Potter")
print(prompt.template)

View file

@ -27,6 +27,9 @@ class FrontendNode(BaseModel):
def add_extra_fields(self) -> None:
pass
def add_extra_base_classes(self) -> None:
pass
@staticmethod
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
"""Formats a given field based on its attributes and value."""

View file

@ -2,10 +2,24 @@ from typing import Optional
from langflow.template.field.base import TemplateField
from langflow.template.frontend_node.base import FrontendNode
from langflow.template.frontend_node.constants import QA_CHAIN_TYPES
from langflow.template.template.base import Template
class ChainFrontendNode(FrontendNode):
def add_extra_fields(self) -> None:
if self.template.type_name == "ConversationalRetrievalChain":
# add memory
self.template.add_field(
TemplateField(
field_type="BaseChatMemory",
required=False,
show=True,
name="memory",
advanced=False,
)
)
@staticmethod
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
FrontendNode.format_field(field, name)
@ -155,3 +169,41 @@ class MidJourneyPromptChainNode(FrontendNode):
"ConversationChain",
"MidJourneyPromptChain",
]
class CombineDocsChainNode(FrontendNode):
name: str = "CombineDocsChain"
template: Template = Template(
type_name="load_qa_chain",
fields=[
TemplateField(
field_type="str",
required=True,
is_list=True,
show=True,
multiline=False,
options=QA_CHAIN_TYPES,
value=QA_CHAIN_TYPES[0],
name="chain_type",
advanced=False,
),
TemplateField(
field_type="BaseLanguageModel",
required=True,
show=True,
name="llm",
display_name="LLM",
advanced=False,
),
],
)
description: str = """Construct a zero shot agent from an LLM and tools."""
base_classes: list[str] = ["BaseCombineDocumentsChain", "function"]
def to_dict(self):
return super().to_dict()
@staticmethod
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
# do nothing and don't return anything
pass

View file

@ -5,6 +5,20 @@ from langflow.template.frontend_node.base import FrontendNode
class MemoryFrontendNode(FrontendNode):
#! Needs testing
def add_extra_fields(self) -> None:
# add return_messages field
self.template.add_field(
TemplateField(
field_type="bool",
required=False,
show=True,
name="return_messages",
advanced=False,
value=False,
)
)
@staticmethod
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
FrontendNode.format_field(field, name)
@ -18,3 +32,7 @@ class MemoryFrontendNode(FrontendNode):
field.value = 10
field.display_name = "Memory Size"
field.password = False
if field.name == "return_messages":
field.required = False
field.show = True
field.advanced = False