feat: Graph build tests passing

This commit is contained in:
Gabriel Almeida 2023-03-25 20:22:50 -03:00
commit cdb01f2a39
5 changed files with 247 additions and 19 deletions

View file

@ -0,0 +1,63 @@
# This module is used to import any langchain class by name.
import importlib
from typing import Any
from langchain import PromptTemplate
from langchain.agents import Agent
from langchain.chains.base import Chain
from langchain.llms.base import BaseLLM
from langchain.tools import BaseTool
from langflow.utils.util import get_tools_dict
def import_module(module_path: str) -> Any:
"""Import module from module path"""
return importlib.import_module(module_path)
def import_by_type(_type: str, name: str) -> Any:
"""Import class by type and name"""
func_dict = {
"agents": import_agent,
"prompts": import_prompt,
"llms": import_llm,
"tools": import_tool,
"chains": import_chain,
}
return func_dict[_type](name)
def import_class(class_path: str) -> Any:
"""Import class from class path"""
module_path, class_name = class_path.rsplit(".", 1)
module = import_module(module_path)
return getattr(module, class_name)
def import_prompt(prompt: str) -> PromptTemplate:
"""Import prompt from prompt name"""
if prompt == "ZeroShotPrompt":
return import_class("langchain.prompts.PromptTemplate")
return import_class(f"langchain.prompts.{prompt}")
def import_agent(agent: str) -> Agent:
"""Import agent from agent name"""
return import_class(f"langchain.agents.{agent}")
def import_llm(llm: str) -> BaseLLM:
"""Import llm from llm name"""
return import_class(f"langchain.llms.{llm}")
def import_tool(tool: str) -> BaseTool:
"""Import tool from tool name"""
return get_tools_dict(tool)
def import_chain(chain: str) -> Chain:
"""Import chain from chain name"""
return import_class(f"langchain.chains.{chain}")

View file

@ -6,16 +6,20 @@ from langchain.agents.load_tools import get_all_tool_names
from langchain.chains.conversation import memory as memories
def list_type(object_type: str):
"""List all components"""
def get_type_dict():
return {
"chains": list_chain_types,
"agents": list_agents,
"prompts": list_prompts,
"llms": list_llms,
"tools": list_tools,
"memories": list_memories,
}.get(object_type, lambda: "Invalid type")()
# "memories": list_memories,
}
def list_type(object_type: str):
"""List all components"""
return get_type_dict().get(object_type, lambda: "Invalid type")()
def list_agents():
@ -48,7 +52,7 @@ def list_tools():
tool_params = util.get_tool_params(util.get_tools_dict(tool))
if tool_params and tool_params["name"] in allowed_components.TOOLS:
tools.append(tool_params["name"])
tools.append("BaseTool")
return tools

View file

