diff --git a/src/backend/langflow/processing/process.py b/src/backend/langflow/processing/process.py index 760f73723..e4ae33d27 100644 --- a/src/backend/langflow/processing/process.py +++ b/src/backend/langflow/processing/process.py @@ -11,7 +11,6 @@ from langflow.interface.run import ( from langflow.utils.logger import logger from langflow.graph import Graph - from typing import Any, Dict, List, Optional, Tuple, Union @@ -179,23 +178,61 @@ def load_flow_from_json( return graph -def process_tweaks(graph_data: Dict, tweaks: Dict): - """This function is used to tweak the graph data using the node id and the tweaks dict""" - # the tweaks dict is a dict of dicts - # the key is the node id and the value is a dict of the tweaks - # the dict of tweaks contains the name of a certain parameter and the value to be tweaked +def validate_input( + graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]] +) -> List[Dict[str, Any]]: + if not isinstance(graph_data, dict) or not isinstance(tweaks, dict): + raise ValueError("graph_data and tweaks should be dictionaries") + + nodes = graph_data.get("data", {}).get("nodes") or graph_data.get("nodes") + + if not isinstance(nodes, list): + raise ValueError( + "graph_data should contain a list of nodes under 'data' key or directly under 'nodes' key" + ) + + return nodes + + +def apply_tweaks(node: Dict[str, Any], node_tweaks: Dict[str, Any]) -> None: + template_data = node.get("data", {}).get("node", {}).get("template") + + if not isinstance(template_data, dict): + logger.warning( + f"Template data for node {node.get('id')} should be a dictionary" + ) + return + + for tweak_name, tweak_value in node_tweaks.items(): + if tweak_name and tweak_value and tweak_name in template_data: + template_data[tweak_name]["value"] = tweak_value + + +def process_tweaks( + graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]] +) -> Dict[str, Any]: + """ + This function is used to tweak the graph data using the node id and the tweaks dict. + + :param graph_data: The dictionary containing the graph data. It must contain a 'data' key with + 'nodes' as its child or directly contain 'nodes' key. Each node should have an 'id' and 'data'. + :param tweaks: A dictionary where the key is the node id and the value is a dictionary of the tweaks. + The inner dictionary contains the name of a certain parameter as the key and the value to be tweaked. + + :return: The modified graph_data dictionary. + + :raises ValueError: If the input is not in the expected format. + """ + nodes = validate_input(graph_data, tweaks) - # We need to process the graph data to add the tweaks - if "data" not in graph_data and "nodes" in graph_data: - nodes = graph_data["nodes"] - else: - nodes = graph_data["data"]["nodes"] for node in nodes: - node_id = node["id"] - if node_id in tweaks: - node_tweaks = tweaks[node_id] - template_data = node["data"]["node"]["template"] - for tweak_name, tweake_value in node_tweaks.items(): - if tweak_name in template_data: - template_data[tweak_name]["value"] = tweake_value + if isinstance(node, dict) and isinstance(node.get("id"), str): + node_id = node["id"] + if node_tweaks := tweaks.get(node_id): + apply_tweaks(node, node_tweaks) + else: + logger.warning( + "Each node should be a dictionary with an 'id' key of type str" + ) + return graph_data