refac: implemented buid_json using graph methods

This commit is contained in:
Gabriel Almeida 2023-03-24 12:22:27 -03:00
commit 230f0d95e9
3 changed files with 28 additions and 15 deletions

View file

@ -53,13 +53,10 @@ class Graph:
def get_node(self, node_id: str) -> Union[None, Node]:
return next((node for node in self.nodes if node.id == node_id), None)
def get_connected_nodes(self, node_id: str) -> List[Node]:
connected_nodes: List[Node] = []
for edge in self.edges:
if edge.source.id == node_id:
connected_nodes.append(edge.target)
elif edge.target.id == node_id:
connected_nodes.append(edge.source)
def get_nodes_with_target(self, node: Node) -> List[Node]:
connected_nodes: List[Node] = [
edge.source for edge in self.edges if edge.target == node
]
return connected_nodes
def get_node_neighbors(self, node: Node) -> Dict[str, int]:

View file

@ -1,5 +1,8 @@
import contextlib
import re
from typing import Dict
from langflow.utils.graph import Graph, Node
def extract_input_variables(nodes):
@ -35,14 +38,23 @@ def get_root_node(graph):
return next((node for node in graph.nodes if node not in incoming_edges), None)
def build_json(root, graph):
edge_ids = [edge.source for edge in graph.edges if edge.target == root]
local_nodes = [node for node in graph.nodes if node in edge_ids]
def build_json(root: Node, graph: Graph) -> Dict:
if "node" not in root.data:
return build_json(local_nodes[0], graph)
# If the root node has no "node" key, then it has only one child,
# which is the target of the single outgoing edge
edge = root.edges[0]
local_nodes = [edge.target]
else:
# Otherwise, find all children whose type matches the type
# specified in the template
module_type = root.data["node"]["template"]["_type"]
local_nodes = graph.get_nodes_with_target(root)
final_dict = root.data["node"]["template"].copy()
if len(local_nodes) == 1:
return build_json(local_nodes[0], graph)
# Build a dictionary from the template
template = root.data["node"]["template"]
final_dict = template.copy()
for key, value in final_dict.items():
if key == "_type":
@ -51,10 +63,13 @@ def build_json(root, graph):
module_type = value["type"]
if "value" in value and value["value"] is not None:
# If the value is specified, use it
value = value["value"]
elif "dict" in module_type:
# If the value is a dictionary, create an empty dictionary
value = {}
else:
# Otherwise, recursively build the child nodes
children = []
for local_node in local_nodes:
module_types = [local_node.data["type"]]
@ -68,4 +83,5 @@ def build_json(root, graph):
values = [build_json(child, graph) for child in children]
value = list(values) if value["list"] else next(iter(values), None)
final_dict[key] = value
return final_dict

View file

@ -17,14 +17,14 @@ def get_graph(basic=True):
return Graph(nodes, edges)
def test_get_connected_nodes():
def test_get_nodes_with_target():
"""Test getting connected nodes"""
graph = get_graph()
assert isinstance(graph, Graph)
# Get root node
root = get_root_node(graph)
assert root is not None
connected_nodes = graph.get_connected_nodes(root)
connected_nodes = graph.get_nodes_with_target(root)
assert connected_nodes is not None