Add support for specifying outputs in run_flow_with_caching and run_graph
This commit is contained in:
parent
a4f5ff0daf
commit
e1ded9c106
3 changed files with 30 additions and 10 deletions
|
|
@ -57,6 +57,7 @@ async def run_flow_with_caching(
|
|||
session: Annotated[Session, Depends(get_session)],
|
||||
flow_id: str,
|
||||
inputs: Optional[InputValueRequest] = None,
|
||||
outputs: Optional[List[str]] = None,
|
||||
tweaks: Optional[dict] = None,
|
||||
stream: Annotated[bool, Body(embed=True)] = False, # noqa: F821
|
||||
session_id: Annotated[Union[None, str], Body(embed=True)] = None, # noqa: F821
|
||||
|
|
@ -69,6 +70,9 @@ async def run_flow_with_caching(
|
|||
else:
|
||||
input_values_dict = {}
|
||||
|
||||
if outputs is None:
|
||||
outputs = []
|
||||
|
||||
if session_id:
|
||||
session_data = await session_service.load_session(
|
||||
session_id, flow_id=flow_id
|
||||
|
|
@ -82,6 +86,7 @@ async def run_flow_with_caching(
|
|||
flow_id=flow_id,
|
||||
session_id=session_id,
|
||||
inputs=input_values_dict,
|
||||
outputs=outputs,
|
||||
artifacts=artifacts,
|
||||
session_service=session_service,
|
||||
stream=stream,
|
||||
|
|
@ -107,6 +112,7 @@ async def run_flow_with_caching(
|
|||
flow_id=flow_id,
|
||||
session_id=session_id,
|
||||
inputs=input_values_dict,
|
||||
outputs=outputs,
|
||||
artifacts={},
|
||||
session_service=session_service,
|
||||
stream=stream,
|
||||
|
|
|
|||
|
|
@ -151,7 +151,7 @@ class Graph:
|
|||
getattr(self, f"_{attribute}_vertices").append(vertex.id)
|
||||
|
||||
async def _run(
|
||||
self, inputs: Dict[str, str], stream: bool, session_id: str
|
||||
self, inputs: Dict[str, str], outputs: list[str], stream: bool, session_id: str
|
||||
) -> List[Optional["ResultData"]]:
|
||||
"""Runs the graph with the given inputs."""
|
||||
for vertex_id in self._is_input_vertices:
|
||||
|
|
@ -171,7 +171,7 @@ class Graph:
|
|||
except Exception as exc:
|
||||
logger.exception(exc)
|
||||
raise ValueError(f"Error running graph: {exc}") from exc
|
||||
outputs = []
|
||||
vertex_outputs = []
|
||||
for vertex_id in self._is_output_vertices:
|
||||
vertex = self.get_vertex(vertex_id)
|
||||
if vertex is None:
|
||||
|
|
@ -183,11 +183,16 @@ class Graph:
|
|||
and hasattr(vertex, "consume_async_generator")
|
||||
):
|
||||
await vertex.consume_async_generator()
|
||||
outputs.append(vertex.result)
|
||||
return outputs
|
||||
if vertex.display_name in outputs or vertex.id in outputs:
|
||||
vertex_outputs.append(vertex.result)
|
||||
return vertex_outputs
|
||||
|
||||
async def run(
|
||||
self, inputs: Dict[str, Union[str, list[str]]], stream: bool, session_id: str
|
||||
self,
|
||||
inputs: Dict[str, Union[str, list[str]]],
|
||||
outputs: list[str],
|
||||
stream: bool,
|
||||
session_id: str,
|
||||
) -> List[Optional["ResultData"]]:
|
||||
"""Runs the graph with the given inputs."""
|
||||
|
||||
|
|
@ -195,17 +200,20 @@ class Graph:
|
|||
# we need to go through self.inputs and update the self._raw_params
|
||||
# of the vertices that are inputs
|
||||
# if the value is a list, we need to run multiple times
|
||||
outputs = []
|
||||
vertex_outputs = []
|
||||
inputs_values = inputs.get(INPUT_FIELD_NAME, "")
|
||||
if not isinstance(inputs_values, list):
|
||||
inputs_values = [inputs_values]
|
||||
for input_value in inputs_values:
|
||||
run_outputs = await self._run(
|
||||
{INPUT_FIELD_NAME: input_value}, stream=stream, session_id=session_id
|
||||
{INPUT_FIELD_NAME: input_value},
|
||||
outputs,
|
||||
stream=stream,
|
||||
session_id=session_id,
|
||||
)
|
||||
logger.debug(f"Run outputs: {run_outputs}")
|
||||
outputs.extend(run_outputs)
|
||||
return outputs
|
||||
vertex_outputs.append(run_outputs)
|
||||
return vertex_outputs
|
||||
|
||||
# vertices_layers is a list of lists ordered by the order the vertices
|
||||
# should be built.
|
||||
|
|
|
|||
|
|
@ -204,6 +204,7 @@ async def run_graph(
|
|||
stream: bool,
|
||||
session_id: Optional[str] = None,
|
||||
inputs: Optional[dict[str, Union[List[str], str]]] = None,
|
||||
outputs: Optional[List[str]] = None,
|
||||
artifacts: Optional[Dict[str, Any]] = None,
|
||||
session_service: Optional[SessionService] = None,
|
||||
):
|
||||
|
|
@ -220,7 +221,12 @@ async def run_graph(
|
|||
if inputs is None:
|
||||
inputs = {}
|
||||
|
||||
outputs = await graph.run(inputs, stream=stream, session_id=session_id)
|
||||
outputs = await graph.run(
|
||||
inputs,
|
||||
outputs,
|
||||
stream=stream,
|
||||
session_id=session_id,
|
||||
)
|
||||
if session_id and session_service:
|
||||
session_service.update_session(session_id, (graph, artifacts))
|
||||
return outputs, session_id
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue