Add input and output definitions to Graph class

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-02-27 11:37:53 -03:00
commit 073e4b7ccf

View file

@ -1,5 +1,5 @@
from collections import defaultdict, deque
from typing import Dict, Generator, List, Optional, Type, Union
from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Type, Union
from langchain.chains.base import Chain
from loguru import logger
@ -19,6 +19,9 @@ from langflow.graph.vertex.types import (
from langflow.interface.tools.constants import FILE_TOOLS
from langflow.utils import payload
if TYPE_CHECKING:
from langflow.graph.schema import ResultData
class Graph:
"""A class representing a graph of vertices and edges."""
@ -37,6 +40,8 @@ class Graph:
self._runs = 0
self._updates = 0
self.flow_id = flow_id
self._inputs = []
self._outputs = []
self.top_level_vertices = []
for vertex in self._vertices:
@ -49,6 +54,45 @@ class Graph:
self.inactive_vertices = set()
self._build_graph()
self.build_graph_maps()
self.define_inputs_and_outputs()
def define_inputs_and_outputs(self):
"""
Defines the input and output vertices of the graph.
"""
for vertex in self.vertices:
if vertex.is_input:
self._inputs.append(vertex.id)
if vertex.is_output:
self._outputs.append(vertex.id)
def run(self, inputs: Dict[str, str]) -> List["ResultData"]:
"""Runs the graph with the given inputs."""
# inputs is {"message": "Hello, world!"}
# we need to go through self.inputs and update the self._raw_params
# of the vertices that are inputs
for vertex_id in self.inputs:
vertex = self.get_vertex(vertex_id)
if vertex is None:
raise ValueError(f"Vertex {vertex_id} not found")
vertex.update_raw_params(inputs)
try:
self.build()
self.increment_run_count()
except Exception as exc:
logger.exception(exc)
raise ValueError(f"Error running graph: {exc}") from exc
# Now we get the outputs from the self.outputs
outputs = []
for vertex_id in self.outputs:
vertex = self.get_vertex(vertex_id)
if vertex is None:
raise ValueError(f"Vertex {vertex_id} not found")
outputs.append(vertex.result)
return outputs
@property
def metadata(self):