From cdb01f2a39195ca6099e57a854e0c9319a45b144 Mon Sep 17 00:00:00 2001 From: Gabriel Almeida Date: Sat, 25 Mar 2023 20:22:50 -0300 Subject: [PATCH] feat: Graph build tests passing --- src/backend/langflow/interface/importing.py | 63 +++++++ src/backend/langflow/interface/listing.py | 14 +- src/backend/langflow/interface/loading.py | 14 ++ src/backend/langflow/utils/graph.py | 173 ++++++++++++++++++-- src/backend/langflow/utils/payload.py | 2 +- 5 files changed, 247 insertions(+), 19 deletions(-) create mode 100644 src/backend/langflow/interface/importing.py diff --git a/src/backend/langflow/interface/importing.py b/src/backend/langflow/interface/importing.py new file mode 100644 index 000000000..8bf156f3f --- /dev/null +++ b/src/backend/langflow/interface/importing.py @@ -0,0 +1,63 @@ +# This module is used to import any langchain class by name. + +import importlib +from typing import Any + +from langchain import PromptTemplate +from langchain.agents import Agent +from langchain.chains.base import Chain +from langchain.llms.base import BaseLLM +from langchain.tools import BaseTool +from langflow.utils.util import get_tools_dict + + +def import_module(module_path: str) -> Any: + """Import module from module path""" + return importlib.import_module(module_path) + + +def import_by_type(_type: str, name: str) -> Any: + """Import class by type and name""" + func_dict = { + "agents": import_agent, + "prompts": import_prompt, + "llms": import_llm, + "tools": import_tool, + "chains": import_chain, + } + return func_dict[_type](name) + + +def import_class(class_path: str) -> Any: + """Import class from class path""" + module_path, class_name = class_path.rsplit(".", 1) + module = import_module(module_path) + return getattr(module, class_name) + + +def import_prompt(prompt: str) -> PromptTemplate: + """Import prompt from prompt name""" + if prompt == "ZeroShotPrompt": + return import_class("langchain.prompts.PromptTemplate") + return import_class(f"langchain.prompts.{prompt}") + + +def import_agent(agent: str) -> Agent: + """Import agent from agent name""" + return import_class(f"langchain.agents.{agent}") + + +def import_llm(llm: str) -> BaseLLM: + """Import llm from llm name""" + return import_class(f"langchain.llms.{llm}") + + +def import_tool(tool: str) -> BaseTool: + """Import tool from tool name""" + + return get_tools_dict(tool) + + +def import_chain(chain: str) -> Chain: + """Import chain from chain name""" + return import_class(f"langchain.chains.{chain}") diff --git a/src/backend/langflow/interface/listing.py b/src/backend/langflow/interface/listing.py index 1df2d25e7..268e6c128 100644 --- a/src/backend/langflow/interface/listing.py +++ b/src/backend/langflow/interface/listing.py @@ -6,16 +6,20 @@ from langchain.agents.load_tools import get_all_tool_names from langchain.chains.conversation import memory as memories -def list_type(object_type: str): - """List all components""" +def get_type_dict(): return { "chains": list_chain_types, "agents": list_agents, "prompts": list_prompts, "llms": list_llms, "tools": list_tools, - "memories": list_memories, - }.get(object_type, lambda: "Invalid type")() + # "memories": list_memories, + } + + +def list_type(object_type: str): + """List all components""" + return get_type_dict().get(object_type, lambda: "Invalid type")() def list_agents(): @@ -48,7 +52,7 @@ def list_tools(): tool_params = util.get_tool_params(util.get_tools_dict(tool)) if tool_params and tool_params["name"] in allowed_components.TOOLS: tools.append(tool_params["name"]) - + tools.append("BaseTool") return tools diff --git a/src/backend/langflow/interface/loading.py b/src/backend/langflow/interface/loading.py index 3c2f0bd4f..59cda2b8a 100644 --- a/src/backend/langflow/interface/loading.py +++ b/src/backend/langflow/interface/loading.py @@ -18,6 +18,7 @@ from langchain.agents.load_tools import ( _EXTRA_OPTIONAL_TOOLS, ) from langflow.utils.graph import Graph +from langchain.agents import agent as agent_module def load_flow_from_json(path: str): @@ -94,6 +95,19 @@ def load_agent_executor_from_config( ) +def load_agent_executor( + agent_class: agent_module.Agent, allowed_tools, llm_chain, **kwargs +): + """Load agent executor from agent class, tools and chain""" + tool_names = [tool.name for tool in allowed_tools] + agent = agent_class(allowed_tools=tool_names, llm_chain=llm_chain) + return AgentExecutor.from_agent_and_tools( + agent=agent, + tools=allowed_tools, + **kwargs, + ) + + def load_tools_from_config(tool_list: list[dict]) -> list: """Load tools based on a config list. diff --git a/src/backend/langflow/utils/graph.py b/src/backend/langflow/utils/graph.py index f7df6cc7a..0589ee8f5 100644 --- a/src/backend/langflow/utils/graph.py +++ b/src/backend/langflow/utils/graph.py @@ -1,4 +1,11 @@ from typing import Dict, List, Union +from langflow.interface import listing +from langflow.interface.importing import import_by_type +from langflow.utils import payload, util + +LANGCHAIN_TYPES_DICT = { + k: list_function() for k, list_function in listing.get_type_dict().items() +} class Node: @@ -14,23 +21,141 @@ class Node: # {'type': 'LLMChain', 'node': {'template': {'_type': 'llm_chain', 'memory': {'type': 'BaseMemory', 'required': False, 'placeholder': '', 'list': False, 'show': True, 'password': False, 'multiline': False, 'value': None}, 'verbose': {'type': 'bool', 'required': False, 'placeholder': '', 'list': False, 'show': False, 'password': False, 'multiline': False, 'value': False}, 'prompt': {'type': 'BasePromptTemplate', 'required': True, 'placeholder': '', 'list': False, 'show': True, 'password': False, 'multiline': False}, 'llm': {'type': 'BaseLanguageModel', 'required': True, 'placeholder': '', 'list': False, 'show': True, 'password': False, 'multiline': False}, 'output_key': {'type': 'str', 'required': False, 'placeholder': '', 'list': False, 'show': False, 'password': True, 'multiline': False, 'value': 'text'}}, 'description': 'Chain to run queries against LLMs.', 'base_classes': ['Chain']}, 'id': 'dndnode_1', 'value': None} # base_classes are the classes that the node can be cast to self.output = self.data["node"]["base_classes"] - template_dict = { + template_dicts = { key: value for key, value in self.data["node"]["template"].items() if isinstance(value, dict) } self.required_inputs = [ - template_dict[key]["type"] - for key, value in template_dict.items() + template_dicts[key]["type"] + for key, value in template_dicts.items() if value["required"] ] self.optional_inputs = [ - template_dict[key]["type"] - for key, value in template_dict.items() + template_dicts[key]["type"] + for key, value in template_dicts.items() if not value["required"] ] + template_dict = self.data["node"]["template"] + self.module_type = ( + self.data["type"] if "Tool" not in self.output else template_dict["_type"] + ) + + def _build_params(self) -> Dict: + # Some params are required, some are optional + # but most importantly, some params are python base classes + # like str and others are LangChain objects like LLMChain, BasePromptTemplate + # so we need to be able to distinguish between the two + + # The dicts with "type" == "str" are the ones that are python base classes + # and most likely have a "value" key + + # So for each key besides "_type" in the template dict, we have a dict + # with a "type" key. If the type is not "str", then we need to get the + # edge that connects to that node and get the Node with the required data + # and use that as the value for the param + # If the type is "str", then we need to get the value of the "value" key + # and use that as the value for the param + template_dict = { + key: value + for key, value in self.data["node"]["template"].items() + if isinstance(value, dict) + } + params = {} + for key, value in template_dict.items(): + if key == "_type": + continue + if value["type"] not in ["str", "bool"]: + # Get the edge that connects to this node + edge = next( + ( + edge + for edge in self.edges + if edge.target == self and edge.matched_type in value["type"] + ), + None, + ) + # Get the output of the node that the edge connects to + # if the value['list'] is True, then there will be more + # than one time setting to params[key] + # so we need to append to a list if it exists + # or create a new list if it doesn't + if edge is None and value["required"]: + raise ValueError( + f"Required input {key} for module {self.module_type} is not connected" + ) + if value["list"]: + if key in params: + params[key].append(edge.source) + else: + params[key] = [edge.source] + else: + if not value["required"] and edge is None: + continue + + params[key] = edge.source + else: + if not value["required"] and not value.get("value"): + continue + params[key] = value["value"] + + # Add _type to params + self.params = params + + def build(self): + from langflow.interface.loading import load_agent_executor + + # The params dict is used to build the module + # it contains values and keys that point to nodes which + # have their own params dict + # When build is called, we iterate through the params dict + # and if the value is a node, we call build on that node + # and use the output of that build as the value for the param + # if the value is not a node, then we use the value as the param + # and continue + # Another aspect is that the module_type is the class that we need to import + # and instantiate with these built params + + # Build each node in the params dict + for key, value in self.params.items(): + # Check if Node or list of Nodes + if isinstance(value, Node): + self.params[key] = value.build() + + elif isinstance(value, list) and all( + isinstance(node, Node) for node in value + ): + self.params[key] = [node.build() for node in value] # type: ignore + + # Get the class from LANGCHAIN_TYPES_DICT + # and instantiate it with the params + # and return the instance + instance = None + for key, value in LANGCHAIN_TYPES_DICT.items(): + if key == "tools": + value = util.get_tools_dict() + if self.module_type in value: + class_object = import_by_type(_type=key, name=self.module_type) + if key == "agents": + # We need to initialize it differently + allowed_tools = self.params["allowed_tools"] + llm_chain = self.params["llm_chain"] + instance = load_agent_executor( + class_object, allowed_tools, llm_chain + ) + elif key == "tools": + instance = class_object(**self.params) + elif self.module_type == "ZeroShotPrompt": + from langchain.agents import ZeroShotAgent + + instance = ZeroShotAgent.create_prompt(**self.params, tools=[]) + else: + instance = class_object(**self.params) + break + return instance + def add_edge(self, edge: "Edge") -> None: self.edges.append(edge) @@ -52,19 +177,33 @@ class Edge: def validate_edge(self) -> None: # Validate that the outputs of the source node are valid inputs for the target node - self.coming_out = self.source.output - self.going_in = self.target.required_inputs + self.target.optional_inputs + self.source_types = self.source.output + self.target_reqs = self.target.required_inputs + self.target.optional_inputs # Both lists contain strings and sometimes a string contains the value we are looking for - # e.g. comgin_out=["Chain"] and going_in=["LLMChain"] - # so we need to check if any of the strings in coming_out is in going_in + # e.g. comgin_out=["Chain"] and target_reqs=["LLMChain"] + # so we need to check if any of the strings in source_types is in target_reqs self.valid = any( - output in going_in - for output in self.coming_out - for going_in in self.going_in + output in target_req + for output in self.source_types + for target_req in self.target_reqs + ) + # Get what type of input the target node is expecting + + self.matched_type = next( + ( + output + for output in self.source_types + for target_req in self.target_reqs + if output in target_req + ), + None, ) def __repr__(self) -> str: - return f"Edge(source={self.source.id}, target={self.target.id}, valid={self.valid}, coming_out={self.coming_out}, going_in={self.going_in})" + return ( + f"Edge(source={self.source.id}, target={self.target.id}, valid={self.valid}" + f", matched_type={self.matched_type})" + ) class Graph: @@ -84,6 +223,9 @@ class Graph: edge.source.add_edge(edge) edge.target.add_edge(edge) + for node in self.nodes: + node._build_params() + def get_node(self, node_id: str) -> Union[None, Node]: return next((node for node in self.nodes if node.id == node_id), None) @@ -93,6 +235,11 @@ class Graph: ] return connected_nodes + def build(self) -> List[Node]: + # Get root node + root_node = payload.get_root_node(self) + return root_node.build() + def get_node_neighbors(self, node: Node) -> Dict[str, int]: neighbors: Dict[Node, int] = {} for edge in self.edges: diff --git a/src/backend/langflow/utils/payload.py b/src/backend/langflow/utils/payload.py index 6ee7152ce..5fe3d8c56 100644 --- a/src/backend/langflow/utils/payload.py +++ b/src/backend/langflow/utils/payload.py @@ -47,7 +47,7 @@ def build_json(root: Node, graph: Graph) -> Dict: else: # Otherwise, find all children whose type matches the type # specified in the template - module_type = root.data["node"]["template"]["_type"] + module_type = root.module_type local_nodes = graph.get_nodes_with_target(root) if len(local_nodes) == 1: