Add stream parameter to run_flow_with_caching and Graph.run methods
This commit is contained in:
parent
6c2a35afb1
commit
639c54e3ee
4 changed files with 18 additions and 4 deletions
|
|
@ -228,6 +228,7 @@ async def run_flow_with_caching(
|
|||
flow_id: str,
|
||||
inputs: Optional[Union[List[dict], dict]] = None,
|
||||
tweaks: Optional[dict] = None,
|
||||
stream: Annotated[bool, Body(embed=True)] = False, # noqa: F821
|
||||
session_id: Annotated[Union[None, str], Body(embed=True)] = None, # noqa: F821
|
||||
api_key_user: User = Depends(api_key_security),
|
||||
session_service: SessionService = Depends(get_session_service),
|
||||
|
|
@ -246,6 +247,7 @@ async def run_flow_with_caching(
|
|||
inputs=inputs,
|
||||
artifacts=artifacts,
|
||||
session_service=session_service,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
else:
|
||||
|
|
@ -270,6 +272,7 @@ async def run_flow_with_caching(
|
|||
inputs=inputs,
|
||||
artifacts={},
|
||||
session_service=session_service,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
return RunResponse(outputs=task_result, session_id=session_id)
|
||||
|
|
|
|||
|
|
@ -73,7 +73,7 @@ class Graph:
|
|||
if getattr(vertex, attribute):
|
||||
getattr(self, f"_{attribute}_vertices").append(vertex.id)
|
||||
|
||||
async def _run(self, inputs: Dict[str, str]) -> List["ResultData"]:
|
||||
async def _run(self, inputs: Dict[str, str], stream: bool) -> List["ResultData"]:
|
||||
"""Runs the graph with the given inputs."""
|
||||
for vertex_id in self._is_input_vertices:
|
||||
vertex = self.get_vertex(vertex_id)
|
||||
|
|
@ -91,10 +91,14 @@ class Graph:
|
|||
vertex = self.get_vertex(vertex_id)
|
||||
if vertex is None:
|
||||
raise ValueError(f"Vertex {vertex_id} not found")
|
||||
if not stream and hasattr(vertex, "consume_async_generator"):
|
||||
await vertex.consume_async_generator()
|
||||
outputs.append(vertex.result)
|
||||
return outputs
|
||||
|
||||
async def run(self, inputs: Dict[str, Union[str, list[str]]]) -> List["ResultData"]:
|
||||
async def run(
|
||||
self, inputs: Dict[str, Union[str, list[str]]], stream: bool
|
||||
) -> List["ResultData"]:
|
||||
"""Runs the graph with the given inputs."""
|
||||
|
||||
# inputs is {"message": "Hello, world!"}
|
||||
|
|
@ -106,7 +110,9 @@ class Graph:
|
|||
if not isinstance(inputs_values, list):
|
||||
inputs_values = [inputs_values]
|
||||
for input_value in inputs_values:
|
||||
run_outputs = await self._run({INPUT_FIELD_NAME: input_value})
|
||||
run_outputs = await self._run(
|
||||
{INPUT_FIELD_NAME: input_value}, stream=stream
|
||||
)
|
||||
logger.debug(f"Run outputs: {run_outputs}")
|
||||
outputs.extend(run_outputs)
|
||||
return outputs
|
||||
|
|
|
|||
|
|
@ -451,6 +451,10 @@ class ChatVertex(StatelessVertex):
|
|||
self._validate_built_object()
|
||||
self._built = True
|
||||
|
||||
async def consume_async_generator(self):
|
||||
async for _ in self.stream():
|
||||
pass
|
||||
|
||||
|
||||
class RoutingVertex(StatelessVertex):
|
||||
def __init__(self, data: Dict, graph):
|
||||
|
|
|
|||
|
|
@ -271,6 +271,7 @@ async def run_graph(
|
|||
graph: Union["Graph", dict],
|
||||
flow_id: str,
|
||||
session_id: str,
|
||||
stream: bool,
|
||||
inputs: Optional[Union[dict, List[dict]]] = None,
|
||||
artifacts: Optional[Dict[str, Any]] = None,
|
||||
session_service: Optional[SessionService] = None,
|
||||
|
|
@ -286,7 +287,7 @@ async def run_graph(
|
|||
session_id=flow_id, data_graph=graph_data
|
||||
)
|
||||
|
||||
outputs = await graph.run(inputs)
|
||||
outputs = await graph.run(inputs, stream=stream)
|
||||
if session_id and session_service:
|
||||
session_service.update_session(session_id, (graph, artifacts))
|
||||
return outputs, session_id
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue