Add extra base class and modify field formatting in ChainFrontendNode
This commit is contained in:
parent
dc56daebc8
commit
952f8d39e5
1 changed files with 25 additions and 15 deletions
|
|
@ -7,6 +7,11 @@ from langflow.template.template.base import Template
|
|||
|
||||
|
||||
class ChainFrontendNode(FrontendNode):
|
||||
output_type: str = "Chain"
|
||||
|
||||
def add_extra_base_classes(self) -> None:
|
||||
self.base_classes.append("Text")
|
||||
|
||||
def add_extra_fields(self) -> None:
|
||||
if self.template.type_name == "ConversationalRetrievalChain":
|
||||
# add memory
|
||||
|
|
@ -48,16 +53,16 @@ class ChainFrontendNode(FrontendNode):
|
|||
@staticmethod
|
||||
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
|
||||
FrontendNode.format_field(field, name)
|
||||
key = field.name or ""
|
||||
if "name" == "RetrievalQA" and key == "memory":
|
||||
|
||||
if "name" == "RetrievalQA" and field.name == "memory":
|
||||
field.show = False
|
||||
field.required = False
|
||||
|
||||
field.advanced = False
|
||||
if "key" in key:
|
||||
if "key" in field.name:
|
||||
field.password = False
|
||||
field.show = False
|
||||
if key in ["input_key", "output_key"]:
|
||||
if field.name in ["input_key", "output_key"]:
|
||||
field.required = True
|
||||
field.show = True
|
||||
field.advanced = True
|
||||
|
|
@ -71,26 +76,26 @@ class ChainFrontendNode(FrontendNode):
|
|||
# field.value = field.value.template
|
||||
|
||||
# Separated for possible future changes
|
||||
if key == "prompt" and field.value is None:
|
||||
if field.name == "prompt" and field.value is None:
|
||||
field.required = True
|
||||
field.show = True
|
||||
field.advanced = False
|
||||
if key == "memory":
|
||||
if field.name == "memory":
|
||||
# field.required = False
|
||||
field.show = True
|
||||
field.advanced = False
|
||||
if key == "verbose":
|
||||
if field.name == "verbose":
|
||||
field.required = False
|
||||
field.show = False
|
||||
field.advanced = True
|
||||
if key == "llm":
|
||||
if field.name == "llm":
|
||||
field.required = True
|
||||
field.show = True
|
||||
field.advanced = False
|
||||
field.field_type = "BaseLanguageModel" # temporary fix
|
||||
field.field_type = "BaseLanguageModel"
|
||||
field.is_list = False
|
||||
|
||||
if key == "return_source_documents":
|
||||
if field.name == "return_source_documents":
|
||||
field.required = False
|
||||
field.show = True
|
||||
field.advanced = True
|
||||
|
|
@ -98,6 +103,7 @@ class ChainFrontendNode(FrontendNode):
|
|||
|
||||
|
||||
class SeriesCharacterChainNode(FrontendNode):
|
||||
output_type: str = "Chain"
|
||||
name: str = "SeriesCharacterChain"
|
||||
template: Template = Template(
|
||||
type_name="SeriesCharacterChain",
|
||||
|
|
@ -135,20 +141,19 @@ class SeriesCharacterChainNode(FrontendNode):
|
|||
),
|
||||
],
|
||||
)
|
||||
description: str = (
|
||||
"SeriesCharacterChain is a chain you can use to have a conversation with a character from a series." # noqa
|
||||
)
|
||||
description: str = "SeriesCharacterChain is a chain you can use to have a conversation with a character from a series." # noqa
|
||||
base_classes: list[str] = [
|
||||
"LLMChain",
|
||||
"BaseCustomChain",
|
||||
"Chain",
|
||||
"ConversationChain",
|
||||
"SeriesCharacterChain",
|
||||
"Callable",
|
||||
"function",
|
||||
]
|
||||
|
||||
|
||||
class TimeTravelGuideChainNode(FrontendNode):
|
||||
output_type: str = "Chain"
|
||||
name: str = "TimeTravelGuideChain"
|
||||
template: Template = Template(
|
||||
type_name="TimeTravelGuideChain",
|
||||
|
|
@ -184,6 +189,7 @@ class TimeTravelGuideChainNode(FrontendNode):
|
|||
|
||||
|
||||
class MidJourneyPromptChainNode(FrontendNode):
|
||||
output_type: str = "Chain"
|
||||
name: str = "MidJourneyPromptChain"
|
||||
template: Template = Template(
|
||||
type_name="MidJourneyPromptChain",
|
||||
|
|
@ -219,6 +225,7 @@ class MidJourneyPromptChainNode(FrontendNode):
|
|||
|
||||
|
||||
class CombineDocsChainNode(FrontendNode):
|
||||
output_type: str = "Chain"
|
||||
name: str = "CombineDocsChain"
|
||||
template: Template = Template(
|
||||
type_name="load_qa_chain",
|
||||
|
|
@ -245,7 +252,10 @@ class CombineDocsChainNode(FrontendNode):
|
|||
],
|
||||
)
|
||||
description: str = """Load question answering chain."""
|
||||
base_classes: list[str] = ["BaseCombineDocumentsChain", "Callable"]
|
||||
base_classes: list[str] = ["BaseCombineDocumentsChain", "function"]
|
||||
|
||||
def to_dict(self):
|
||||
return super().to_dict()
|
||||
|
||||
@staticmethod
|
||||
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue