Update vertex raw parameters and add session ID
This commit is contained in:
parent
16c8d5be0c
commit
c6045809f2
2 changed files with 15 additions and 5 deletions
|
|
@ -151,14 +151,20 @@ class Graph:
|
|||
getattr(self, f"_{attribute}_vertices").append(vertex.id)
|
||||
|
||||
async def _run(
|
||||
self, inputs: Dict[str, str], stream: bool
|
||||
self, inputs: Dict[str, str], stream: bool, session_id: str
|
||||
) -> List[Optional["ResultData"]]:
|
||||
"""Runs the graph with the given inputs."""
|
||||
for vertex_id in self._is_input_vertices:
|
||||
vertex = self.get_vertex(vertex_id)
|
||||
if vertex is None:
|
||||
raise ValueError(f"Vertex {vertex_id} not found")
|
||||
vertex.update_raw_params(inputs)
|
||||
vertex.update_raw_params(inputs, overwrite=True)
|
||||
# Update all the vertices with the session_id
|
||||
for vertex_id in self._has_session_id_vertices:
|
||||
vertex = self.get_vertex(vertex_id)
|
||||
if vertex is None:
|
||||
raise ValueError(f"Vertex {vertex_id} not found")
|
||||
vertex.update_raw_params({"session_id": session_id})
|
||||
try:
|
||||
await self.process()
|
||||
self.increment_run_count()
|
||||
|
|
@ -181,7 +187,7 @@ class Graph:
|
|||
return outputs
|
||||
|
||||
async def run(
|
||||
self, inputs: Dict[str, Union[str, list[str]]], stream: bool
|
||||
self, inputs: Dict[str, Union[str, list[str]]], stream: bool, session_id: str
|
||||
) -> List[Optional["ResultData"]]:
|
||||
"""Runs the graph with the given inputs."""
|
||||
|
||||
|
|
@ -195,7 +201,7 @@ class Graph:
|
|||
inputs_values = [inputs_values]
|
||||
for input_value in inputs_values:
|
||||
run_outputs = await self._run(
|
||||
{INPUT_FIELD_NAME: input_value}, stream=stream
|
||||
{INPUT_FIELD_NAME: input_value}, stream=stream, session_id=session_id
|
||||
)
|
||||
logger.debug(f"Run outputs: {run_outputs}")
|
||||
outputs.extend(run_outputs)
|
||||
|
|
|
|||
|
|
@ -383,7 +383,7 @@ class Vertex:
|
|||
self.params = params
|
||||
self._raw_params = params.copy()
|
||||
|
||||
def update_raw_params(self, new_params: Dict[str, str]):
|
||||
def update_raw_params(self, new_params: Dict[str, str], overwrite: bool = False):
|
||||
"""
|
||||
Update the raw parameters of the vertex with the given new parameters.
|
||||
|
||||
|
|
@ -398,6 +398,10 @@ class Vertex:
|
|||
return
|
||||
if any(isinstance(self._raw_params.get(key), Vertex) for key in new_params):
|
||||
return
|
||||
if not overwrite:
|
||||
for key in new_params.copy():
|
||||
if key not in self._raw_params:
|
||||
new_params.pop(key)
|
||||
self._raw_params.update(new_params)
|
||||
self.updated_raw_params = True
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue