diff --git a/src/backend/langflow/graph/graph/base.py b/src/backend/langflow/graph/graph/base.py index 2f94f150a..68e16bed8 100644 --- a/src/backend/langflow/graph/graph/base.py +++ b/src/backend/langflow/graph/graph/base.py @@ -1,5 +1,5 @@ from collections import defaultdict, deque -from typing import Dict, Generator, List, Optional, Type, Union +from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Type, Union from langchain.chains.base import Chain from loguru import logger @@ -19,6 +19,9 @@ from langflow.graph.vertex.types import ( from langflow.interface.tools.constants import FILE_TOOLS from langflow.utils import payload +if TYPE_CHECKING: + from langflow.graph.schema import ResultData + class Graph: """A class representing a graph of vertices and edges.""" @@ -37,6 +40,8 @@ class Graph: self._runs = 0 self._updates = 0 self.flow_id = flow_id + self._inputs = [] + self._outputs = [] self.top_level_vertices = [] for vertex in self._vertices: @@ -49,6 +54,45 @@ class Graph: self.inactive_vertices = set() self._build_graph() self.build_graph_maps() + self.define_inputs_and_outputs() + + def define_inputs_and_outputs(self): + """ + Defines the input and output vertices of the graph. + """ + for vertex in self.vertices: + if vertex.is_input: + self._inputs.append(vertex.id) + if vertex.is_output: + self._outputs.append(vertex.id) + + def run(self, inputs: Dict[str, str]) -> List["ResultData"]: + """Runs the graph with the given inputs.""" + + # inputs is {"message": "Hello, world!"} + # we need to go through self.inputs and update the self._raw_params + # of the vertices that are inputs + + for vertex_id in self.inputs: + vertex = self.get_vertex(vertex_id) + if vertex is None: + raise ValueError(f"Vertex {vertex_id} not found") + vertex.update_raw_params(inputs) + try: + self.build() + self.increment_run_count() + except Exception as exc: + logger.exception(exc) + raise ValueError(f"Error running graph: {exc}") from exc + + # Now we get the outputs from the self.outputs + outputs = [] + for vertex_id in self.outputs: + vertex = self.get_vertex(vertex_id) + if vertex is None: + raise ValueError(f"Vertex {vertex_id} not found") + outputs.append(vertex.result) + return outputs @property def metadata(self):