@ -18,6 +18,7 @@ from langchain.agents.load_tools import (
_EXTRA_OPTIONAL_TOOLS,
)
from langflow.utils.graph import Graph
from langchain.agents import agent as agent_module
def load_flow_from_json(path: str):
@ -94,6 +95,19 @@ def load_agent_executor_from_config(
)
def load_agent_executor(
agent_class: agent_module.Agent, allowed_tools, llm_chain, **kwargs
):
"""Load agent executor from agent class, tools and chain"""
tool_names = [tool.name for tool in allowed_tools]
agent = agent_class(allowed_tools=tool_names, llm_chain=llm_chain)
return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=allowed_tools,
**kwargs,
)
def load_tools_from_config(tool_list: list[dict]) -> list:
"""Load tools based on a config list.

View file

@ -1,4 +1,11 @@
from typing import Dict, List, Union
from langflow.interface import listing
from langflow.interface.importing import import_by_type
from langflow.utils import payload, util
LANGCHAIN_TYPES_DICT = {
k: list_function() for k, list_function in listing.get_type_dict().items()
}
class Node:
@ -14,23 +21,141 @@ class Node:
# {'type': 'LLMChain', 'node': {'template': {'_type': 'llm_chain', 'memory': {'type': 'BaseMemory', 'required': False, 'placeholder': '', 'list': False, 'show': True, 'password': False, 'multiline': False, 'value': None}, 'verbose': {'type': 'bool', 'required': False, 'placeholder': '', 'list': False, 'show': False, 'password': False, 'multiline': False, 'value': False}, 'prompt': {'type': 'BasePromptTemplate', 'required': True, 'placeholder': '', 'list': False, 'show': True, 'password': False, 'multiline': False}, 'llm': {'type': 'BaseLanguageModel', 'required': True, 'placeholder': '', 'list': False, 'show': True, 'password': False, 'multiline': False}, 'output_key': {'type': 'str', 'required': False, 'placeholder': '', 'list': False, 'show': False, 'password': True, 'multiline': False, 'value': 'text'}}, 'description': 'Chain to run queries against LLMs.', 'base_classes': ['Chain']}, 'id': 'dndnode_1', 'value': None}
# base_classes are the classes that the node can be cast to
self.output = self.data["node"]["base_classes"]
template_dict = {
template_dicts = {
key: value
for key, value in self.data["node"]["template"].items()
if isinstance(value, dict)
}
self.required_inputs = [
template_dict[key]["type"]
for key, value in template_dict.items()
template_dicts[key]["type"]
for key, value in template_dicts.items()
if value["required"]
]
self.optional_inputs = [
template_dict[key]["type"]
for key, value in template_dict.items()
template_dicts[key]["type"]
for key, value in template_dicts.items()
if not value["required"]
]
template_dict = self.data["node"]["template"]
self.module_type = (
self.data["type"] if "Tool" not in self.output else template_dict["_type"]
)
def _build_params(self) -> Dict:
# Some params are required, some are optional
# but most importantly, some params are python base classes
# like str and others are LangChain objects like LLMChain, BasePromptTemplate
# so we need to be able to distinguish between the two
# The dicts with "type" == "str" are the ones that are python base classes
# and most likely have a "value" key
# So for each key besides "_type" in the template dict, we have a dict
# with a "type" key. If the type is not "str", then we need to get the
# edge that connects to that node and get the Node with the required data
# and use that as the value for the param
# If the type is "str", then we need to get the value of the "value" key
# and use that as the value for the param
template_dict = {
key: value
for key, value in self.data["node"]["template"].items()
if isinstance(value, dict)
}
params = {}
for key, value in template_dict.items():
if key == "_type":
continue
if value["type"] not in ["str", "bool"]:
# Get the edge that connects to this node
edge = next(
(
edge
for edge in self.edges
if edge.target == self and edge.matched_type in value["type"]
),
None,
)
# Get the output of the node that the edge connects to
# if the value['list'] is True, then there will be more
# than one time setting to params[key]
# so we need to append to a list if it exists
# or create a new list if it doesn't
if edge is None and value["required"]:
raise ValueError(
f"Required input {key} for module {self.module_type} is not connected"
)
if value["list"]:
if key in params:
params[key].append(edge.source)
else:
params[key] = [edge.source]
else:
if not value["required"] and edge is None:
continue
params[key] = edge.source
else:
if not value["required"] and not value.get("value"):
continue
params[key] = value["value"]
# Add _type to params
self.params = params
def build(self):
from langflow.interface.loading import load_agent_executor
# The params dict is used to build the module
# it contains values and keys that point to nodes which
# have their own params dict
# When build is called, we iterate through the params dict
# and if the value is a node, we call build on that node
# and use the output of that build as the value for the param
# if the value is not a node, then we use the value as the param
# and continue
# Another aspect is that the module_type is the class that we need to import
# and instantiate with these built params
# Build each node in the params dict
for key, value in self.params.items():
# Check if Node or list of Nodes
if isinstance(value, Node):
self.params[key] = value.build()
elif isinstance(value, list) and all(
isinstance(node, Node) for node in value
):
self.params[key] = [node.build() for node in value] # type: ignore
# Get the class from LANGCHAIN_TYPES_DICT
# and instantiate it with the params
# and return the instance
instance = None
for key, value in LANGCHAIN_TYPES_DICT.items():
if key == "tools":
value = util.get_tools_dict()
if self.module_type in value:
class_object = import_by_type(_type=key, name=self.module_type)
if key == "agents":
# We need to initialize it differently
allowed_tools = self.params["allowed_tools"]
llm_chain = self.params["llm_chain"]
instance = load_agent_executor(
class_object, allowed_tools, llm_chain
)
elif key == "tools":
instance = class_object(**self.params)
elif self.module_type == "ZeroShotPrompt":
from langchain.agents import ZeroShotAgent
instance = ZeroShotAgent.create_prompt(**self.params, tools=[])
else:
instance = class_object(**self.params)
break
return instance
def add_edge(self, edge: "Edge") -> None:
self.edges.append(edge)
@ -52,19 +177,33 @@ class Edge:
def validate_edge(self) -> None:
# Validate that the outputs of the source node are valid inputs for the target node
self.coming_out = self.source.output
self.going_in = self.target.required_inputs + self.target.optional_inputs
self.source_types = self.source.output
self.target_reqs = self.target.required_inputs + self.target.optional_inputs
# Both lists contain strings and sometimes a string contains the value we are looking for
# e.g. comgin_out=["Chain"] and going_in=["LLMChain"]
# so we need to check if any of the strings in coming_out is in going_in
# e.g. comgin_out=["Chain"] and target_reqs=["LLMChain"]
# so we need to check if any of the strings in source_types is in target_reqs
self.valid = any(
output in going_in
for output in self.coming_out
for going_in in self.going_in
output in target_req
for output in self.source_types
for target_req in self.target_reqs
)
# Get what type of input the target node is expecting
self.matched_type = next(
(
output
for output in self.source_types
for target_req in self.target_reqs
if output in target_req
),
None,
)
def __repr__(self) -> str:
return f"Edge(source={self.source.id}, target={self.target.id}, valid={self.valid}, coming_out={self.coming_out}, going_in={self.going_in})"
return (
f"Edge(source={self.source.id}, target={self.target.id}, valid={self.valid}"
f", matched_type={self.matched_type})"
)
class Graph:
@ -84,6 +223,9 @@ class Graph:
edge.source.add_edge(edge)
edge.target.add_edge(edge)
for node in self.nodes:
node._build_params()
def get_node(self, node_id: str) -> Union[None, Node]:
return next((node for node in self.nodes if node.id == node_id), None)
@ -93,6 +235,11 @@ class Graph:
]
return connected_nodes
def build(self) -> List[Node]:
# Get root node
root_node = payload.get_root_node(self)
return root_node.build()
def get_node_neighbors(self, node: Node) -> Dict[str, int]:
neighbors: Dict[Node, int] = {}
for edge in self.edges:

View file

@ -47,7 +47,7 @@ def build_json(root: Node, graph: Graph) -> Dict:
else:
# Otherwise, find all children whose type matches the type
# specified in the template
module_type = root.data["node"]["template"]["_type"]
module_type = root.module_type
local_nodes = graph.get_nodes_with_target(root)
if len(local_nodes) == 1: