diff --git a/src/backend/langflow/processing/process.py b/src/backend/langflow/processing/process.py index ad4f8fb78..69e47b242 100644 --- a/src/backend/langflow/processing/process.py +++ b/src/backend/langflow/processing/process.py @@ -7,6 +7,9 @@ from langchain.schema import AgentAction, Document from langchain_community.vectorstores import VectorStore from langchain_core.messages import AIMessage from langchain_core.runnables.base import Runnable +from loguru import logger +from pydantic import BaseModel + from langflow.graph.graph.base import Graph from langflow.graph.vertex.base import Vertex from langflow.interface.custom.custom_component import CustomComponent @@ -17,8 +20,6 @@ from langflow.interface.run import ( ) from langflow.services.deps import get_session_service from langflow.services.session.service import SessionService -from loguru import logger -from pydantic import BaseModel def fix_memory_inputs(langchain_object): @@ -146,7 +147,9 @@ async def process_runnable(runnable: Runnable, inputs: Union[dict, List[dict]]): elif isinstance(inputs, dict) and hasattr(runnable, "ainvoke"): result = await runnable.ainvoke(inputs) else: - raise ValueError(f"Runnable {runnable} does not support inputs of type {type(inputs)}") + raise ValueError( + f"Runnable {runnable} does not support inputs of type {type(inputs)}" + ) # Check if the result is a list of AIMessages if isinstance(result, list) and all(isinstance(r, AIMessage) for r in result): result = [r.content for r in result] @@ -155,7 +158,9 @@ async def process_runnable(runnable: Runnable, inputs: Union[dict, List[dict]]): return result -async def process_inputs_dict(built_object: Union[Chain, VectorStore, Runnable], inputs: dict): +async def process_inputs_dict( + built_object: Union[Chain, VectorStore, Runnable], inputs: dict +): if isinstance(built_object, Chain): if inputs is None: raise ValueError("Inputs must be provided for a Chain") @@ -190,7 +195,9 @@ async def process_inputs_list(built_object: Runnable, inputs: List[dict]): return await process_runnable(built_object, inputs) -async def generate_result(built_object: Union[Chain, VectorStore, Runnable], inputs: Union[dict, List[dict]]): +async def generate_result( + built_object: Union[Chain, VectorStore, Runnable], inputs: Union[dict, List[dict]] +): if isinstance(inputs, dict): result = await process_inputs_dict(built_object, inputs) elif isinstance(inputs, List) and isinstance(built_object, Runnable): @@ -222,7 +229,9 @@ async def process_graph_cached( if clear_cache: session_service.clear_session(session_id) if session_id is None: - session_id = session_service.generate_key(session_id=session_id, data_graph=data_graph) + session_id = session_service.generate_key( + session_id=session_id, data_graph=data_graph + ) # Load the graph using SessionService session = await session_service.load_session(session_id, data_graph) graph, artifacts = session if session else (None, None) @@ -258,14 +267,34 @@ async def build_graph_and_generate_result( return Result(result=result, session_id=session_id) -def validate_input(graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]) -> List[Dict[str, Any]]: +async def run_graph( + graph: Union["Graph", dict], + session_id: str, + inputs: Optional[Union[dict, List[dict]]] = None, + artifacts: Optional[Dict[str, Any]] = None, + session_service: Optional[SessionService] = None, +): + """Run the graph and generate the result""" + if isinstance(graph, dict): + graph = Graph.from_payload(graph) + outputs = await graph.run(inputs) + if session_id and session_service: + session_service.update_session(session_id, (graph, artifacts)) + return outputs + + +def validate_input( + graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]] +) -> List[Dict[str, Any]]: if not isinstance(graph_data, dict) or not isinstance(tweaks, dict): raise ValueError("graph_data and tweaks should be dictionaries") nodes = graph_data.get("data", {}).get("nodes") or graph_data.get("nodes") if not isinstance(nodes, list): - raise ValueError("graph_data should contain a list of nodes under 'data' key or directly under 'nodes' key") + raise ValueError( + "graph_data should contain a list of nodes under 'data' key or directly under 'nodes' key" + ) return nodes @@ -274,7 +303,9 @@ def apply_tweaks(node: Dict[str, Any], node_tweaks: Dict[str, Any]) -> None: template_data = node.get("data", {}).get("node", {}).get("template") if not isinstance(template_data, dict): - logger.warning(f"Template data for node {node.get('id')} should be a dictionary") + logger.warning( + f"Template data for node {node.get('id')} should be a dictionary" + ) return for tweak_name, tweak_value in node_tweaks.items(): @@ -289,7 +320,9 @@ def apply_tweaks_on_vertex(vertex: Vertex, node_tweaks: Dict[str, Any]) -> None: vertex.params[tweak_name] = tweak_value -def process_tweaks(graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]) -> Dict[str, Any]: +def process_tweaks( + graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]] +) -> Dict[str, Any]: """ This function is used to tweak the graph data using the node id and the tweaks dict. @@ -310,7 +343,9 @@ def process_tweaks(graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]] if node_tweaks := tweaks.get(node_id): apply_tweaks(node, node_tweaks) else: - logger.warning("Each node should be a dictionary with an 'id' key of type str") + logger.warning( + "Each node should be a dictionary with an 'id' key of type str" + ) return graph_data @@ -322,6 +357,8 @@ def process_tweaks_on_graph(graph: Graph, tweaks: Dict[str, Dict[str, Any]]): if node_tweaks := tweaks.get(node_id): apply_tweaks_on_vertex(vertex, node_tweaks) else: - logger.warning("Each node should be a Vertex with an 'id' attribute of type str") + logger.warning( + "Each node should be a Vertex with an 'id' attribute of type str" + ) return graph