feat: toolkit node implementation

This commit is contained in:
Gabriel Almeida 2023-03-31 14:07:23 -03:00
commit ef4fe40ce4
2 changed files with 88 additions and 12 deletions

View file

@ -1,15 +1,25 @@
from typing import Dict, List, Union
from langflow.utils import payload
from langflow.interface.tools.constants import ALL_TOOLS_NAMES
from langflow.graph.base import Node, Edge
from langflow.graph.base import Edge, Node
from langflow.graph.nodes import (
AgentNode,
ChainNode,
FileToolNode,
LLMNode,
PromptNode,
ToolkitNode,
ToolNode,
WrapperNode,
)
from langflow.interface.agents.base import agent_creator
from langflow.interface.chains.base import chain_creator
from langflow.interface.llms.base import llm_creator
from langflow.interface.prompts.base import prompt_creator
from langflow.interface.toolkits.base import toolkits_creator
from langflow.interface.tools.base import tool_creator
from langflow.interface.tools.constants import ALL_TOOLS_NAMES, FILE_TOOLS
from langflow.interface.wrappers.base import wrapper_creator
from langflow.utils import payload
class Graph:
@ -84,16 +94,22 @@ class Graph:
node_type: str = node_data["type"] # type: ignore
node_lc_type: str = node_data["node"]["template"]["_type"] # type: ignore
if node_type in {"ZeroShotPrompt", "PromptTemplate"}:
if node_type in prompt_creator.to_list():
nodes.append(PromptNode(node))
elif "agent" in node_type.lower():
elif node_type in agent_creator.to_list():
nodes.append(AgentNode(node))
elif "chain" in node_type.lower():
elif node_type in chain_creator.to_list():
nodes.append(ChainNode(node))
elif "tool" in node_type.lower() or node_lc_type in ALL_TOOLS_NAMES:
elif node_type in tool_creator.to_list() or node_lc_type in ALL_TOOLS_NAMES:
if node_type in FILE_TOOLS:
nodes.append(FileToolNode(node))
nodes.append(ToolNode(node))
elif "toolkit" in node_type.lower():
elif node_type in toolkits_creator.to_list():
nodes.append(ToolkitNode(node))
elif node_type in wrapper_creator.to_list():
nodes.append(WrapperNode(node))
elif node_type in llm_creator.to_list():
nodes.append(LLMNode(node))
else:
nodes.append(Node(node))
return nodes

View file

@ -1,12 +1,14 @@
import json
from copy import deepcopy
from typing import Any, Dict, List, Optional, Union
from langflow.graph.base import Node
from langflow.interface.toolkits.base import toolkits_creator
class AgentNode(Node):
def __init__(self, data: Dict):
super().__init__(data)
super().__init__(data, base_type="agents")
self.tools: List[ToolNode] = []
self.chains: List[ChainNode] = []
@ -35,7 +37,7 @@ class AgentNode(Node):
class ToolNode(Node):
def __init__(self, data: Dict):
super().__init__(data)
super().__init__(data, base_type="tools")
def build(self, force: bool = False) -> Any:
if not self._built or force:
@ -45,7 +47,7 @@ class ToolNode(Node):
class PromptNode(Node):
def __init__(self, data: Dict):
super().__init__(data)
super().__init__(data, base_type="prompts")
def build(
self,
@ -68,7 +70,7 @@ class PromptNode(Node):
class ChainNode(Node):
def __init__(self, data: Dict):
super().__init__(data)
super().__init__(data, base_type="chains")
def build(
self,
@ -87,6 +89,52 @@ class ChainNode(Node):
class ToolkitNode(Node):
def __init__(self, data: Dict):
super().__init__(data, base_type="toolkits")
def build(self, force: bool = False) -> Any:
if not self._built or force:
if toolkits_creator.has_create_function(self.node_type):
self.find_llm()
self._build()
# Now that the toolkit is built, we need to find the llm
# and add it to the self.params
# go through the edges and find the llm
return deepcopy(self._built_object)
def find_llm(self, node=None, edges_visited=[]) -> None:
if node is None:
node = self
# Move recursively through the edges
# the targets of this node edges are this node
# If we find an LLMNode, we add it to the params
if len(node.edges) == 1:
return
for edge in node.edges:
source = edge.source
if source in edges_visited:
continue
edges_visited.append(source)
if isinstance(source, LLMNode):
self.params["llm"] = source.build()
break
else:
self.find_llm(source, edges_visited)
class LLMNode(Node):
def __init__(self, data: Dict):
super().__init__(data, base_type="llms")
def build(self, force: bool = False) -> Any:
if not self._built or force:
self._build()
return deepcopy(self._built_object)
class FileToolNode(ToolNode):
def __init__(self, data: Dict):
super().__init__(data)
@ -94,3 +142,15 @@ class ToolkitNode(Node):
if not self._built or force:
self._build()
return deepcopy(self._built_object)
class WrapperNode(Node):
def __init__(self, data: Dict):
super().__init__(data, base_type="wrappers")
def build(self, force: bool = False) -> Any:
if not self._built or force:
if "headers" in self.params:
self.params["headers"] = eval(self.params["headers"])
self._build()
return deepcopy(self._built_object)