Add stream parameter to run_flow_with_caching and Graph.run methods

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-02-27 23:04:21 -03:00
commit 639c54e3ee
4 changed files with 18 additions and 4 deletions

View file

@ -228,6 +228,7 @@ async def run_flow_with_caching(
flow_id: str,
inputs: Optional[Union[List[dict], dict]] = None,
tweaks: Optional[dict] = None,
stream: Annotated[bool, Body(embed=True)] = False, # noqa: F821
session_id: Annotated[Union[None, str], Body(embed=True)] = None, # noqa: F821
api_key_user: User = Depends(api_key_security),
session_service: SessionService = Depends(get_session_service),
@ -246,6 +247,7 @@ async def run_flow_with_caching(
inputs=inputs,
artifacts=artifacts,
session_service=session_service,
stream=stream,
)
else:
@ -270,6 +272,7 @@ async def run_flow_with_caching(
inputs=inputs,
artifacts={},
session_service=session_service,
stream=stream,
)
return RunResponse(outputs=task_result, session_id=session_id)

View file

@ -73,7 +73,7 @@ class Graph:
if getattr(vertex, attribute):
getattr(self, f"_{attribute}_vertices").append(vertex.id)
async def _run(self, inputs: Dict[str, str]) -> List["ResultData"]:
async def _run(self, inputs: Dict[str, str], stream: bool) -> List["ResultData"]:
"""Runs the graph with the given inputs."""
for vertex_id in self._is_input_vertices:
vertex = self.get_vertex(vertex_id)
@ -91,10 +91,14 @@ class Graph:
vertex = self.get_vertex(vertex_id)
if vertex is None:
raise ValueError(f"Vertex {vertex_id} not found")
if not stream and hasattr(vertex, "consume_async_generator"):
await vertex.consume_async_generator()
outputs.append(vertex.result)
return outputs
async def run(self, inputs: Dict[str, Union[str, list[str]]]) -> List["ResultData"]:
async def run(
self, inputs: Dict[str, Union[str, list[str]]], stream: bool
) -> List["ResultData"]:
"""Runs the graph with the given inputs."""
# inputs is {"message": "Hello, world!"}
@ -106,7 +110,9 @@ class Graph:
if not isinstance(inputs_values, list):
inputs_values = [inputs_values]
for input_value in inputs_values:
run_outputs = await self._run({INPUT_FIELD_NAME: input_value})
run_outputs = await self._run(
{INPUT_FIELD_NAME: input_value}, stream=stream
)
logger.debug(f"Run outputs: {run_outputs}")
outputs.extend(run_outputs)
return outputs

View file

@ -451,6 +451,10 @@ class ChatVertex(StatelessVertex):
self._validate_built_object()
self._built = True
async def consume_async_generator(self):
async for _ in self.stream():
pass
class RoutingVertex(StatelessVertex):
def __init__(self, data: Dict, graph):

View file

@ -271,6 +271,7 @@ async def run_graph(
graph: Union["Graph", dict],
flow_id: str,
session_id: str,
stream: bool,
inputs: Optional[Union[dict, List[dict]]] = None,
artifacts: Optional[Dict[str, Any]] = None,
session_service: Optional[SessionService] = None,
@ -286,7 +287,7 @@ async def run_graph(
session_id=flow_id, data_graph=graph_data
)
outputs = await graph.run(inputs)
outputs = await graph.run(inputs, stream=stream)
if session_id and session_service:
session_service.update_session(session_id, (graph, artifacts))
return outputs, session_id