feat: Graph build tests passing
This commit is contained in:
parent
8ccc22086b
commit
cdb01f2a39
5 changed files with 247 additions and 19 deletions
63
src/backend/langflow/interface/importing.py
Normal file
63
src/backend/langflow/interface/importing.py
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
# This module is used to import any langchain class by name.
|
||||
|
||||
import importlib
|
||||
from typing import Any
|
||||
|
||||
from langchain import PromptTemplate
|
||||
from langchain.agents import Agent
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.tools import BaseTool
|
||||
from langflow.utils.util import get_tools_dict
|
||||
|
||||
|
||||
def import_module(module_path: str) -> Any:
|
||||
"""Import module from module path"""
|
||||
return importlib.import_module(module_path)
|
||||
|
||||
|
||||
def import_by_type(_type: str, name: str) -> Any:
|
||||
"""Import class by type and name"""
|
||||
func_dict = {
|
||||
"agents": import_agent,
|
||||
"prompts": import_prompt,
|
||||
"llms": import_llm,
|
||||
"tools": import_tool,
|
||||
"chains": import_chain,
|
||||
}
|
||||
return func_dict[_type](name)
|
||||
|
||||
|
||||
def import_class(class_path: str) -> Any:
|
||||
"""Import class from class path"""
|
||||
module_path, class_name = class_path.rsplit(".", 1)
|
||||
module = import_module(module_path)
|
||||
return getattr(module, class_name)
|
||||
|
||||
|
||||
def import_prompt(prompt: str) -> PromptTemplate:
|
||||
"""Import prompt from prompt name"""
|
||||
if prompt == "ZeroShotPrompt":
|
||||
return import_class("langchain.prompts.PromptTemplate")
|
||||
return import_class(f"langchain.prompts.{prompt}")
|
||||
|
||||
|
||||
def import_agent(agent: str) -> Agent:
|
||||
"""Import agent from agent name"""
|
||||
return import_class(f"langchain.agents.{agent}")
|
||||
|
||||
|
||||
def import_llm(llm: str) -> BaseLLM:
|
||||
"""Import llm from llm name"""
|
||||
return import_class(f"langchain.llms.{llm}")
|
||||
|
||||
|
||||
def import_tool(tool: str) -> BaseTool:
|
||||
"""Import tool from tool name"""
|
||||
|
||||
return get_tools_dict(tool)
|
||||
|
||||
|
||||
def import_chain(chain: str) -> Chain:
|
||||
"""Import chain from chain name"""
|
||||
return import_class(f"langchain.chains.{chain}")
|
||||
|
|
@ -6,16 +6,20 @@ from langchain.agents.load_tools import get_all_tool_names
|
|||
from langchain.chains.conversation import memory as memories
|
||||
|
||||
|
||||
def list_type(object_type: str):
|
||||
"""List all components"""
|
||||
def get_type_dict():
|
||||
return {
|
||||
"chains": list_chain_types,
|
||||
"agents": list_agents,
|
||||
"prompts": list_prompts,
|
||||
"llms": list_llms,
|
||||
"tools": list_tools,
|
||||
"memories": list_memories,
|
||||
}.get(object_type, lambda: "Invalid type")()
|
||||
# "memories": list_memories,
|
||||
}
|
||||
|
||||
|
||||
def list_type(object_type: str):
|
||||
"""List all components"""
|
||||
return get_type_dict().get(object_type, lambda: "Invalid type")()
|
||||
|
||||
|
||||
def list_agents():
|
||||
|
|
@ -48,7 +52,7 @@ def list_tools():
|
|||
tool_params = util.get_tool_params(util.get_tools_dict(tool))
|
||||
if tool_params and tool_params["name"] in allowed_components.TOOLS:
|
||||
tools.append(tool_params["name"])
|
||||
|
||||
tools.append("BaseTool")
|
||||
return tools
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from langchain.agents.load_tools import (
|
|||
_EXTRA_OPTIONAL_TOOLS,
|
||||
)
|
||||
from langflow.utils.graph import Graph
|
||||
from langchain.agents import agent as agent_module
|
||||
|
||||
|
||||
def load_flow_from_json(path: str):
|
||||
|
|
@ -94,6 +95,19 @@ def load_agent_executor_from_config(
|
|||
)
|
||||
|
||||
|
||||
def load_agent_executor(
|
||||
agent_class: agent_module.Agent, allowed_tools, llm_chain, **kwargs
|
||||
):
|
||||
"""Load agent executor from agent class, tools and chain"""
|
||||
tool_names = [tool.name for tool in allowed_tools]
|
||||
agent = agent_class(allowed_tools=tool_names, llm_chain=llm_chain)
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent,
|
||||
tools=allowed_tools,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def load_tools_from_config(tool_list: list[dict]) -> list:
|
||||
"""Load tools based on a config list.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,11 @@
|
|||
from typing import Dict, List, Union
|
||||
from langflow.interface import listing
|
||||
from langflow.interface.importing import import_by_type
|
||||
from langflow.utils import payload, util
|
||||
|
||||
LANGCHAIN_TYPES_DICT = {
|
||||
k: list_function() for k, list_function in listing.get_type_dict().items()
|
||||
}
|
||||
|
||||
|
||||
class Node:
|
||||
|
|
@ -14,23 +21,141 @@ class Node:
|
|||
# {'type': 'LLMChain', 'node': {'template': {'_type': 'llm_chain', 'memory': {'type': 'BaseMemory', 'required': False, 'placeholder': '', 'list': False, 'show': True, 'password': False, 'multiline': False, 'value': None}, 'verbose': {'type': 'bool', 'required': False, 'placeholder': '', 'list': False, 'show': False, 'password': False, 'multiline': False, 'value': False}, 'prompt': {'type': 'BasePromptTemplate', 'required': True, 'placeholder': '', 'list': False, 'show': True, 'password': False, 'multiline': False}, 'llm': {'type': 'BaseLanguageModel', 'required': True, 'placeholder': '', 'list': False, 'show': True, 'password': False, 'multiline': False}, 'output_key': {'type': 'str', 'required': False, 'placeholder': '', 'list': False, 'show': False, 'password': True, 'multiline': False, 'value': 'text'}}, 'description': 'Chain to run queries against LLMs.', 'base_classes': ['Chain']}, 'id': 'dndnode_1', 'value': None}
|
||||
# base_classes are the classes that the node can be cast to
|
||||
self.output = self.data["node"]["base_classes"]
|
||||
template_dict = {
|
||||
template_dicts = {
|
||||
key: value
|
||||
for key, value in self.data["node"]["template"].items()
|
||||
if isinstance(value, dict)
|
||||
}
|
||||
|
||||
self.required_inputs = [
|
||||
template_dict[key]["type"]
|
||||
for key, value in template_dict.items()
|
||||
template_dicts[key]["type"]
|
||||
for key, value in template_dicts.items()
|
||||
if value["required"]
|
||||
]
|
||||
self.optional_inputs = [
|
||||
template_dict[key]["type"]
|
||||
for key, value in template_dict.items()
|
||||
template_dicts[key]["type"]
|
||||
for key, value in template_dicts.items()
|
||||
if not value["required"]
|
||||
]
|
||||
|
||||
template_dict = self.data["node"]["template"]
|
||||
self.module_type = (
|
||||
self.data["type"] if "Tool" not in self.output else template_dict["_type"]
|
||||
)
|
||||
|
||||
def _build_params(self) -> Dict:
|
||||
# Some params are required, some are optional
|
||||
# but most importantly, some params are python base classes
|
||||
# like str and others are LangChain objects like LLMChain, BasePromptTemplate
|
||||
# so we need to be able to distinguish between the two
|
||||
|
||||
# The dicts with "type" == "str" are the ones that are python base classes
|
||||
# and most likely have a "value" key
|
||||
|
||||
# So for each key besides "_type" in the template dict, we have a dict
|
||||
# with a "type" key. If the type is not "str", then we need to get the
|
||||
# edge that connects to that node and get the Node with the required data
|
||||
# and use that as the value for the param
|
||||
# If the type is "str", then we need to get the value of the "value" key
|
||||
# and use that as the value for the param
|
||||
template_dict = {
|
||||
key: value
|
||||
for key, value in self.data["node"]["template"].items()
|
||||
if isinstance(value, dict)
|
||||
}
|
||||
params = {}
|
||||
for key, value in template_dict.items():
|
||||
if key == "_type":
|
||||
continue
|
||||
if value["type"] not in ["str", "bool"]:
|
||||
# Get the edge that connects to this node
|
||||
edge = next(
|
||||
(
|
||||
edge
|
||||
for edge in self.edges
|
||||
if edge.target == self and edge.matched_type in value["type"]
|
||||
),
|
||||
None,
|
||||
)
|
||||
# Get the output of the node that the edge connects to
|
||||
# if the value['list'] is True, then there will be more
|
||||
# than one time setting to params[key]
|
||||
# so we need to append to a list if it exists
|
||||
# or create a new list if it doesn't
|
||||
if edge is None and value["required"]:
|
||||
raise ValueError(
|
||||
f"Required input {key} for module {self.module_type} is not connected"
|
||||
)
|
||||
if value["list"]:
|
||||
if key in params:
|
||||
params[key].append(edge.source)
|
||||
else:
|
||||
params[key] = [edge.source]
|
||||
else:
|
||||
if not value["required"] and edge is None:
|
||||
continue
|
||||
|
||||
params[key] = edge.source
|
||||
else:
|
||||
if not value["required"] and not value.get("value"):
|
||||
continue
|
||||
params[key] = value["value"]
|
||||
|
||||
# Add _type to params
|
||||
self.params = params
|
||||
|
||||
def build(self):
|
||||
from langflow.interface.loading import load_agent_executor
|
||||
|
||||
# The params dict is used to build the module
|
||||
# it contains values and keys that point to nodes which
|
||||
# have their own params dict
|
||||
# When build is called, we iterate through the params dict
|
||||
# and if the value is a node, we call build on that node
|
||||
# and use the output of that build as the value for the param
|
||||
# if the value is not a node, then we use the value as the param
|
||||
# and continue
|
||||
# Another aspect is that the module_type is the class that we need to import
|
||||
# and instantiate with these built params
|
||||
|
||||
# Build each node in the params dict
|
||||
for key, value in self.params.items():
|
||||
# Check if Node or list of Nodes
|
||||
if isinstance(value, Node):
|
||||
self.params[key] = value.build()
|
||||
|
||||
elif isinstance(value, list) and all(
|
||||
isinstance(node, Node) for node in value
|
||||
):
|
||||
self.params[key] = [node.build() for node in value] # type: ignore
|
||||
|
||||
# Get the class from LANGCHAIN_TYPES_DICT
|
||||
# and instantiate it with the params
|
||||
# and return the instance
|
||||
instance = None
|
||||
for key, value in LANGCHAIN_TYPES_DICT.items():
|
||||
if key == "tools":
|
||||
value = util.get_tools_dict()
|
||||
if self.module_type in value:
|
||||
class_object = import_by_type(_type=key, name=self.module_type)
|
||||
if key == "agents":
|
||||
# We need to initialize it differently
|
||||
allowed_tools = self.params["allowed_tools"]
|
||||
llm_chain = self.params["llm_chain"]
|
||||
instance = load_agent_executor(
|
||||
class_object, allowed_tools, llm_chain
|
||||
)
|
||||
elif key == "tools":
|
||||
instance = class_object(**self.params)
|
||||
elif self.module_type == "ZeroShotPrompt":
|
||||
from langchain.agents import ZeroShotAgent
|
||||
|
||||
instance = ZeroShotAgent.create_prompt(**self.params, tools=[])
|
||||
else:
|
||||
instance = class_object(**self.params)
|
||||
break
|
||||
return instance
|
||||
|
||||
def add_edge(self, edge: "Edge") -> None:
|
||||
self.edges.append(edge)
|
||||
|
||||
|
|
@ -52,19 +177,33 @@ class Edge:
|
|||
|
||||
def validate_edge(self) -> None:
|
||||
# Validate that the outputs of the source node are valid inputs for the target node
|
||||
self.coming_out = self.source.output
|
||||
self.going_in = self.target.required_inputs + self.target.optional_inputs
|
||||
self.source_types = self.source.output
|
||||
self.target_reqs = self.target.required_inputs + self.target.optional_inputs
|
||||
# Both lists contain strings and sometimes a string contains the value we are looking for
|
||||
# e.g. comgin_out=["Chain"] and going_in=["LLMChain"]
|
||||
# so we need to check if any of the strings in coming_out is in going_in
|
||||
# e.g. comgin_out=["Chain"] and target_reqs=["LLMChain"]
|
||||
# so we need to check if any of the strings in source_types is in target_reqs
|
||||
self.valid = any(
|
||||
output in going_in
|
||||
for output in self.coming_out
|
||||
for going_in in self.going_in
|
||||
output in target_req
|
||||
for output in self.source_types
|
||||
for target_req in self.target_reqs
|
||||
)
|
||||
# Get what type of input the target node is expecting
|
||||
|
||||
self.matched_type = next(
|
||||
(
|
||||
output
|
||||
for output in self.source_types
|
||||
for target_req in self.target_reqs
|
||||
if output in target_req
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Edge(source={self.source.id}, target={self.target.id}, valid={self.valid}, coming_out={self.coming_out}, going_in={self.going_in})"
|
||||
return (
|
||||
f"Edge(source={self.source.id}, target={self.target.id}, valid={self.valid}"
|
||||
f", matched_type={self.matched_type})"
|
||||
)
|
||||
|
||||
|
||||
class Graph:
|
||||
|
|
@ -84,6 +223,9 @@ class Graph:
|
|||
edge.source.add_edge(edge)
|
||||
edge.target.add_edge(edge)
|
||||
|
||||
for node in self.nodes:
|
||||
node._build_params()
|
||||
|
||||
def get_node(self, node_id: str) -> Union[None, Node]:
|
||||
return next((node for node in self.nodes if node.id == node_id), None)
|
||||
|
||||
|
|
@ -93,6 +235,11 @@ class Graph:
|
|||
]
|
||||
return connected_nodes
|
||||
|
||||
def build(self) -> List[Node]:
|
||||
# Get root node
|
||||
root_node = payload.get_root_node(self)
|
||||
return root_node.build()
|
||||
|
||||
def get_node_neighbors(self, node: Node) -> Dict[str, int]:
|
||||
neighbors: Dict[Node, int] = {}
|
||||
for edge in self.edges:
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ def build_json(root: Node, graph: Graph) -> Dict:
|
|||
else:
|
||||
# Otherwise, find all children whose type matches the type
|
||||
# specified in the template
|
||||
module_type = root.data["node"]["template"]["_type"]
|
||||
module_type = root.module_type
|
||||
local_nodes = graph.get_nodes_with_target(root)
|
||||
|
||||
if len(local_nodes) == 1:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue