🐛 fix(base.py): import process_flow function to fix NameError in Graph class
✨ feat(utils.py): add process_flow function to recursively process and ungroup nodes in a flow object
This commit is contained in:
parent
6ec51a0d22
commit
72f64caa41
2 changed files with 200 additions and 0 deletions
|
|
@ -2,6 +2,7 @@ from typing import Dict, Generator, List, Type, Union
|
|||
|
||||
from langflow.graph.edge.base import Edge
|
||||
from langflow.graph.graph.constants import lazy_load_vertex_dict
|
||||
from langflow.graph.graph.utils import process_flow
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
from langflow.graph.vertex.types import (
|
||||
FileToolVertex,
|
||||
|
|
@ -39,6 +40,7 @@ class Graph:
|
|||
"""
|
||||
if "data" in payload:
|
||||
payload = payload["data"]
|
||||
payload = process_flow(payload)
|
||||
try:
|
||||
nodes = payload["nodes"]
|
||||
edges = payload["edges"]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,198 @@
|
|||
import copy
|
||||
|
||||
|
||||
def find_last_node(data):
|
||||
"""
|
||||
This function receives a flow and returns the last node.
|
||||
"""
|
||||
nodes, edges = data["nodes"], data["edges"]
|
||||
return next((n for n in nodes if all(e["source"] != n["id"] for e in edges)), None)
|
||||
|
||||
|
||||
def ungroup_node(group_node_data, base_flow):
|
||||
template, flow = (
|
||||
group_node_data["node"]["template"],
|
||||
group_node_data["node"]["flow"],
|
||||
)
|
||||
g_nodes = flow["data"]["nodes"]
|
||||
g_edges = flow["data"]["edges"]
|
||||
|
||||
# Redirect edges to the correct proxy node
|
||||
updated_edges = get_updated_edges(base_flow, g_nodes, group_node_data["id"])
|
||||
|
||||
# Update template values
|
||||
update_template(template, g_nodes)
|
||||
|
||||
nodes = [
|
||||
n for n in base_flow["nodes"] if n["id"] != group_node_data["id"]
|
||||
] + g_nodes
|
||||
edges = (
|
||||
[
|
||||
e
|
||||
for e in base_flow["edges"]
|
||||
if e["target"] != group_node_data["id"]
|
||||
and e["source"] != group_node_data["id"]
|
||||
]
|
||||
+ g_edges
|
||||
+ updated_edges
|
||||
)
|
||||
|
||||
base_flow["nodes"] = nodes
|
||||
base_flow["edges"] = edges
|
||||
|
||||
|
||||
def process_flow(flow_object):
|
||||
cloned_flow = copy.deepcopy(flow_object)
|
||||
|
||||
def process_node(node):
|
||||
if (
|
||||
node.get("data")
|
||||
and node["data"].get("node")
|
||||
and node["data"]["node"].get("flow")
|
||||
):
|
||||
process_flow(node["data"]["node"]["flow"]["data"])
|
||||
ungroup_node(node["data"], cloned_flow)
|
||||
|
||||
for node in cloned_flow["nodes"]:
|
||||
process_node(node)
|
||||
|
||||
return cloned_flow
|
||||
|
||||
|
||||
def update_template(template, g_nodes):
|
||||
"""
|
||||
Updates the template of a node in a graph with the given template.
|
||||
|
||||
Args:
|
||||
template (dict): The new template to update the node with.
|
||||
g_nodes (list): The list of nodes in the graph.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for key in template.keys():
|
||||
field, id_ = template[key]["proxy"]
|
||||
node_index = next((i for i, n in enumerate(g_nodes) if n["id"] == id_), -1)
|
||||
if node_index != -1:
|
||||
display_name = None
|
||||
show = g_nodes[node_index]["data"]["node"]["template"][field]["show"]
|
||||
advanced = g_nodes[node_index]["data"]["node"]["template"][field][
|
||||
"advanced"
|
||||
]
|
||||
if "display_name" in g_nodes[node_index]["data"]["node"]["template"][field]:
|
||||
display_name = g_nodes[node_index]["data"]["node"]["template"][field][
|
||||
"display_name"
|
||||
]
|
||||
else:
|
||||
display_name = g_nodes[node_index]["data"]["node"]["template"][field][
|
||||
"name"
|
||||
]
|
||||
|
||||
g_nodes[node_index]["data"]["node"]["template"][field] = template[key]
|
||||
g_nodes[node_index]["data"]["node"]["template"][field]["show"] = show
|
||||
g_nodes[node_index]["data"]["node"]["template"][field][
|
||||
"advanced"
|
||||
] = advanced
|
||||
g_nodes[node_index]["data"]["node"]["template"][field][
|
||||
"display_name"
|
||||
] = display_name
|
||||
|
||||
|
||||
def update_target_handle(new_edge, g_nodes, group_node_id):
|
||||
"""
|
||||
Updates the target handle of a given edge if it is a proxy node.
|
||||
|
||||
Args:
|
||||
new_edge (dict): The edge to update.
|
||||
g_nodes (list): The list of nodes in the graph.
|
||||
group_node_id (str): The ID of the group node.
|
||||
|
||||
Returns:
|
||||
dict: The updated edge.
|
||||
"""
|
||||
target_handle = new_edge["data"]["targetHandle"]
|
||||
if target_handle.get("proxy"):
|
||||
proxy_id = target_handle["proxy"]["id"]
|
||||
if node := next((n for n in g_nodes if n["id"] == proxy_id), None):
|
||||
set_new_target_handle(proxy_id, new_edge, target_handle, node)
|
||||
return new_edge
|
||||
|
||||
|
||||
def set_new_target_handle(proxy_id, new_edge, target_handle, node):
|
||||
"""
|
||||
Sets a new target handle for a given edge.
|
||||
|
||||
Args:
|
||||
proxy_id (str): The ID of the proxy.
|
||||
new_edge (dict): The new edge to be created.
|
||||
target_handle (dict): The target handle of the edge.
|
||||
node (dict): The node containing the edge.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
new_edge["target"] = proxy_id
|
||||
_type = target_handle.get("type")
|
||||
if _type is None:
|
||||
raise KeyError("The 'type' key must be present in target_handle.")
|
||||
|
||||
field = target_handle["proxy"]["field"]
|
||||
new_target_handle = {
|
||||
"fieldName": field,
|
||||
"type": _type,
|
||||
"id": proxy_id,
|
||||
}
|
||||
if node["data"]["node"]["flow"]:
|
||||
new_target_handle["proxy"] = {
|
||||
"field": node["data"]["node"]["template"][field]["proxy"]["field"],
|
||||
"id": node["data"]["node"]["template"][field]["proxy"]["id"],
|
||||
}
|
||||
if input_types := target_handle.get("inputTypes"):
|
||||
new_target_handle["inputTypes"] = input_types
|
||||
new_edge["data"]["targetHandle"] = new_target_handle
|
||||
|
||||
|
||||
def update_source_handle(new_edge, flow_data):
|
||||
"""
|
||||
Updates the source handle of a given edge to the last node in the flow data.
|
||||
|
||||
Args:
|
||||
new_edge (dict): The edge to update.
|
||||
flow_data (dict): The flow data containing the nodes and edges.
|
||||
|
||||
Returns:
|
||||
dict: The updated edge with the new source handle.
|
||||
"""
|
||||
last_node = copy.deepcopy(find_last_node(flow_data))
|
||||
new_edge["source"] = last_node["id"]
|
||||
new_source_handle = new_edge["data"]["sourceHandle"]
|
||||
new_source_handle["id"] = last_node["id"]
|
||||
new_edge["data"]["sourceHandle"] = new_source_handle
|
||||
return new_edge
|
||||
|
||||
|
||||
def get_updated_edges(base_flow, g_nodes, group_node_id):
|
||||
"""
|
||||
Given a base flow, a list of graph nodes and a group node id, returns a list of updated edges.
|
||||
An updated edge is an edge that has its target or source handle updated based on the group node id.
|
||||
|
||||
Args:
|
||||
base_flow (dict): The base flow containing a list of edges.
|
||||
g_nodes (list): A list of graph nodes.
|
||||
group_node_id (str): The id of the group node.
|
||||
|
||||
Returns:
|
||||
list: A list of updated edges.
|
||||
"""
|
||||
updated_edges = []
|
||||
for edge in base_flow["edges"]:
|
||||
new_edge = copy.deepcopy(edge)
|
||||
if new_edge["target"] == group_node_id:
|
||||
new_edge = update_target_handle(new_edge, g_nodes, group_node_id)
|
||||
|
||||
if new_edge["source"] == group_node_id:
|
||||
new_edge = update_source_handle(new_edge, g_nodes)
|
||||
|
||||
if new_edge["target"] == group_node_id or new_edge["source"] == group_node_id:
|
||||
updated_edges.append(new_edge)
|
||||
return updated_edges
|
||||
Loading…
Add table
Add a link
Reference in a new issue