From ef4fe40ce4ab7f2a31d692d386aa5f35e20f69b6 Mon Sep 17 00:00:00 2001 From: Gabriel Almeida Date: Fri, 31 Mar 2023 14:07:23 -0300 Subject: [PATCH] feat: toolkit node implementation --- src/backend/langflow/graph/graph.py | 32 ++++++++++---- src/backend/langflow/graph/nodes.py | 68 +++++++++++++++++++++++++++-- 2 files changed, 88 insertions(+), 12 deletions(-) diff --git a/src/backend/langflow/graph/graph.py b/src/backend/langflow/graph/graph.py index e7110e977..cd4632b61 100644 --- a/src/backend/langflow/graph/graph.py +++ b/src/backend/langflow/graph/graph.py @@ -1,15 +1,25 @@ from typing import Dict, List, Union -from langflow.utils import payload -from langflow.interface.tools.constants import ALL_TOOLS_NAMES -from langflow.graph.base import Node, Edge +from langflow.graph.base import Edge, Node from langflow.graph.nodes import ( AgentNode, ChainNode, + FileToolNode, + LLMNode, PromptNode, ToolkitNode, ToolNode, + WrapperNode, ) +from langflow.interface.agents.base import agent_creator +from langflow.interface.chains.base import chain_creator +from langflow.interface.llms.base import llm_creator +from langflow.interface.prompts.base import prompt_creator +from langflow.interface.toolkits.base import toolkits_creator +from langflow.interface.tools.base import tool_creator +from langflow.interface.tools.constants import ALL_TOOLS_NAMES, FILE_TOOLS +from langflow.interface.wrappers.base import wrapper_creator +from langflow.utils import payload class Graph: @@ -84,16 +94,22 @@ class Graph: node_type: str = node_data["type"] # type: ignore node_lc_type: str = node_data["node"]["template"]["_type"] # type: ignore - if node_type in {"ZeroShotPrompt", "PromptTemplate"}: + if node_type in prompt_creator.to_list(): nodes.append(PromptNode(node)) - elif "agent" in node_type.lower(): + elif node_type in agent_creator.to_list(): nodes.append(AgentNode(node)) - elif "chain" in node_type.lower(): + elif node_type in chain_creator.to_list(): nodes.append(ChainNode(node)) - elif "tool" in node_type.lower() or node_lc_type in ALL_TOOLS_NAMES: + elif node_type in tool_creator.to_list() or node_lc_type in ALL_TOOLS_NAMES: + if node_type in FILE_TOOLS: + nodes.append(FileToolNode(node)) nodes.append(ToolNode(node)) - elif "toolkit" in node_type.lower(): + elif node_type in toolkits_creator.to_list(): nodes.append(ToolkitNode(node)) + elif node_type in wrapper_creator.to_list(): + nodes.append(WrapperNode(node)) + elif node_type in llm_creator.to_list(): + nodes.append(LLMNode(node)) else: nodes.append(Node(node)) return nodes diff --git a/src/backend/langflow/graph/nodes.py b/src/backend/langflow/graph/nodes.py index 43350b3d9..245e3f2f5 100644 --- a/src/backend/langflow/graph/nodes.py +++ b/src/backend/langflow/graph/nodes.py @@ -1,12 +1,14 @@ +import json from copy import deepcopy from typing import Any, Dict, List, Optional, Union from langflow.graph.base import Node +from langflow.interface.toolkits.base import toolkits_creator class AgentNode(Node): def __init__(self, data: Dict): - super().__init__(data) + super().__init__(data, base_type="agents") self.tools: List[ToolNode] = [] self.chains: List[ChainNode] = [] @@ -35,7 +37,7 @@ class AgentNode(Node): class ToolNode(Node): def __init__(self, data: Dict): - super().__init__(data) + super().__init__(data, base_type="tools") def build(self, force: bool = False) -> Any: if not self._built or force: @@ -45,7 +47,7 @@ class ToolNode(Node): class PromptNode(Node): def __init__(self, data: Dict): - super().__init__(data) + super().__init__(data, base_type="prompts") def build( self, @@ -68,7 +70,7 @@ class PromptNode(Node): class ChainNode(Node): def __init__(self, data: Dict): - super().__init__(data) + super().__init__(data, base_type="chains") def build( self, @@ -87,6 +89,52 @@ class ChainNode(Node): class ToolkitNode(Node): + def __init__(self, data: Dict): + super().__init__(data, base_type="toolkits") + + def build(self, force: bool = False) -> Any: + if not self._built or force: + if toolkits_creator.has_create_function(self.node_type): + self.find_llm() + self._build() + # Now that the toolkit is built, we need to find the llm + # and add it to the self.params + + # go through the edges and find the llm + + return deepcopy(self._built_object) + + def find_llm(self, node=None, edges_visited=[]) -> None: + if node is None: + node = self + # Move recursively through the edges + # the targets of this node edges are this node + # If we find an LLMNode, we add it to the params + if len(node.edges) == 1: + return + for edge in node.edges: + source = edge.source + if source in edges_visited: + continue + edges_visited.append(source) + if isinstance(source, LLMNode): + self.params["llm"] = source.build() + break + else: + self.find_llm(source, edges_visited) + + +class LLMNode(Node): + def __init__(self, data: Dict): + super().__init__(data, base_type="llms") + + def build(self, force: bool = False) -> Any: + if not self._built or force: + self._build() + return deepcopy(self._built_object) + + +class FileToolNode(ToolNode): def __init__(self, data: Dict): super().__init__(data) @@ -94,3 +142,15 @@ class ToolkitNode(Node): if not self._built or force: self._build() return deepcopy(self._built_object) + + +class WrapperNode(Node): + def __init__(self, data: Dict): + super().__init__(data, base_type="wrappers") + + def build(self, force: bool = False) -> Any: + if not self._built or force: + if "headers" in self.params: + self.params["headers"] = eval(self.params["headers"]) + self._build() + return deepcopy(self._built_object)