Add RunOutputs class to schema.py

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-03-11 18:21:31 -03:00
commit 0bd4517372
2 changed files with 13 additions and 7 deletions

View file

@ -9,7 +9,7 @@ from langflow.graph.edge.base import ContractEdge
from langflow.graph.graph.constants import lazy_load_vertex_dict
from langflow.graph.graph.state_manager import GraphStateManager
from langflow.graph.graph.utils import process_flow
from langflow.graph.schema import INPUT_FIELD_NAME, InterfaceComponentTypes
from langflow.graph.schema import INPUT_FIELD_NAME, InterfaceComponentTypes, RunOutputs
from langflow.graph.vertex.base import Vertex
from langflow.graph.vertex.types import (
ChatVertex,
@ -188,7 +188,7 @@ class Graph:
outputs: Optional[list[str]] = None,
session_id: Optional[str] = None,
stream: bool = False,
) -> List[List[Optional["ResultData"]]]:
) -> List[RunOutputs]:
"""Runs the graph with the given inputs."""
# inputs is {"message": "Hello, world!"}
@ -214,16 +214,17 @@ class Graph:
input_value = _input_value
else:
raise ValueError(f"Invalid input value: {input_value}. Expected string")
run_inputs = {INPUT_FIELD_NAME: input_value}
run_outputs = await self._run(
inputs={INPUT_FIELD_NAME: input_value},
inputs=run_inputs,
input_components=components,
outputs=outputs or [],
stream=stream,
session_id=session_id or "",
)
logger.debug(f"Run outputs: {run_outputs}")
vertex_outputs.append(run_outputs)
run_output_object = RunOutputs(inputs=run_inputs, outputs=run_outputs)
logger.debug(f"Run outputs: {run_output_object}")
vertex_outputs.append(run_output_object)
return vertex_outputs
# vertices_layers is a list of lists ordered by the order the vertices

View file

@ -1,5 +1,5 @@
from enum import Enum
from typing import Any, Optional
from typing import Any, List, Optional
from pydantic import BaseModel, Field, field_serializer
@ -50,3 +50,8 @@ OUTPUT_COMPONENTS = [
]
INPUT_FIELD_NAME = "input_value"
class RunOutputs(BaseModel):
inputs: dict = Field(default_factory=dict)
outputs: List[Optional[ResultData]] = Field(default_factory=list)