diff --git a/src/backend/langflow/api/base.py b/src/backend/langflow/api/base.py index 4ae313fd3..084e04d65 100644 --- a/src/backend/langflow/api/base.py +++ b/src/backend/langflow/api/base.py @@ -1,6 +1,7 @@ -from langflow.graph.utils import extract_input_variables_from_prompt from pydantic import BaseModel, validator +from langflow.graph.utils import extract_input_variables_from_prompt + class Code(BaseModel): code: str diff --git a/src/backend/langflow/config.yaml b/src/backend/langflow/config.yaml index 3a0f9d018..9421e56fa 100644 --- a/src/backend/langflow/config.yaml +++ b/src/backend/langflow/config.yaml @@ -29,10 +29,10 @@ llms: - HuggingFaceHub tools: - # - Search + - Search - PAL-MATH - Calculator - # - Serper Search + - Serper Search - Tool - PythonFunction - JsonSpec diff --git a/src/backend/langflow/graph/graph.py b/src/backend/langflow/graph/graph.py index 76cab071d..84b94a6a5 100644 --- a/src/backend/langflow/graph/graph.py +++ b/src/backend/langflow/graph/graph.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Union +from typing import Dict, List, Type, Union from langflow.graph.base import Edge, Node from langflow.graph.nodes import ( @@ -15,13 +15,12 @@ from langflow.graph.nodes import ( from langflow.interface.agents.base import agent_creator from langflow.interface.chains.base import chain_creator from langflow.interface.llms.base import llm_creator +from langflow.interface.memories.base import memory_creator from langflow.interface.prompts.base import prompt_creator from langflow.interface.toolkits.base import toolkits_creator from langflow.interface.tools.base import tool_creator from langflow.interface.tools.constants import FILE_TOOLS -from langflow.interface.tools.util import get_tools_dict from langflow.interface.wrappers.base import wrapper_creator -from langflow.interface.memories.base import memory_creator from langflow.utils import payload @@ -108,6 +107,26 @@ class Graph: edges.append(Edge(source, target)) return edges + def _get_node_class(self, node_type: str, node_lc_type: str) -> Type[Node]: + node_type_map: Dict[str, Type[Node]] = { + **{t: PromptNode for t in prompt_creator.to_list()}, + **{t: AgentNode for t in agent_creator.to_list()}, + **{t: ChainNode for t in chain_creator.to_list()}, + **{t: ToolNode for t in tool_creator.to_list()}, + **{t: ToolkitNode for t in toolkits_creator.to_list()}, + **{t: WrapperNode for t in wrapper_creator.to_list()}, + **{t: LLMNode for t in llm_creator.to_list()}, + **{t: MemoryNode for t in memory_creator.to_list()}, + } + + if node_type in FILE_TOOLS: + return FileToolNode + if node_type in node_type_map: + return node_type_map[node_type] + if node_lc_type in node_type_map: + return node_type_map[node_lc_type] + return Node + def _build_nodes(self) -> List[Node]: nodes: List[Node] = [] for node in self._nodes: @@ -115,38 +134,9 @@ class Graph: node_type: str = node_data["type"] # type: ignore node_lc_type: str = node_data["node"]["template"]["_type"] # type: ignore - if node_type in prompt_creator.to_list(): - nodes.append(PromptNode(node)) - elif ( - node_type in agent_creator.to_list() - or node_lc_type in agent_creator.to_list() - ): - nodes.append(AgentNode(node)) - elif node_type in chain_creator.to_list(): - nodes.append(ChainNode(node)) - elif ( - node_type in tool_creator.to_list() - or node_lc_type in get_tools_dict().keys() - ): - if node_type in FILE_TOOLS: - nodes.append(FileToolNode(node)) - nodes.append(ToolNode(node)) - elif node_type in toolkits_creator.to_list(): - nodes.append(ToolkitNode(node)) - elif node_type in wrapper_creator.to_list(): - nodes.append(WrapperNode(node)) - elif ( - node_type in llm_creator.to_list() - or node_lc_type in llm_creator.to_list() - ): - nodes.append(LLMNode(node)) - elif ( - node_type in memory_creator.to_list() - or node_lc_type in memory_creator.to_list() - ): - nodes.append(MemoryNode(node)) - else: - nodes.append(Node(node)) + NodeClass = self._get_node_class(node_type, node_lc_type) + nodes.append(NodeClass(node)) + return nodes def get_children_by_node_type(self, node: Node, node_type: str) -> List[Node]: diff --git a/tests/test_prompts_template.py b/tests/test_prompts_template.py index 24729cb63..caa30821c 100644 --- a/tests/test_prompts_template.py +++ b/tests/test_prompts_template.py @@ -167,7 +167,7 @@ def test_zero_shot_prompt(client: TestClient): "placeholder": "", "show": True, "multiline": True, - "value": "Answer the following questions as best you can. You have access to the following tools:", # noqa: E501 + "value": "Answer the following questions as best you can. You have access to the following tools:", # noqa: E501 "password": False, "name": "prefix", "type": "str", @@ -189,7 +189,7 @@ def test_zero_shot_prompt(client: TestClient): "placeholder": "", "show": False, "multiline": False, - "value": "Use the following format:\n\nQuestion: the input question you must answer\nThought: you should always think about what to do\nAction: the action to take, should be one of [{tool_names}]\nAction Input: the input to the action\nObservation: the result of the action\n... (this Thought/Action/Action Input/Observation can repeat N times)\nThought: I now know the final answer\nFinal Answer: the final answer to the original input question", # noqa: E501 + "value": "Use the following format:\n\nQuestion: the input question you must answer\nThought: you should always think about what to do\nAction: the action to take, should be one of [{tool_names}]\nAction Input: the input to the action\nObservation: the result of the action\n... (this Thought/Action/Action Input/Observation can repeat N times)\nThought: I now know the final answer\nFinal Answer: the final answer to the original input question", # noqa: E501 "password": False, "name": "format_instructions", "type": "str",