Add support for input type in Graph class

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-03-22 22:37:07 -03:00
commit 36c4b9ee4d
3 changed files with 19 additions and 2 deletions

View file

@ -1,7 +1,7 @@
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Union
from uuid import UUID
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_serializer
@ -258,6 +258,10 @@ class VerticesBuiltResponse(BaseModel):
class InputValueRequest(BaseModel):
components: Optional[List[str]] = []
input_value: Optional[str] = None
type: Optional[Literal["chat", "text", "json", "any"]] = Field(
"any",
description="Defines on which components the input value should be applied. 'any' applies to all input components.",
)
# add an example
model_config = ConfigDict(
@ -269,6 +273,8 @@ class InputValueRequest(BaseModel):
},
{"components": ["Component Name"], "input_value": "input_value"},
{"input_value": "input_value"},
{"type": "chat", "input_value": "input_value"},
{"type": "json", "input_value": '{"key": "value"}'},
]
},
extra="forbid",

View file

@ -198,6 +198,7 @@ class Graph:
self,
inputs: Dict[str, str],
input_components: list[str],
input_type: str,
outputs: list[str],
stream: bool,
session_id: str,
@ -224,8 +225,13 @@ class Graph:
raise ValueError(f"Invalid input value: {inputs.get(INPUT_FIELD_NAME)}. Expected string")
for vertex_id in self._is_input_vertices:
vertex = self.get_vertex(vertex_id)
# If the vertex is not in the input_components list
if input_components and (vertex_id not in input_components or vertex.display_name not in input_components):
continue
# If the input_type is not any and the input_type is not in the vertex id
# Example: input_type = "chat" and vertex.id = "OpenAI-19ddn"
elif input_type != "any" and input_type not in vertex.id.lower():
continue
if vertex is None:
raise ValueError(f"Vertex {vertex_id} not found")
vertex.update_raw_params(inputs, overwrite=True)
@ -259,6 +265,7 @@ class Graph:
self,
inputs: list[Dict[str, str]],
inputs_components: Optional[list[list[str]]] = None,
types: Optional[list[str]] = None,
outputs: Optional[list[str]] = None,
session_id: Optional[str] = None,
stream: bool = False,
@ -285,10 +292,11 @@ class Graph:
inputs = [inputs]
elif not inputs:
inputs = [{}]
for run_inputs, components in zip(inputs, inputs_components):
for run_inputs, components, input_type in zip(inputs, inputs_components, types):
run_outputs = await self._run(
inputs=run_inputs,
input_components=components,
input_type=input_type,
outputs=outputs or [],
stream=stream,
session_id=session_id or "",

View file

@ -217,16 +217,19 @@ async def run_graph(
raise ValueError("session_id or session_service must be provided")
components = []
inputs_list = []
types = []
for input_value_request in inputs:
if input_value_request.input_value is None:
logger.warning("InputValueRequest input_value cannot be None, defaulting to an empty string.")
input_value_request.input_value = ""
components.append(input_value_request.components or [])
inputs_list.append({INPUT_FIELD_NAME: input_value_request.input_value})
types.append(input_value_request.type)
run_outputs = await graph.run(
inputs_list,
components,
types,
outputs or [],
stream=stream,
session_id=session_id_str or "",