Update vertex raw parameters and add session ID

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-03-04 18:52:50 -03:00
commit c6045809f2
2 changed files with 15 additions and 5 deletions

View file

@ -151,14 +151,20 @@ class Graph:
getattr(self, f"_{attribute}_vertices").append(vertex.id)
async def _run(
self, inputs: Dict[str, str], stream: bool
self, inputs: Dict[str, str], stream: bool, session_id: str
) -> List[Optional["ResultData"]]:
"""Runs the graph with the given inputs."""
for vertex_id in self._is_input_vertices:
vertex = self.get_vertex(vertex_id)
if vertex is None:
raise ValueError(f"Vertex {vertex_id} not found")
vertex.update_raw_params(inputs)
vertex.update_raw_params(inputs, overwrite=True)
# Update all the vertices with the session_id
for vertex_id in self._has_session_id_vertices:
vertex = self.get_vertex(vertex_id)
if vertex is None:
raise ValueError(f"Vertex {vertex_id} not found")
vertex.update_raw_params({"session_id": session_id})
try:
await self.process()
self.increment_run_count()
@ -181,7 +187,7 @@ class Graph:
return outputs
async def run(
self, inputs: Dict[str, Union[str, list[str]]], stream: bool
self, inputs: Dict[str, Union[str, list[str]]], stream: bool, session_id: str
) -> List[Optional["ResultData"]]:
"""Runs the graph with the given inputs."""
@ -195,7 +201,7 @@ class Graph:
inputs_values = [inputs_values]
for input_value in inputs_values:
run_outputs = await self._run(
{INPUT_FIELD_NAME: input_value}, stream=stream
{INPUT_FIELD_NAME: input_value}, stream=stream, session_id=session_id
)
logger.debug(f"Run outputs: {run_outputs}")
outputs.extend(run_outputs)

View file

@ -383,7 +383,7 @@ class Vertex:
self.params = params
self._raw_params = params.copy()
def update_raw_params(self, new_params: Dict[str, str]):
def update_raw_params(self, new_params: Dict[str, str], overwrite: bool = False):
"""
Update the raw parameters of the vertex with the given new parameters.
@ -398,6 +398,10 @@ class Vertex:
return
if any(isinstance(self._raw_params.get(key), Vertex) for key in new_params):
return
if not overwrite:
for key in new_params.copy():
if key not in self._raw_params:
new_params.pop(key)
self._raw_params.update(new_params)
self.updated_raw_params = True