Add input and output definitions to Graph class
This commit is contained in:
parent
04de488ede
commit
073e4b7ccf
1 changed files with 45 additions and 1 deletions
|
|
@ -1,5 +1,5 @@
|
|||
from collections import defaultdict, deque
|
||||
from typing import Dict, Generator, List, Optional, Type, Union
|
||||
from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Type, Union
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from loguru import logger
|
||||
|
|
@ -19,6 +19,9 @@ from langflow.graph.vertex.types import (
|
|||
from langflow.interface.tools.constants import FILE_TOOLS
|
||||
from langflow.utils import payload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.graph.schema import ResultData
|
||||
|
||||
|
||||
class Graph:
|
||||
"""A class representing a graph of vertices and edges."""
|
||||
|
|
@ -37,6 +40,8 @@ class Graph:
|
|||
self._runs = 0
|
||||
self._updates = 0
|
||||
self.flow_id = flow_id
|
||||
self._inputs = []
|
||||
self._outputs = []
|
||||
|
||||
self.top_level_vertices = []
|
||||
for vertex in self._vertices:
|
||||
|
|
@ -49,6 +54,45 @@ class Graph:
|
|||
self.inactive_vertices = set()
|
||||
self._build_graph()
|
||||
self.build_graph_maps()
|
||||
self.define_inputs_and_outputs()
|
||||
|
||||
def define_inputs_and_outputs(self):
|
||||
"""
|
||||
Defines the input and output vertices of the graph.
|
||||
"""
|
||||
for vertex in self.vertices:
|
||||
if vertex.is_input:
|
||||
self._inputs.append(vertex.id)
|
||||
if vertex.is_output:
|
||||
self._outputs.append(vertex.id)
|
||||
|
||||
def run(self, inputs: Dict[str, str]) -> List["ResultData"]:
|
||||
"""Runs the graph with the given inputs."""
|
||||
|
||||
# inputs is {"message": "Hello, world!"}
|
||||
# we need to go through self.inputs and update the self._raw_params
|
||||
# of the vertices that are inputs
|
||||
|
||||
for vertex_id in self.inputs:
|
||||
vertex = self.get_vertex(vertex_id)
|
||||
if vertex is None:
|
||||
raise ValueError(f"Vertex {vertex_id} not found")
|
||||
vertex.update_raw_params(inputs)
|
||||
try:
|
||||
self.build()
|
||||
self.increment_run_count()
|
||||
except Exception as exc:
|
||||
logger.exception(exc)
|
||||
raise ValueError(f"Error running graph: {exc}") from exc
|
||||
|
||||
# Now we get the outputs from the self.outputs
|
||||
outputs = []
|
||||
for vertex_id in self.outputs:
|
||||
vertex = self.get_vertex(vertex_id)
|
||||
if vertex is None:
|
||||
raise ValueError(f"Vertex {vertex_id} not found")
|
||||
outputs.append(vertex.result)
|
||||
return outputs
|
||||
|
||||
@property
|
||||
def metadata(self):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue