diff --git a/src/backend/langflow/graph/graph/utils.py b/src/backend/langflow/graph/graph/utils.py index de6385ba2..fa03519cc 100644 --- a/src/backend/langflow/graph/graph/utils.py +++ b/src/backend/langflow/graph/graph/utils.py @@ -80,8 +80,11 @@ def update_template(template, g_nodes): Returns: None """ - for key in template.keys(): - field, id_ = template[key]["proxy"] + for _, value in template.items(): + if not value.get("proxy"): + continue + proxy_dict = value["proxy"] + field, id_ = proxy_dict["field"], proxy_dict["id"] node_index = next((i for i, n in enumerate(g_nodes) if n["id"] == id_), -1) if node_index != -1: display_name = None @@ -98,7 +101,7 @@ def update_template(template, g_nodes): "name" ] - g_nodes[node_index]["data"]["node"]["template"][field] = template[key] + g_nodes[node_index]["data"]["node"]["template"][field] = value g_nodes[node_index]["data"]["node"]["template"][field]["show"] = show g_nodes[node_index]["data"]["node"]["template"][field][ "advanced" diff --git a/tests/test_graph.py b/tests/test_graph.py index 8c5f92f13..05a60108f 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -36,8 +36,8 @@ from langflow.graph.graph.utils import ( @pytest.fixture def sample_template(): return { - "field1": {"proxy": ["some_field", "node1"]}, - "field2": {"proxy": ["other_field", "node2"]}, + "field1": {"proxy": {"field": "some_field", "id": "node1"}}, + "field2": {"proxy": {"field": "other_field", "id": "node2"}}, } @@ -358,6 +358,7 @@ def test_ungroup_node(grouped_chat_json_flow): def test_process_flow(grouped_chat_json_flow): grouped_chat_data = json.loads(grouped_chat_json_flow).get("data") + processed_flow = process_flow(grouped_chat_data) assert processed_flow is not None assert isinstance(processed_flow, dict) @@ -365,6 +366,41 @@ def test_process_flow(grouped_chat_json_flow): assert "edges" in processed_flow +def test_process_flow_one_group(one_grouped_chat_json_flow): + grouped_chat_data = json.loads(one_grouped_chat_json_flow).get("data") + # There should be only one node + assert len(grouped_chat_data["nodes"]) == 1 + # Get the node, it should be a group node + group_node = grouped_chat_data["nodes"][0] + node_data = group_node["data"]["node"] + assert node_data.get("flow") is not None + template_data = node_data["template"] + assert any("openai_api_key" in key for key in template_data.keys()) + # Get the openai_api_key dict + openai_api_key = next( + (template_data[key] for key in template_data.keys() if "openai_api_key" in key), + None, + ) + assert openai_api_key is not None + assert openai_api_key["value"] == "test" + + processed_flow = process_flow(grouped_chat_data) + assert processed_flow is not None + assert isinstance(processed_flow, dict) + assert "nodes" in processed_flow + assert "edges" in processed_flow + + # Now get the node that has ChatOpenAI in its id + chat_openai_node = next( + (node for node in processed_flow["nodes"] if "ChatOpenAI" in node["id"]), None + ) + assert chat_openai_node is not None + assert ( + chat_openai_node["data"]["node"]["template"]["openai_api_key"]["value"] + == "test" + ) + + def test_update_template(sample_template, sample_nodes): # Making a deep copy to keep original sample_nodes unchanged nodes_copy = copy.deepcopy(sample_nodes)