Add support for specifying outputs in run_flow_with_caching and run_graph

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-03-05 12:24:49 -03:00
commit e1ded9c106
3 changed files with 30 additions and 10 deletions

View file

@ -57,6 +57,7 @@ async def run_flow_with_caching(
session: Annotated[Session, Depends(get_session)],
flow_id: str,
inputs: Optional[InputValueRequest] = None,
outputs: Optional[List[str]] = None,
tweaks: Optional[dict] = None,
stream: Annotated[bool, Body(embed=True)] = False, # noqa: F821
session_id: Annotated[Union[None, str], Body(embed=True)] = None, # noqa: F821
@ -69,6 +70,9 @@ async def run_flow_with_caching(
else:
input_values_dict = {}
if outputs is None:
outputs = []
if session_id:
session_data = await session_service.load_session(
session_id, flow_id=flow_id
@ -82,6 +86,7 @@ async def run_flow_with_caching(
flow_id=flow_id,
session_id=session_id,
inputs=input_values_dict,
outputs=outputs,
artifacts=artifacts,
session_service=session_service,
stream=stream,
@ -107,6 +112,7 @@ async def run_flow_with_caching(
flow_id=flow_id,
session_id=session_id,
inputs=input_values_dict,
outputs=outputs,
artifacts={},
session_service=session_service,
stream=stream,

View file

@ -151,7 +151,7 @@ class Graph:
getattr(self, f"_{attribute}_vertices").append(vertex.id)
async def _run(
self, inputs: Dict[str, str], stream: bool, session_id: str
self, inputs: Dict[str, str], outputs: list[str], stream: bool, session_id: str
) -> List[Optional["ResultData"]]:
"""Runs the graph with the given inputs."""
for vertex_id in self._is_input_vertices:
@ -171,7 +171,7 @@ class Graph:
except Exception as exc:
logger.exception(exc)
raise ValueError(f"Error running graph: {exc}") from exc
outputs = []
vertex_outputs = []
for vertex_id in self._is_output_vertices:
vertex = self.get_vertex(vertex_id)
if vertex is None:
@ -183,11 +183,16 @@ class Graph:
and hasattr(vertex, "consume_async_generator")
):
await vertex.consume_async_generator()
outputs.append(vertex.result)
return outputs
if vertex.display_name in outputs or vertex.id in outputs:
vertex_outputs.append(vertex.result)
return vertex_outputs
async def run(
self, inputs: Dict[str, Union[str, list[str]]], stream: bool, session_id: str
self,
inputs: Dict[str, Union[str, list[str]]],
outputs: list[str],
stream: bool,
session_id: str,
) -> List[Optional["ResultData"]]:
"""Runs the graph with the given inputs."""
@ -195,17 +200,20 @@ class Graph:
# we need to go through self.inputs and update the self._raw_params
# of the vertices that are inputs
# if the value is a list, we need to run multiple times
outputs = []
vertex_outputs = []
inputs_values = inputs.get(INPUT_FIELD_NAME, "")
if not isinstance(inputs_values, list):
inputs_values = [inputs_values]
for input_value in inputs_values:
run_outputs = await self._run(
{INPUT_FIELD_NAME: input_value}, stream=stream, session_id=session_id
{INPUT_FIELD_NAME: input_value},
outputs,
stream=stream,
session_id=session_id,
)
logger.debug(f"Run outputs: {run_outputs}")
outputs.extend(run_outputs)
return outputs
vertex_outputs.append(run_outputs)
return vertex_outputs
# vertices_layers is a list of lists ordered by the order the vertices
# should be built.

View file

@ -204,6 +204,7 @@ async def run_graph(
stream: bool,
session_id: Optional[str] = None,
inputs: Optional[dict[str, Union[List[str], str]]] = None,
outputs: Optional[List[str]] = None,
artifacts: Optional[Dict[str, Any]] = None,
session_service: Optional[SessionService] = None,
):
@ -220,7 +221,12 @@ async def run_graph(
if inputs is None:
inputs = {}
outputs = await graph.run(inputs, stream=stream, session_id=session_id)
outputs = await graph.run(
inputs,
outputs,
stream=stream,
session_id=session_id,
)
if session_id and session_service:
session_service.update_session(session_id, (graph, artifacts))
return outputs, session_id