Add extra base class and modify field formatting in ChainFrontendNode

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-01-25 15:14:46 -03:00
commit 952f8d39e5

View file

@ -7,6 +7,11 @@ from langflow.template.template.base import Template
class ChainFrontendNode(FrontendNode):
output_type: str = "Chain"
def add_extra_base_classes(self) -> None:
self.base_classes.append("Text")
def add_extra_fields(self) -> None:
if self.template.type_name == "ConversationalRetrievalChain":
# add memory
@ -48,16 +53,16 @@ class ChainFrontendNode(FrontendNode):
@staticmethod
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
FrontendNode.format_field(field, name)
key = field.name or ""
if "name" == "RetrievalQA" and key == "memory":
if "name" == "RetrievalQA" and field.name == "memory":
field.show = False
field.required = False
field.advanced = False
if "key" in key:
if "key" in field.name:
field.password = False
field.show = False
if key in ["input_key", "output_key"]:
if field.name in ["input_key", "output_key"]:
field.required = True
field.show = True
field.advanced = True
@ -71,26 +76,26 @@ class ChainFrontendNode(FrontendNode):
# field.value = field.value.template
# Separated for possible future changes
if key == "prompt" and field.value is None:
if field.name == "prompt" and field.value is None:
field.required = True
field.show = True
field.advanced = False
if key == "memory":
if field.name == "memory":
# field.required = False
field.show = True
field.advanced = False
if key == "verbose":
if field.name == "verbose":
field.required = False
field.show = False
field.advanced = True
if key == "llm":
if field.name == "llm":
field.required = True
field.show = True
field.advanced = False
field.field_type = "BaseLanguageModel" # temporary fix
field.field_type = "BaseLanguageModel"
field.is_list = False
if key == "return_source_documents":
if field.name == "return_source_documents":
field.required = False
field.show = True
field.advanced = True
@ -98,6 +103,7 @@ class ChainFrontendNode(FrontendNode):
class SeriesCharacterChainNode(FrontendNode):
output_type: str = "Chain"
name: str = "SeriesCharacterChain"
template: Template = Template(
type_name="SeriesCharacterChain",
@ -135,20 +141,19 @@ class SeriesCharacterChainNode(FrontendNode):
),
],
)
description: str = (
"SeriesCharacterChain is a chain you can use to have a conversation with a character from a series." # noqa
)
description: str = "SeriesCharacterChain is a chain you can use to have a conversation with a character from a series." # noqa
base_classes: list[str] = [
"LLMChain",
"BaseCustomChain",
"Chain",
"ConversationChain",
"SeriesCharacterChain",
"Callable",
"function",
]
class TimeTravelGuideChainNode(FrontendNode):
output_type: str = "Chain"
name: str = "TimeTravelGuideChain"
template: Template = Template(
type_name="TimeTravelGuideChain",
@ -184,6 +189,7 @@ class TimeTravelGuideChainNode(FrontendNode):
class MidJourneyPromptChainNode(FrontendNode):
output_type: str = "Chain"
name: str = "MidJourneyPromptChain"
template: Template = Template(
type_name="MidJourneyPromptChain",
@ -219,6 +225,7 @@ class MidJourneyPromptChainNode(FrontendNode):
class CombineDocsChainNode(FrontendNode):
output_type: str = "Chain"
name: str = "CombineDocsChain"
template: Template = Template(
type_name="load_qa_chain",
@ -245,7 +252,10 @@ class CombineDocsChainNode(FrontendNode):
],
)
description: str = """Load question answering chain."""
base_classes: list[str] = ["BaseCombineDocumentsChain", "Callable"]
base_classes: list[str] = ["BaseCombineDocumentsChain", "function"]
def to_dict(self):
return super().to_dict()
@staticmethod
def format_field(field: TemplateField, name: Optional[str] = None) -> None: