From 4b14aef8a73428046651b05f5400298d8987fbc6 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Mon, 4 Mar 2024 18:52:59 -0300 Subject: [PATCH] Refactor process.py for readability and maintainability --- src/backend/langflow/processing/process.py | 42 ++++++++++++++++------ 1 file changed, 31 insertions(+), 11 deletions(-) diff --git a/src/backend/langflow/processing/process.py b/src/backend/langflow/processing/process.py index 408db0a53..e2af56d21 100644 --- a/src/backend/langflow/processing/process.py +++ b/src/backend/langflow/processing/process.py @@ -126,7 +126,9 @@ async def process_runnable(runnable: Runnable, inputs: Union[dict, List[dict]]): elif isinstance(inputs, dict) and hasattr(runnable, "ainvoke"): result = await runnable.ainvoke(inputs) else: - raise ValueError(f"Runnable {runnable} does not support inputs of type {type(inputs)}") + raise ValueError( + f"Runnable {runnable} does not support inputs of type {type(inputs)}" + ) # Check if the result is a list of AIMessages if isinstance(result, list) and all(isinstance(r, AIMessage) for r in result): result = [r.content for r in result] @@ -135,7 +137,9 @@ async def process_runnable(runnable: Runnable, inputs: Union[dict, List[dict]]): return result -async def process_inputs_dict(built_object: Union[Chain, VectorStore, Runnable], inputs: dict): +async def process_inputs_dict( + built_object: Union[Chain, VectorStore, Runnable], inputs: dict +): if isinstance(built_object, Chain): if inputs is None: raise ValueError("Inputs must be provided for a Chain") @@ -170,7 +174,9 @@ async def process_inputs_list(built_object: Runnable, inputs: List[dict]): return await process_runnable(built_object, inputs) -async def generate_result(built_object: Union[Chain, VectorStore, Runnable], inputs: Union[dict, List[dict]]): +async def generate_result( + built_object: Union[Chain, VectorStore, Runnable], inputs: Union[dict, List[dict]] +): if isinstance(inputs, dict): result = await process_inputs_dict(built_object, inputs) elif isinstance(inputs, List) and isinstance(built_object, Runnable): @@ -208,24 +214,30 @@ async def run_graph( else: graph_data = graph._graph_data if not session_id and session_service is not None: - session_id = session_service.generate_key(session_id=flow_id, data_graph=graph_data) + session_id = session_service.generate_key( + session_id=flow_id, data_graph=graph_data + ) if inputs is None: inputs = {} - outputs = await graph.run(inputs, stream=stream) + outputs = await graph.run(inputs, stream=stream, session_id=session_id) if session_id and session_service: session_service.update_session(session_id, (graph, artifacts)) return outputs, session_id -def validate_input(graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]) -> List[Dict[str, Any]]: +def validate_input( + graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]] +) -> List[Dict[str, Any]]: if not isinstance(graph_data, dict) or not isinstance(tweaks, dict): raise ValueError("graph_data and tweaks should be dictionaries") nodes = graph_data.get("data", {}).get("nodes") or graph_data.get("nodes") if not isinstance(nodes, list): - raise ValueError("graph_data should contain a list of nodes under 'data' key or directly under 'nodes' key") + raise ValueError( + "graph_data should contain a list of nodes under 'data' key or directly under 'nodes' key" + ) return nodes @@ -234,7 +246,9 @@ def apply_tweaks(node: Dict[str, Any], node_tweaks: Dict[str, Any]) -> None: template_data = node.get("data", {}).get("node", {}).get("template") if not isinstance(template_data, dict): - logger.warning(f"Template data for node {node.get('id')} should be a dictionary") + logger.warning( + f"Template data for node {node.get('id')} should be a dictionary" + ) return for tweak_name, tweak_value in node_tweaks.items(): @@ -249,7 +263,9 @@ def apply_tweaks_on_vertex(vertex: Vertex, node_tweaks: Dict[str, Any]) -> None: vertex.params[tweak_name] = tweak_value -def process_tweaks(graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]) -> Dict[str, Any]: +def process_tweaks( + graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]] +) -> Dict[str, Any]: """ This function is used to tweak the graph data using the node id and the tweaks dict. @@ -270,7 +286,9 @@ def process_tweaks(graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]] if node_tweaks := tweaks.get(node_id): apply_tweaks(node, node_tweaks) else: - logger.warning("Each node should be a dictionary with an 'id' key of type str") + logger.warning( + "Each node should be a dictionary with an 'id' key of type str" + ) return graph_data @@ -282,6 +300,8 @@ def process_tweaks_on_graph(graph: Graph, tweaks: Dict[str, Dict[str, Any]]): if node_tweaks := tweaks.get(node_id): apply_tweaks_on_vertex(vertex, node_tweaks) else: - logger.warning("Each node should be a Vertex with an 'id' attribute of type str") + logger.warning( + "Each node should be a Vertex with an 'id' attribute of type str" + ) return graph