diff --git a/src/backend/langflow/template/frontend_node/chains.py b/src/backend/langflow/template/frontend_node/chains.py index 5ff211f68..6bd676522 100644 --- a/src/backend/langflow/template/frontend_node/chains.py +++ b/src/backend/langflow/template/frontend_node/chains.py @@ -7,6 +7,11 @@ from langflow.template.template.base import Template class ChainFrontendNode(FrontendNode): + output_type: str = "Chain" + + def add_extra_base_classes(self) -> None: + self.base_classes.append("Text") + def add_extra_fields(self) -> None: if self.template.type_name == "ConversationalRetrievalChain": # add memory @@ -48,16 +53,16 @@ class ChainFrontendNode(FrontendNode): @staticmethod def format_field(field: TemplateField, name: Optional[str] = None) -> None: FrontendNode.format_field(field, name) - key = field.name or "" - if "name" == "RetrievalQA" and key == "memory": + + if "name" == "RetrievalQA" and field.name == "memory": field.show = False field.required = False field.advanced = False - if "key" in key: + if "key" in field.name: field.password = False field.show = False - if key in ["input_key", "output_key"]: + if field.name in ["input_key", "output_key"]: field.required = True field.show = True field.advanced = True @@ -71,26 +76,26 @@ class ChainFrontendNode(FrontendNode): # field.value = field.value.template # Separated for possible future changes - if key == "prompt" and field.value is None: + if field.name == "prompt" and field.value is None: field.required = True field.show = True field.advanced = False - if key == "memory": + if field.name == "memory": # field.required = False field.show = True field.advanced = False - if key == "verbose": + if field.name == "verbose": field.required = False field.show = False field.advanced = True - if key == "llm": + if field.name == "llm": field.required = True field.show = True field.advanced = False - field.field_type = "BaseLanguageModel" # temporary fix + field.field_type = "BaseLanguageModel" field.is_list = False - if key == "return_source_documents": + if field.name == "return_source_documents": field.required = False field.show = True field.advanced = True @@ -98,6 +103,7 @@ class ChainFrontendNode(FrontendNode): class SeriesCharacterChainNode(FrontendNode): + output_type: str = "Chain" name: str = "SeriesCharacterChain" template: Template = Template( type_name="SeriesCharacterChain", @@ -135,20 +141,19 @@ class SeriesCharacterChainNode(FrontendNode): ), ], ) - description: str = ( - "SeriesCharacterChain is a chain you can use to have a conversation with a character from a series." # noqa - ) + description: str = "SeriesCharacterChain is a chain you can use to have a conversation with a character from a series." # noqa base_classes: list[str] = [ "LLMChain", "BaseCustomChain", "Chain", "ConversationChain", "SeriesCharacterChain", - "Callable", + "function", ] class TimeTravelGuideChainNode(FrontendNode): + output_type: str = "Chain" name: str = "TimeTravelGuideChain" template: Template = Template( type_name="TimeTravelGuideChain", @@ -184,6 +189,7 @@ class TimeTravelGuideChainNode(FrontendNode): class MidJourneyPromptChainNode(FrontendNode): + output_type: str = "Chain" name: str = "MidJourneyPromptChain" template: Template = Template( type_name="MidJourneyPromptChain", @@ -219,6 +225,7 @@ class MidJourneyPromptChainNode(FrontendNode): class CombineDocsChainNode(FrontendNode): + output_type: str = "Chain" name: str = "CombineDocsChain" template: Template = Template( type_name="load_qa_chain", @@ -245,7 +252,10 @@ class CombineDocsChainNode(FrontendNode): ], ) description: str = """Load question answering chain.""" - base_classes: list[str] = ["BaseCombineDocumentsChain", "Callable"] + base_classes: list[str] = ["BaseCombineDocumentsChain", "function"] + + def to_dict(self): + return super().to_dict() @staticmethod def format_field(field: TemplateField, name: Optional[str] = None) -> None: