Add support for input type in Graph class
This commit is contained in:
parent
e6f251320b
commit
36c4b9ee4d
3 changed files with 19 additions and 2 deletions
|
|
@ -1,7 +1,7 @@
|
|||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_serializer
|
||||
|
|
@ -258,6 +258,10 @@ class VerticesBuiltResponse(BaseModel):
|
|||
class InputValueRequest(BaseModel):
|
||||
components: Optional[List[str]] = []
|
||||
input_value: Optional[str] = None
|
||||
type: Optional[Literal["chat", "text", "json", "any"]] = Field(
|
||||
"any",
|
||||
description="Defines on which components the input value should be applied. 'any' applies to all input components.",
|
||||
)
|
||||
|
||||
# add an example
|
||||
model_config = ConfigDict(
|
||||
|
|
@ -269,6 +273,8 @@ class InputValueRequest(BaseModel):
|
|||
},
|
||||
{"components": ["Component Name"], "input_value": "input_value"},
|
||||
{"input_value": "input_value"},
|
||||
{"type": "chat", "input_value": "input_value"},
|
||||
{"type": "json", "input_value": '{"key": "value"}'},
|
||||
]
|
||||
},
|
||||
extra="forbid",
|
||||
|
|
|
|||
|
|
@ -198,6 +198,7 @@ class Graph:
|
|||
self,
|
||||
inputs: Dict[str, str],
|
||||
input_components: list[str],
|
||||
input_type: str,
|
||||
outputs: list[str],
|
||||
stream: bool,
|
||||
session_id: str,
|
||||
|
|
@ -224,8 +225,13 @@ class Graph:
|
|||
raise ValueError(f"Invalid input value: {inputs.get(INPUT_FIELD_NAME)}. Expected string")
|
||||
for vertex_id in self._is_input_vertices:
|
||||
vertex = self.get_vertex(vertex_id)
|
||||
# If the vertex is not in the input_components list
|
||||
if input_components and (vertex_id not in input_components or vertex.display_name not in input_components):
|
||||
continue
|
||||
# If the input_type is not any and the input_type is not in the vertex id
|
||||
# Example: input_type = "chat" and vertex.id = "OpenAI-19ddn"
|
||||
elif input_type != "any" and input_type not in vertex.id.lower():
|
||||
continue
|
||||
if vertex is None:
|
||||
raise ValueError(f"Vertex {vertex_id} not found")
|
||||
vertex.update_raw_params(inputs, overwrite=True)
|
||||
|
|
@ -259,6 +265,7 @@ class Graph:
|
|||
self,
|
||||
inputs: list[Dict[str, str]],
|
||||
inputs_components: Optional[list[list[str]]] = None,
|
||||
types: Optional[list[str]] = None,
|
||||
outputs: Optional[list[str]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
stream: bool = False,
|
||||
|
|
@ -285,10 +292,11 @@ class Graph:
|
|||
inputs = [inputs]
|
||||
elif not inputs:
|
||||
inputs = [{}]
|
||||
for run_inputs, components in zip(inputs, inputs_components):
|
||||
for run_inputs, components, input_type in zip(inputs, inputs_components, types):
|
||||
run_outputs = await self._run(
|
||||
inputs=run_inputs,
|
||||
input_components=components,
|
||||
input_type=input_type,
|
||||
outputs=outputs or [],
|
||||
stream=stream,
|
||||
session_id=session_id or "",
|
||||
|
|
|
|||
|
|
@ -217,16 +217,19 @@ async def run_graph(
|
|||
raise ValueError("session_id or session_service must be provided")
|
||||
components = []
|
||||
inputs_list = []
|
||||
types = []
|
||||
for input_value_request in inputs:
|
||||
if input_value_request.input_value is None:
|
||||
logger.warning("InputValueRequest input_value cannot be None, defaulting to an empty string.")
|
||||
input_value_request.input_value = ""
|
||||
components.append(input_value_request.components or [])
|
||||
inputs_list.append({INPUT_FIELD_NAME: input_value_request.input_value})
|
||||
types.append(input_value_request.type)
|
||||
|
||||
run_outputs = await graph.run(
|
||||
inputs_list,
|
||||
components,
|
||||
types,
|
||||
outputs or [],
|
||||
stream=stream,
|
||||
session_id=session_id_str or "",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue