From 83c28dcabe7be9b67de778bc96275877249e08ec Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Mon, 19 Jun 2023 11:59:34 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A8=20refactor(process.py):=20refactor?= =?UTF-8?q?=20process=5Ftweaks=20function=20to=20improve=20readability=20a?= =?UTF-8?q?nd=20maintainability=20=E2=9C=A8=20feat(process.py):=20add=20in?= =?UTF-8?q?put=20validation=20to=20process=5Ftweaks=20function=20The=20pro?= =?UTF-8?q?cess=5Ftweaks=20function=20has=20been=20refactored=20to=20impro?= =?UTF-8?q?ve=20readability=20and=20maintainability.=20The=20apply=5Ftweak?= =?UTF-8?q?s=20function=20has=20been=20added=20to=20apply=20the=20tweaks?= =?UTF-8?q?=20to=20the=20node.=20The=20validate=5Finput=20function=20has?= =?UTF-8?q?=20been=20added=20to=20validate=20the=20input=20parameters.=20T?= =?UTF-8?q?he=20process=5Ftweaks=20function=20now=20raises=20a=20ValueErro?= =?UTF-8?q?r=20if=20the=20input=20is=20not=20in=20the=20expected=20format.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/langflow/processing/process.py | 73 ++++++++++++++++------ 1 file changed, 55 insertions(+), 18 deletions(-) diff --git a/src/backend/langflow/processing/process.py b/src/backend/langflow/processing/process.py index 760f73723..e4ae33d27 100644 --- a/src/backend/langflow/processing/process.py +++ b/src/backend/langflow/processing/process.py @@ -11,7 +11,6 @@ from langflow.interface.run import ( from langflow.utils.logger import logger from langflow.graph import Graph - from typing import Any, Dict, List, Optional, Tuple, Union @@ -179,23 +178,61 @@ def load_flow_from_json( return graph -def process_tweaks(graph_data: Dict, tweaks: Dict): - """This function is used to tweak the graph data using the node id and the tweaks dict""" - # the tweaks dict is a dict of dicts - # the key is the node id and the value is a dict of the tweaks - # the dict of tweaks contains the name of a certain parameter and the value to be tweaked +def validate_input( + graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]] +) -> List[Dict[str, Any]]: + if not isinstance(graph_data, dict) or not isinstance(tweaks, dict): + raise ValueError("graph_data and tweaks should be dictionaries") + + nodes = graph_data.get("data", {}).get("nodes") or graph_data.get("nodes") + + if not isinstance(nodes, list): + raise ValueError( + "graph_data should contain a list of nodes under 'data' key or directly under 'nodes' key" + ) + + return nodes + + +def apply_tweaks(node: Dict[str, Any], node_tweaks: Dict[str, Any]) -> None: + template_data = node.get("data", {}).get("node", {}).get("template") + + if not isinstance(template_data, dict): + logger.warning( + f"Template data for node {node.get('id')} should be a dictionary" + ) + return + + for tweak_name, tweak_value in node_tweaks.items(): + if tweak_name and tweak_value and tweak_name in template_data: + template_data[tweak_name]["value"] = tweak_value + + +def process_tweaks( + graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]] +) -> Dict[str, Any]: + """ + This function is used to tweak the graph data using the node id and the tweaks dict. + + :param graph_data: The dictionary containing the graph data. It must contain a 'data' key with + 'nodes' as its child or directly contain 'nodes' key. Each node should have an 'id' and 'data'. + :param tweaks: A dictionary where the key is the node id and the value is a dictionary of the tweaks. + The inner dictionary contains the name of a certain parameter as the key and the value to be tweaked. + + :return: The modified graph_data dictionary. + + :raises ValueError: If the input is not in the expected format. + """ + nodes = validate_input(graph_data, tweaks) - # We need to process the graph data to add the tweaks - if "data" not in graph_data and "nodes" in graph_data: - nodes = graph_data["nodes"] - else: - nodes = graph_data["data"]["nodes"] for node in nodes: - node_id = node["id"] - if node_id in tweaks: - node_tweaks = tweaks[node_id] - template_data = node["data"]["node"]["template"] - for tweak_name, tweake_value in node_tweaks.items(): - if tweak_name in template_data: - template_data[tweak_name]["value"] = tweake_value + if isinstance(node, dict) and isinstance(node.get("id"), str): + node_id = node["id"] + if node_tweaks := tweaks.get(node_id): + apply_tweaks(node, node_tweaks) + else: + logger.warning( + "Each node should be a dictionary with an 'id' key of type str" + ) + return graph_data