From 639c54e3eefa1aae137ee3e89eb9466dccd557ec Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 27 Feb 2024 23:04:21 -0300 Subject: [PATCH] Add stream parameter to run_flow_with_caching and Graph.run methods --- src/backend/langflow/api/v1/endpoints.py | 3 +++ src/backend/langflow/graph/graph/base.py | 12 +++++++++--- src/backend/langflow/graph/vertex/types.py | 4 ++++ src/backend/langflow/processing/process.py | 3 ++- 4 files changed, 18 insertions(+), 4 deletions(-) diff --git a/src/backend/langflow/api/v1/endpoints.py b/src/backend/langflow/api/v1/endpoints.py index 8da1f0e53..116c63b2c 100644 --- a/src/backend/langflow/api/v1/endpoints.py +++ b/src/backend/langflow/api/v1/endpoints.py @@ -228,6 +228,7 @@ async def run_flow_with_caching( flow_id: str, inputs: Optional[Union[List[dict], dict]] = None, tweaks: Optional[dict] = None, + stream: Annotated[bool, Body(embed=True)] = False, # noqa: F821 session_id: Annotated[Union[None, str], Body(embed=True)] = None, # noqa: F821 api_key_user: User = Depends(api_key_security), session_service: SessionService = Depends(get_session_service), @@ -246,6 +247,7 @@ async def run_flow_with_caching( inputs=inputs, artifacts=artifacts, session_service=session_service, + stream=stream, ) else: @@ -270,6 +272,7 @@ async def run_flow_with_caching( inputs=inputs, artifacts={}, session_service=session_service, + stream=stream, ) return RunResponse(outputs=task_result, session_id=session_id) diff --git a/src/backend/langflow/graph/graph/base.py b/src/backend/langflow/graph/graph/base.py index 3f5e376a5..051fc6b3b 100644 --- a/src/backend/langflow/graph/graph/base.py +++ b/src/backend/langflow/graph/graph/base.py @@ -73,7 +73,7 @@ class Graph: if getattr(vertex, attribute): getattr(self, f"_{attribute}_vertices").append(vertex.id) - async def _run(self, inputs: Dict[str, str]) -> List["ResultData"]: + async def _run(self, inputs: Dict[str, str], stream: bool) -> List["ResultData"]: """Runs the graph with the given inputs.""" for vertex_id in self._is_input_vertices: vertex = self.get_vertex(vertex_id) @@ -91,10 +91,14 @@ class Graph: vertex = self.get_vertex(vertex_id) if vertex is None: raise ValueError(f"Vertex {vertex_id} not found") + if not stream and hasattr(vertex, "consume_async_generator"): + await vertex.consume_async_generator() outputs.append(vertex.result) return outputs - async def run(self, inputs: Dict[str, Union[str, list[str]]]) -> List["ResultData"]: + async def run( + self, inputs: Dict[str, Union[str, list[str]]], stream: bool + ) -> List["ResultData"]: """Runs the graph with the given inputs.""" # inputs is {"message": "Hello, world!"} @@ -106,7 +110,9 @@ class Graph: if not isinstance(inputs_values, list): inputs_values = [inputs_values] for input_value in inputs_values: - run_outputs = await self._run({INPUT_FIELD_NAME: input_value}) + run_outputs = await self._run( + {INPUT_FIELD_NAME: input_value}, stream=stream + ) logger.debug(f"Run outputs: {run_outputs}") outputs.extend(run_outputs) return outputs diff --git a/src/backend/langflow/graph/vertex/types.py b/src/backend/langflow/graph/vertex/types.py index 0b2be3ab0..721a7ccc8 100644 --- a/src/backend/langflow/graph/vertex/types.py +++ b/src/backend/langflow/graph/vertex/types.py @@ -451,6 +451,10 @@ class ChatVertex(StatelessVertex): self._validate_built_object() self._built = True + async def consume_async_generator(self): + async for _ in self.stream(): + pass + class RoutingVertex(StatelessVertex): def __init__(self, data: Dict, graph): diff --git a/src/backend/langflow/processing/process.py b/src/backend/langflow/processing/process.py index 0ab6fcadd..d7cf09a6f 100644 --- a/src/backend/langflow/processing/process.py +++ b/src/backend/langflow/processing/process.py @@ -271,6 +271,7 @@ async def run_graph( graph: Union["Graph", dict], flow_id: str, session_id: str, + stream: bool, inputs: Optional[Union[dict, List[dict]]] = None, artifacts: Optional[Dict[str, Any]] = None, session_service: Optional[SessionService] = None, @@ -286,7 +287,7 @@ async def run_graph( session_id=flow_id, data_graph=graph_data ) - outputs = await graph.run(inputs) + outputs = await graph.run(inputs, stream=stream) if session_id and session_service: session_service.update_session(session_id, (graph, artifacts)) return outputs, session_id