langflow/src/backend/langflow/template/frontend_node/chains.py
Rodrigo Nader fd59bd77b3 Refactor SQL agent and related nodes
- Rename `VectorStoreAgent` to `Vector store agent` in `custom.py`
- Replace "Construct a sql agent from an LLM and tools." with "Construct an SQL agent from an LLM and tools." in `custom.py`
- Update descriptions of `SQLAgentNode`, `TimeTravelGuideChainNode`, `CombineDocsChainNode`, and `ToolNode` in related files
- Increase width of `ExtraSidebarComponent` container in `index.tsx` from 52 to 56
2023-06-14 23:54:39 -03:00

209 lines
6.2 KiB
Python

from typing import Optional
from langflow.template.field.base import TemplateField
from langflow.template.frontend_node.base import FrontendNode
from langflow.template.frontend_node.constants import QA_CHAIN_TYPES
from langflow.template.template.base import Template
class ChainFrontendNode(FrontendNode):
def add_extra_fields(self) -> None:
if self.template.type_name == "ConversationalRetrievalChain":
# add memory
self.template.add_field(
TemplateField(
field_type="BaseChatMemory",
required=False,
show=True,
name="memory",
advanced=False,
)
)
@staticmethod
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
FrontendNode.format_field(field, name)
field.advanced = False
if "key" in field.name:
field.password = False
field.show = False
if field.name in ["input_key", "output_key"]:
field.required = True
field.show = True
field.advanced = True
# Separated for possible future changes
if field.name == "prompt" and field.value is None:
field.required = True
field.show = True
field.advanced = False
if field.name == "memory":
field.required = False
field.show = True
field.advanced = False
if field.name == "verbose":
field.required = False
field.show = True
field.advanced = True
if field.name == "llm":
field.required = True
field.show = True
field.advanced = False
class SeriesCharacterChainNode(FrontendNode):
name: str = "SeriesCharacterChain"
template: Template = Template(
type_name="SeriesCharacterChain",
fields=[
TemplateField(
field_type="str",
required=True,
placeholder="",
is_list=False,
show=True,
advanced=False,
multiline=False,
name="character",
),
TemplateField(
field_type="str",
required=True,
placeholder="",
is_list=False,
show=True,
advanced=False,
multiline=False,
name="series",
),
TemplateField(
field_type="BaseLanguageModel",
required=True,
placeholder="",
is_list=False,
show=True,
advanced=False,
multiline=False,
name="llm",
display_name="LLM",
),
],
)
description: str = "SeriesCharacterChain is a chain you can use to have a conversation with a character from a series." # noqa
base_classes: list[str] = [
"LLMChain",
"BaseCustomChain",
"Chain",
"ConversationChain",
"SeriesCharacterChain",
"function",
]
class TimeTravelGuideChainNode(FrontendNode):
name: str = "TimeTravelGuideChain"
template: Template = Template(
type_name="TimeTravelGuideChain",
fields=[
TemplateField(
field_type="BaseLanguageModel",
required=True,
placeholder="",
is_list=False,
show=True,
advanced=False,
multiline=False,
name="llm",
display_name="LLM",
),
TemplateField(
field_type="BaseChatMemory",
required=False,
show=True,
name="memory",
advanced=False,
),
],
)
description: str = "Time travel guide chain."
base_classes: list[str] = [
"LLMChain",
"BaseCustomChain",
"TimeTravelGuideChain",
"Chain",
"ConversationChain",
]
class MidJourneyPromptChainNode(FrontendNode):
name: str = "MidJourneyPromptChain"
template: Template = Template(
type_name="MidJourneyPromptChain",
fields=[
TemplateField(
field_type="BaseLanguageModel",
required=True,
placeholder="",
is_list=False,
show=True,
advanced=False,
multiline=False,
name="llm",
display_name="LLM",
),
TemplateField(
field_type="BaseChatMemory",
required=False,
show=True,
name="memory",
advanced=False,
),
],
)
description: str = "MidJourneyPromptChain is a chain you can use to generate new MidJourney prompts."
base_classes: list[str] = [
"LLMChain",
"BaseCustomChain",
"Chain",
"ConversationChain",
"MidJourneyPromptChain",
]
class CombineDocsChainNode(FrontendNode):
name: str = "CombineDocsChain"
template: Template = Template(
type_name="load_qa_chain",
fields=[
TemplateField(
field_type="str",
required=True,
is_list=True,
show=True,
multiline=False,
options=QA_CHAIN_TYPES,
value=QA_CHAIN_TYPES[0],
name="chain_type",
advanced=False,
),
TemplateField(
field_type="BaseLanguageModel",
required=True,
show=True,
name="llm",
display_name="LLM",
advanced=False,
),
],
)
description: str = """Construct a chain from combined documents."""
base_classes: list[str] = ["BaseCombineDocumentsChain", "function"]
def to_dict(self):
return super().to_dict()
@staticmethod
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
# do nothing and don't return anything
pass