refac: implemented buid_json using graph methods
This commit is contained in:
parent
b4963572b0
commit
230f0d95e9
3 changed files with 28 additions and 15 deletions
|
|
@ -53,13 +53,10 @@ class Graph:
|
|||
def get_node(self, node_id: str) -> Union[None, Node]:
|
||||
return next((node for node in self.nodes if node.id == node_id), None)
|
||||
|
||||
def get_connected_nodes(self, node_id: str) -> List[Node]:
|
||||
connected_nodes: List[Node] = []
|
||||
for edge in self.edges:
|
||||
if edge.source.id == node_id:
|
||||
connected_nodes.append(edge.target)
|
||||
elif edge.target.id == node_id:
|
||||
connected_nodes.append(edge.source)
|
||||
def get_nodes_with_target(self, node: Node) -> List[Node]:
|
||||
connected_nodes: List[Node] = [
|
||||
edge.source for edge in self.edges if edge.target == node
|
||||
]
|
||||
return connected_nodes
|
||||
|
||||
def get_node_neighbors(self, node: Node) -> Dict[str, int]:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,8 @@
|
|||
import contextlib
|
||||
import re
|
||||
from typing import Dict
|
||||
|
||||
from langflow.utils.graph import Graph, Node
|
||||
|
||||
|
||||
def extract_input_variables(nodes):
|
||||
|
|
@ -35,14 +38,23 @@ def get_root_node(graph):
|
|||
return next((node for node in graph.nodes if node not in incoming_edges), None)
|
||||
|
||||
|
||||
def build_json(root, graph):
|
||||
edge_ids = [edge.source for edge in graph.edges if edge.target == root]
|
||||
local_nodes = [node for node in graph.nodes if node in edge_ids]
|
||||
|
||||
def build_json(root: Node, graph: Graph) -> Dict:
|
||||
if "node" not in root.data:
|
||||
return build_json(local_nodes[0], graph)
|
||||
# If the root node has no "node" key, then it has only one child,
|
||||
# which is the target of the single outgoing edge
|
||||
edge = root.edges[0]
|
||||
local_nodes = [edge.target]
|
||||
else:
|
||||
# Otherwise, find all children whose type matches the type
|
||||
# specified in the template
|
||||
module_type = root.data["node"]["template"]["_type"]
|
||||
local_nodes = graph.get_nodes_with_target(root)
|
||||
|
||||
final_dict = root.data["node"]["template"].copy()
|
||||
if len(local_nodes) == 1:
|
||||
return build_json(local_nodes[0], graph)
|
||||
# Build a dictionary from the template
|
||||
template = root.data["node"]["template"]
|
||||
final_dict = template.copy()
|
||||
|
||||
for key, value in final_dict.items():
|
||||
if key == "_type":
|
||||
|
|
@ -51,10 +63,13 @@ def build_json(root, graph):
|
|||
module_type = value["type"]
|
||||
|
||||
if "value" in value and value["value"] is not None:
|
||||
# If the value is specified, use it
|
||||
value = value["value"]
|
||||
elif "dict" in module_type:
|
||||
# If the value is a dictionary, create an empty dictionary
|
||||
value = {}
|
||||
else:
|
||||
# Otherwise, recursively build the child nodes
|
||||
children = []
|
||||
for local_node in local_nodes:
|
||||
module_types = [local_node.data["type"]]
|
||||
|
|
@ -68,4 +83,5 @@ def build_json(root, graph):
|
|||
values = [build_json(child, graph) for child in children]
|
||||
value = list(values) if value["list"] else next(iter(values), None)
|
||||
final_dict[key] = value
|
||||
|
||||
return final_dict
|
||||
|
|
|
|||
|
|
@ -17,14 +17,14 @@ def get_graph(basic=True):
|
|||
return Graph(nodes, edges)
|
||||
|
||||
|
||||
def test_get_connected_nodes():
|
||||
def test_get_nodes_with_target():
|
||||
"""Test getting connected nodes"""
|
||||
graph = get_graph()
|
||||
assert isinstance(graph, Graph)
|
||||
# Get root node
|
||||
root = get_root_node(graph)
|
||||
assert root is not None
|
||||
connected_nodes = graph.get_connected_nodes(root)
|
||||
connected_nodes = graph.get_nodes_with_target(root)
|
||||
assert connected_nodes is not None
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue