From c6045809f2a16cdf6cee5c1242b0a5adae62d3a2 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Mon, 4 Mar 2024 18:52:50 -0300 Subject: [PATCH] Update vertex raw parameters and add session ID --- src/backend/langflow/graph/graph/base.py | 14 ++++++++++---- src/backend/langflow/graph/vertex/base.py | 6 +++++- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/src/backend/langflow/graph/graph/base.py b/src/backend/langflow/graph/graph/base.py index a0b91b440..97246aee1 100644 --- a/src/backend/langflow/graph/graph/base.py +++ b/src/backend/langflow/graph/graph/base.py @@ -151,14 +151,20 @@ class Graph: getattr(self, f"_{attribute}_vertices").append(vertex.id) async def _run( - self, inputs: Dict[str, str], stream: bool + self, inputs: Dict[str, str], stream: bool, session_id: str ) -> List[Optional["ResultData"]]: """Runs the graph with the given inputs.""" for vertex_id in self._is_input_vertices: vertex = self.get_vertex(vertex_id) if vertex is None: raise ValueError(f"Vertex {vertex_id} not found") - vertex.update_raw_params(inputs) + vertex.update_raw_params(inputs, overwrite=True) + # Update all the vertices with the session_id + for vertex_id in self._has_session_id_vertices: + vertex = self.get_vertex(vertex_id) + if vertex is None: + raise ValueError(f"Vertex {vertex_id} not found") + vertex.update_raw_params({"session_id": session_id}) try: await self.process() self.increment_run_count() @@ -181,7 +187,7 @@ class Graph: return outputs async def run( - self, inputs: Dict[str, Union[str, list[str]]], stream: bool + self, inputs: Dict[str, Union[str, list[str]]], stream: bool, session_id: str ) -> List[Optional["ResultData"]]: """Runs the graph with the given inputs.""" @@ -195,7 +201,7 @@ class Graph: inputs_values = [inputs_values] for input_value in inputs_values: run_outputs = await self._run( - {INPUT_FIELD_NAME: input_value}, stream=stream + {INPUT_FIELD_NAME: input_value}, stream=stream, session_id=session_id ) logger.debug(f"Run outputs: {run_outputs}") outputs.extend(run_outputs) diff --git a/src/backend/langflow/graph/vertex/base.py b/src/backend/langflow/graph/vertex/base.py index 635b6f7f3..f4e4823be 100644 --- a/src/backend/langflow/graph/vertex/base.py +++ b/src/backend/langflow/graph/vertex/base.py @@ -383,7 +383,7 @@ class Vertex: self.params = params self._raw_params = params.copy() - def update_raw_params(self, new_params: Dict[str, str]): + def update_raw_params(self, new_params: Dict[str, str], overwrite: bool = False): """ Update the raw parameters of the vertex with the given new parameters. @@ -398,6 +398,10 @@ class Vertex: return if any(isinstance(self._raw_params.get(key), Vertex) for key in new_params): return + if not overwrite: + for key in new_params.copy(): + if key not in self._raw_params: + new_params.pop(key) self._raw_params.update(new_params) self.updated_raw_params = True