🐛 fix(base.py): import process_flow function to fix NameError in Graph class

 feat(utils.py): add process_flow function to recursively process and ungroup nodes in a flow object
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-09-21 14:44:10 -03:00
commit 72f64caa41
2 changed files with 200 additions and 0 deletions

View file

@ -2,6 +2,7 @@ from typing import Dict, Generator, List, Type, Union
from langflow.graph.edge.base import Edge
from langflow.graph.graph.constants import lazy_load_vertex_dict
from langflow.graph.graph.utils import process_flow
from langflow.graph.vertex.base import Vertex
from langflow.graph.vertex.types import (
FileToolVertex,
@ -39,6 +40,7 @@ class Graph:
"""
if "data" in payload:
payload = payload["data"]
payload = process_flow(payload)
try:
nodes = payload["nodes"]
edges = payload["edges"]

View file

@ -0,0 +1,198 @@
import copy
def find_last_node(data):
"""
This function receives a flow and returns the last node.
"""
nodes, edges = data["nodes"], data["edges"]
return next((n for n in nodes if all(e["source"] != n["id"] for e in edges)), None)
def ungroup_node(group_node_data, base_flow):
template, flow = (
group_node_data["node"]["template"],
group_node_data["node"]["flow"],
)
g_nodes = flow["data"]["nodes"]
g_edges = flow["data"]["edges"]
# Redirect edges to the correct proxy node
updated_edges = get_updated_edges(base_flow, g_nodes, group_node_data["id"])
# Update template values
update_template(template, g_nodes)
nodes = [
n for n in base_flow["nodes"] if n["id"] != group_node_data["id"]
] + g_nodes
edges = (
[
e
for e in base_flow["edges"]
if e["target"] != group_node_data["id"]
and e["source"] != group_node_data["id"]
]
+ g_edges
+ updated_edges
)
base_flow["nodes"] = nodes
base_flow["edges"] = edges
def process_flow(flow_object):
cloned_flow = copy.deepcopy(flow_object)
def process_node(node):
if (
node.get("data")
and node["data"].get("node")
and node["data"]["node"].get("flow")
):
process_flow(node["data"]["node"]["flow"]["data"])
ungroup_node(node["data"], cloned_flow)
for node in cloned_flow["nodes"]:
process_node(node)
return cloned_flow
def update_template(template, g_nodes):
"""
Updates the template of a node in a graph with the given template.
Args:
template (dict): The new template to update the node with.
g_nodes (list): The list of nodes in the graph.
Returns:
None
"""
for key in template.keys():
field, id_ = template[key]["proxy"]
node_index = next((i for i, n in enumerate(g_nodes) if n["id"] == id_), -1)
if node_index != -1:
display_name = None
show = g_nodes[node_index]["data"]["node"]["template"][field]["show"]
advanced = g_nodes[node_index]["data"]["node"]["template"][field][
"advanced"
]
if "display_name" in g_nodes[node_index]["data"]["node"]["template"][field]:
display_name = g_nodes[node_index]["data"]["node"]["template"][field][
"display_name"
]
else:
display_name = g_nodes[node_index]["data"]["node"]["template"][field][
"name"
]
g_nodes[node_index]["data"]["node"]["template"][field] = template[key]
g_nodes[node_index]["data"]["node"]["template"][field]["show"] = show
g_nodes[node_index]["data"]["node"]["template"][field][
"advanced"
] = advanced
g_nodes[node_index]["data"]["node"]["template"][field][
"display_name"
] = display_name
def update_target_handle(new_edge, g_nodes, group_node_id):
"""
Updates the target handle of a given edge if it is a proxy node.
Args:
new_edge (dict): The edge to update.
g_nodes (list): The list of nodes in the graph.
group_node_id (str): The ID of the group node.
Returns:
dict: The updated edge.
"""
target_handle = new_edge["data"]["targetHandle"]
if target_handle.get("proxy"):
proxy_id = target_handle["proxy"]["id"]
if node := next((n for n in g_nodes if n["id"] == proxy_id), None):
set_new_target_handle(proxy_id, new_edge, target_handle, node)
return new_edge
def set_new_target_handle(proxy_id, new_edge, target_handle, node):
"""
Sets a new target handle for a given edge.
Args:
proxy_id (str): The ID of the proxy.
new_edge (dict): The new edge to be created.
target_handle (dict): The target handle of the edge.
node (dict): The node containing the edge.
Returns:
None
"""
new_edge["target"] = proxy_id
_type = target_handle.get("type")
if _type is None:
raise KeyError("The 'type' key must be present in target_handle.")
field = target_handle["proxy"]["field"]
new_target_handle = {
"fieldName": field,
"type": _type,
"id": proxy_id,
}
if node["data"]["node"]["flow"]:
new_target_handle["proxy"] = {
"field": node["data"]["node"]["template"][field]["proxy"]["field"],
"id": node["data"]["node"]["template"][field]["proxy"]["id"],
}
if input_types := target_handle.get("inputTypes"):
new_target_handle["inputTypes"] = input_types
new_edge["data"]["targetHandle"] = new_target_handle
def update_source_handle(new_edge, flow_data):
"""
Updates the source handle of a given edge to the last node in the flow data.
Args:
new_edge (dict): The edge to update.
flow_data (dict): The flow data containing the nodes and edges.
Returns:
dict: The updated edge with the new source handle.
"""
last_node = copy.deepcopy(find_last_node(flow_data))
new_edge["source"] = last_node["id"]
new_source_handle = new_edge["data"]["sourceHandle"]
new_source_handle["id"] = last_node["id"]
new_edge["data"]["sourceHandle"] = new_source_handle
return new_edge
def get_updated_edges(base_flow, g_nodes, group_node_id):
"""
Given a base flow, a list of graph nodes and a group node id, returns a list of updated edges.
An updated edge is an edge that has its target or source handle updated based on the group node id.
Args:
base_flow (dict): The base flow containing a list of edges.
g_nodes (list): A list of graph nodes.
group_node_id (str): The id of the group node.
Returns:
list: A list of updated edges.
"""
updated_edges = []
for edge in base_flow["edges"]:
new_edge = copy.deepcopy(edge)
if new_edge["target"] == group_node_id:
new_edge = update_target_handle(new_edge, g_nodes, group_node_id)
if new_edge["source"] == group_node_id:
new_edge = update_source_handle(new_edge, g_nodes)
if new_edge["target"] == group_node_id or new_edge["source"] == group_node_id:
updated_edges.append(new_edge)
return updated_edges