feat: add dynamic state model creation and update (#3271)

* feat: add initial implementation of dynamic state model creation and output getter in graph state module

* feat: implement _reset_all_output_values method to initialize component outputs in custom_component class

* feat: add state model management with lazy initialization and dynamic instance getter in custom_component class

* feat: Refactor Component class to use public method get_output_by_method

Refactor the Component class in the custom_component module to change the visibility of the method `_get_output_by_method` to public by renaming it to `get_output_by_method`. This change improves the accessibility and clarity of the method for external use.

* feat: add output setter utility to manage output values in state model properties

* feat: implement validation for methods' classes in output getter/setter utilities in state model to ensure proper structure

* feat: add state model creation from graph in state_model.py

* feat: enhance Graph class with lazy loading for state model creation from graph

* feat: add unit tests for state model creation and validation in test_state_model.py

* feat: add unit tests for state model creation and validation in test_state_model.py

* feat: add functional test for graph state update and validation in test_graph_state_model.py

* fix: update _instance_getter function to accept a parameter in component.py for state model instance retrieval

* refactor: rename test to clarify purpose in test_state_model.py for functional state update validation

* chore: import Finish constant in test_graph_state_model.py for improved clarity and usage in state model tests

* refactor: add optional validation in output getter/setter methods for improved method integrity in state model handling

* refactor: enhance state model creation with optional validation and error handling for output methods in model.py

* refactor: serialize and deserialize GraphStateModel in test_graph_state_model.py

* refactor: improve error message and add verbose mode for graph start in test_state_model.py

* refactor: remove verbose flag from graph.start in TestCreateStateModel for consistency in test_state_model.py

* refactor: disable validation when creating GraphStateModel in state_model.py for improved flexibility

* refactor: add validation documentation for method attributes in model.py to enhance code clarity and usability

* refactor: expand docstring for build_output_getter in model.py to clarify usage and validation details

* refactor: add detailed docstring for build_output_setter in model.py to improve clarity on functionality and usage scenarios

* refactor: add comprehensive docstring for create_state_model in model.py to clarify functionality and usage examples

* refactor: enhance docstring for create_state_model_from_graph in state_model.py to clarify functionality and provide examples

* test: add JSON schema validation in graph state model tests for improved structure and correctness verification

* refactor: Improve graph_state_model.json_schema unit test readability and structure.
This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-08-12 21:53:57 -03:00 committed by GitHub
commit c5d9cbae49
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 654 additions and 4 deletions

View file

@ -7,6 +7,7 @@ import nanoid # type: ignore
import yaml
from pydantic import BaseModel
from langflow.graph.state.model import create_state_model
from langflow.helpers.custom import format_type
from langflow.schema.artifact import get_artifact_type, post_process_raw
from langflow.schema.data import Data
@ -35,7 +36,7 @@ class Component(CustomComponent):
def __init__(self, **kwargs):
# if key starts with _ it is a config
# else it is an input
self._reset_all_output_values()
inputs = {}
config = {}
for key, value in kwargs.items():
@ -50,6 +51,7 @@ class Component(CustomComponent):
self._parameters = inputs or {}
self._edges: list[EdgeData] = []
self._components: list[Component] = []
self._state_model = None
self.set_attributes(self._parameters)
self._output_logs = {}
config = config or {}
@ -70,6 +72,30 @@ class Component(CustomComponent):
self._set_output_types()
self.set_class_code()
def _reset_all_output_values(self):
for output in self.outputs:
setattr(output, "value", UNDEFINED)
def _build_state_model(self):
if self._state_model:
return self._state_model
name = self.name or self.__class__.__name__
model_name = f"{name}StateModel"
fields = {}
for output in self.outputs:
fields[output.name] = getattr(self, output.method)
self._state_model = create_state_model(model_name=model_name, **fields)
return self._state_model
def get_state_model_instance_getter(self):
state_model = self._build_state_model()
def _instance_getter(_):
return state_model()
_instance_getter.__annotations__["return"] = state_model
return _instance_getter
def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
@ -247,7 +273,7 @@ class Component(CustomComponent):
output.add_types(return_types)
output.set_selected()
def _get_output_by_method(self, method: Callable):
def get_output_by_method(self, method: Callable):
# method is a callable and output.method is a string
# we need to find the output that has the same method
output = next((output for output in self.outputs if output.method == method.__name__), None)
@ -268,7 +294,7 @@ class Component(CustomComponent):
method_is_output = (
hasattr(method, "__self__")
and isinstance(method.__self__, Component)
and method.__self__._get_output_by_method(method)
and method.__self__.get_output_by_method(method)
)
return method_is_output
@ -298,7 +324,7 @@ class Component(CustomComponent):
def _connect_to_component(self, key, value, _input):
component = value.__self__
self._components.append(component)
output = component._get_output_by_method(value)
output = component.get_output_by_method(value)
self._add_edge(component, key, output, _input)
def _add_edge(self, component, key, output, _input):

View file

@ -19,6 +19,7 @@ from langflow.graph.graph.constants import Finish, lazy_load_vertex_dict
from langflow.graph.graph.runnable_vertices_manager import RunnableVerticesManager
from langflow.graph.graph.schema import GraphData, GraphDump, VertexBuildResult
from langflow.graph.graph.state_manager import GraphStateManager
from langflow.graph.graph.state_model import create_state_model_from_graph
from langflow.graph.graph.utils import find_start_component_id, process_flow, sort_up_to_vertex
from langflow.graph.schema import InterfaceComponentTypes, RunOutputs
from langflow.graph.vertex.base import Vertex, VertexStates
@ -62,6 +63,7 @@ class Graph:
log_config = {"disable": False}
configure(**log_config)
self._start = start
self._state_model = None
self._end = end
self._prepared = False
self._runs = 0
@ -109,6 +111,12 @@ class Graph:
if (start is not None and end is None) or (start is None and end is not None):
raise ValueError("You must provide both input and output components")
@property
def state_model(self):
if not self._state_model:
self._state_model = create_state_model_from_graph(self)
return self._state_model
def __add__(self, other):
if not isinstance(other, Graph):
raise TypeError("Can only add Graph objects")

View file

@ -0,0 +1,67 @@
import re
from langflow.graph.state.model import create_state_model
from langflow.helpers.base_model import BaseModel
def camel_to_snake(camel_str: str) -> str:
snake_str = re.sub(r"(?<!^)(?=[A-Z])", "_", camel_str).lower()
return snake_str
def create_state_model_from_graph(graph: BaseModel) -> type[BaseModel]:
"""
Create a Pydantic state model from a graph representation.
This function generates a Pydantic model that represents the state of an entire graph.
It creates getter methods for each vertex in the graph, allowing access to the state
of individual components within the graph structure.
Args:
graph (BaseModel): The graph object from which to create the state model.
This should be a Pydantic model representing the graph structure,
with a 'vertices' attribute containing all graph vertices.
Returns:
type[BaseModel]: A dynamically created Pydantic model class representing
the state of the entire graph. This model will have properties
corresponding to each vertex in the graph, with names converted from
the vertex IDs to snake case.
Raises:
ValueError: If any vertex in the graph does not have a properly initialized
component instance (i.e., if vertex._custom_component is None).
Notes:
- Each vertex in the graph must have a '_custom_component' attribute.
- The '_custom_component' must have a 'get_state_model_instance_getter' method.
- Vertex IDs are converted from camel case to snake case for the resulting model's field names.
- The resulting model uses the 'create_state_model' function with validation disabled.
Example:
>>> class Vertex(BaseModel):
... id: str
... _custom_component: Any
>>> class Graph(BaseModel):
... vertices: List[Vertex]
>>> # Assume proper setup of vertices and components
>>> graph = Graph(vertices=[...])
>>> GraphStateModel = create_state_model_from_graph(graph)
>>> graph_state = GraphStateModel()
>>> # Access component states, e.g.:
>>> print(graph_state.some_component_name)
"""
for vertex in graph.vertices:
if hasattr(vertex, "_custom_component") and vertex._custom_component is None:
raise ValueError(f"Vertex {vertex.id} does not have a component instance.")
state_model_getters = [
vertex._custom_component.get_state_model_instance_getter()
for vertex in graph.vertices
if hasattr(vertex, "_custom_component") and hasattr(vertex._custom_component, "get_state_model_instance_getter")
]
fields = {
camel_to_snake(vertex.id): state_model_getter
for vertex, state_model_getter in zip(graph.vertices, state_model_getters)
}
return create_state_model(model_name="GraphStateModel", validate=False, **fields)

View file

@ -0,0 +1,237 @@
from typing import Any, Callable, get_type_hints
from pydantic import ConfigDict, computed_field, create_model
from pydantic.fields import FieldInfo
def __validate_method(method: Callable) -> None:
"""
Validates a method by checking if it has the required attributes.
This function ensures that the given method belongs to a class with the necessary
structure for output handling. It checks for the presence of a __self__ attribute
on the method and a get_output_by_method attribute on the method's class.
Args:
method (Callable): The method to be validated.
Raises:
ValueError: If the method does not have a __self__ attribute or if the method's
class does not have a get_output_by_method attribute.
Example:
>>> class ValidClass:
... def get_output_by_method(self):
... pass
... def valid_method(self):
... pass
>>> __validate_method(ValidClass().valid_method) # This will pass
>>> __validate_method(lambda x: x) # This will raise a ValueError
"""
if not hasattr(method, "__self__"):
raise ValueError(f"Method {method} does not have a __self__ attribute.")
if not hasattr(method.__self__, "get_output_by_method"):
raise ValueError(f"Method's class {method.__self__} must have a get_output_by_method attribute.")
def build_output_getter(method: Callable, validate: bool = True) -> Callable:
"""
Builds an output getter function for a given method in a graph component.
This function creates a new callable that, when invoked, retrieves the output
of the specified method using the get_output_by_method of the method's class.
It's used in creating dynamic state models for graph components.
Args:
method (Callable): The method for which to build the output getter.
validate (bool, optional): Whether to validate the method before building
the getter. Defaults to True.
Returns:
Callable: The output getter function. When called, this function returns
the value of the output associated with the original method.
Raises:
ValueError: If the method has no return type annotation or if validation fails.
Notes:
- The getter function returns UNDEFINED if the output has not been set.
- When validate is True, the method must belong to a class with a
'get_output_by_method' attribute.
- This function is typically used internally by create_state_model.
Example:
>>> class ChatComponent:
... def get_output_by_method(self, method):
... return type('Output', (), {'value': "Hello, World!"})()
... def get_message(self) -> str:
... pass
>>> component = ChatComponent()
>>> getter = build_output_getter(component.get_message)
>>> print(getter(None)) # This will print "Hello, World!"
"""
def output_getter(_):
if validate:
__validate_method(method)
methods_class = method.__self__
output = methods_class.get_output_by_method(method)
return output.value
return_type = get_type_hints(method).get("return", None)
if return_type is None:
raise ValueError(f"Method {method.__name__} has no return type annotation.")
output_getter.__annotations__["return"] = return_type
return output_getter
def build_output_setter(method: Callable, validate: bool = True) -> Callable:
"""
Build an output setter function for a given method in a graph component.
This function creates a new callable that, when invoked, sets the output
of the specified method using the get_output_by_method of the method's class.
It's used in creating dynamic state models for graph components, allowing
for the modification of component states.
Args:
method (Callable): The method for which the output setter is being built.
validate (bool, optional): Flag indicating whether to validate the method
before building the setter. Defaults to True.
Returns:
Callable: The output setter function. When called with a value, this function
sets the output associated with the original method to that value.
Raises:
ValueError: If validation fails when validate is True.
Notes:
- When validate is True, the method must belong to a class with a
'get_output_by_method' attribute.
- This function is typically used internally by create_state_model.
- The setter allows for dynamic updating of component states in a graph.
Example:
>>> class ChatComponent:
... def get_output_by_method(self, method):
... return type('Output', (), {'value': None})()
... def set_message(self):
... pass
>>> component = ChatComponent()
>>> setter = build_output_setter(component.set_message)
>>> setter(component, "New message")
>>> print(component.get_output_by_method(component.set_message).value) # Prints "New message"
"""
def output_setter(self, value):
if validate:
__validate_method(method)
methods_class = method.__self__
output = methods_class.get_output_by_method(method)
output.value = value
return output_setter
def create_state_model(model_name: str = "State", validate: bool = True, **kwargs) -> type:
"""
Create a dynamic Pydantic state model based on the provided keyword arguments.
This function generates a Pydantic model class with fields corresponding to the
provided keyword arguments. It can handle various types of field definitions,
including callable methods (which are converted to properties), FieldInfo objects,
and type-default value tuples.
Args:
model_name (str, optional): The name of the model. Defaults to "State".
validate (bool, optional): Whether to validate the methods when converting
them to properties. Defaults to True.
**kwargs: Keyword arguments representing the fields of the model. Each argument
can be a callable method, a FieldInfo object, or a tuple of (type, default).
Returns:
type: The dynamically created Pydantic state model class.
Raises:
ValueError: If the provided field value is invalid or cannot be processed.
Examples:
>>> from langflow.components.inputs import ChatInput
>>> from langflow.components.outputs.ChatOutput import ChatOutput
>>> from pydantic import Field
>>>
>>> chat_input = ChatInput()
>>> chat_output = ChatOutput()
>>>
>>> # Create a model with a method from a component
>>> StateModel = create_state_model(method_one=chat_input.message_response)
>>> state = StateModel()
>>> assert state.method_one is UNDEFINED
>>> chat_input.set_output_value("message", "test")
>>> assert state.method_one == "test"
>>>
>>> # Create a model with multiple components and a Pydantic Field
>>> NewStateModel = create_state_model(
... model_name="NewStateModel",
... first_method=chat_input.message_response,
... second_method=chat_output.message_response,
... my_attribute=Field(None)
... )
>>> new_state = NewStateModel()
>>> new_state.first_method = "test"
>>> new_state.my_attribute = 123
>>> assert new_state.first_method == "test"
>>> assert new_state.my_attribute == 123
>>>
>>> # Create a model with tuple-based field definitions
>>> TupleStateModel = create_state_model(field_one=(str, "default"), field_two=(int, 123))
>>> tuple_state = TupleStateModel()
>>> assert tuple_state.field_one == "default"
>>> assert tuple_state.field_two == 123
Notes:
- The function handles empty keyword arguments gracefully.
- For tuple-based field definitions, the first element must be a valid Python type.
- Unsupported value types in keyword arguments will raise a ValueError.
- Callable methods must have proper return type annotations and belong to a class
with a 'get_output_by_method' attribute when validate is True.
"""
fields = {}
for name, value in kwargs.items():
# Extract the return type from the method's type annotations
if callable(value):
# Define the field with the return type
try:
__validate_method(value)
getter = build_output_getter(value, validate)
setter = build_output_setter(value, validate)
property_method = property(getter, setter)
except ValueError as e:
# If the method is not valid,assume it is already a getter
if "get_output_by_method" not in str(e) and "__self__" not in str(e) or validate:
raise e
property_method = value
fields[name] = computed_field(property_method)
elif isinstance(value, FieldInfo):
field_tuple = (value.annotation or Any, value)
fields[name] = field_tuple
elif isinstance(value, tuple) and len(value) == 2:
# Fields are defined by one of the following tuple forms:
# (<type>, <default value>)
# (<type>, Field(...))
# typing.Annotated[<type>, Field(...)]
if not isinstance(value[0], type):
raise ValueError(f"Invalid type for field {name}: {type(value[0])}")
fields[name] = (value[0], value[1])
else:
raise ValueError(f"Invalid value type {type(value)} for field {name}")
# Create the model dynamically
config_dict = ConfigDict(arbitrary_types_allowed=True, validate_assignment=True)
model = create_model(model_name, __config__=config_dict, **fields)
return model

View file

@ -0,0 +1,139 @@
import pytest
from pydantic import Field
from langflow.components.inputs import ChatInput
from langflow.components.outputs.ChatOutput import ChatOutput
from langflow.graph.graph.base import Graph
from langflow.graph.graph.constants import Finish
from langflow.graph.state.model import create_state_model
from langflow.template.field.base import UNDEFINED
@pytest.fixture
def client():
pass
@pytest.fixture
def chat_input_component():
return ChatInput()
@pytest.fixture
def chat_output_component():
return ChatOutput()
class TestCreateStateModel:
# Successfully create a model with valid method return type annotations
def test_create_model_with_valid_return_type_annotations(self, chat_input_component):
StateModel = create_state_model(method_one=chat_input_component.message_response)
state_instance = StateModel()
assert state_instance.method_one is UNDEFINED
chat_input_component.set_output_value("message", "test")
assert state_instance.method_one == "test"
def test_create_model_and_assign_values_fails(self, chat_input_component):
StateModel = create_state_model(method_one=chat_input_component.message_response)
state_instance = StateModel()
state_instance.method_one = "test"
assert state_instance.method_one == "test"
def test_create_with_multiple_components(self, chat_input_component, chat_output_component):
NewStateModel = create_state_model(
model_name="NewStateModel",
first_method=chat_input_component.message_response,
second_method=chat_output_component.message_response,
)
state_instance = NewStateModel()
assert state_instance.first_method is UNDEFINED
assert state_instance.second_method is UNDEFINED
state_instance.first_method = "test"
state_instance.second_method = 123
assert state_instance.first_method == "test"
assert state_instance.second_method == 123
def test_create_with_pydantic_field(self, chat_input_component):
StateModel = create_state_model(method_one=chat_input_component.message_response, my_attribute=Field(None))
state_instance = StateModel()
state_instance.method_one = "test"
state_instance.my_attribute = "test"
assert state_instance.method_one == "test"
assert state_instance.my_attribute == "test"
# my_attribute should be of type Any
state_instance.my_attribute = 123
assert state_instance.my_attribute == 123
# Creates a model with fields based on provided keyword arguments
def test_create_model_with_fields_from_kwargs(self):
StateModel = create_state_model(field_one=(str, "default"), field_two=(int, 123))
state_instance = StateModel()
assert state_instance.field_one == "default"
assert state_instance.field_two == 123
# Raises ValueError for invalid field type in tuple-based definitions
def test_raise_valueerror_for_invalid_field_type_in_tuple(self):
with pytest.raises(ValueError, match="Invalid type for field invalid_field"):
create_state_model(invalid_field=("not_a_type", "default"))
# Raises ValueError for unsupported value types in keyword arguments
def test_raise_valueerror_for_unsupported_value_types(self):
with pytest.raises(ValueError, match="Invalid value type <class 'int'> for field invalid_field"):
create_state_model(invalid_field=123)
# Handles empty keyword arguments gracefully
def test_handle_empty_kwargs_gracefully(self):
StateModel = create_state_model()
state_instance = StateModel()
assert state_instance is not None
# Ensures model name defaults to "State" if not provided
def test_default_model_name_to_state(self):
StateModel = create_state_model()
assert StateModel.__name__ == "State"
OtherNameModel = create_state_model(model_name="OtherName")
assert OtherNameModel.__name__ == "OtherName"
# Validates that callable values are properly type-annotated
def test_create_model_with_invalid_callable(self):
class MockComponent:
def method_one(self) -> str:
return "test"
def method_two(self) -> int:
return 123
mock_component = MockComponent()
with pytest.raises(ValueError, match="get_output_by_method"):
create_state_model(method_one=mock_component.method_one, method_two=mock_component.method_two)
def test_graph_functional_start_state_update(self):
chat_input = ChatInput(_id="chat_input")
chat_output = ChatOutput(input_value="test", _id="chat_output")
chat_output.set(sender_name=chat_input.message_response)
ChatStateModel = create_state_model(model_name="ChatState", message=chat_output.message_response)
chat_state_model = ChatStateModel()
assert chat_state_model.__class__.__name__ == "ChatState"
assert chat_state_model.message is UNDEFINED
graph = Graph(chat_input, chat_output)
graph.prepare()
# Now iterate through the graph
# and check that the graph is running
# correctly
ids = ["chat_input", "chat_output"]
results = []
for result in graph.start():
results.append(result)
assert len(results) == 3
assert all(result.vertex.id in ids for result in results if hasattr(result, "vertex"))
assert results[-1] == Finish()
assert chat_state_model.__class__.__name__ == "ChatState"
assert chat_state_model.message.get_text() == "test"

View file

@ -0,0 +1,173 @@
import pytest
from pydantic import BaseModel
from langflow.components.helpers.Memory import MemoryComponent
from langflow.components.inputs.ChatInput import ChatInput
from langflow.components.models.OpenAIModel import OpenAIModelComponent
from langflow.components.outputs.ChatOutput import ChatOutput
from langflow.components.prompts.Prompt import PromptComponent
from langflow.graph import Graph
from langflow.graph.graph.constants import Finish
from langflow.graph.graph.state_model import create_state_model_from_graph
@pytest.fixture
def client():
pass
def test_graph_state_model():
session_id = "test_session_id"
template = """{context}
User: {user_message}
AI: """
memory_component = MemoryComponent(_id="chat_memory")
memory_component.set(session_id=session_id)
chat_input = ChatInput(_id="chat_input")
prompt_component = PromptComponent(_id="prompt")
prompt_component.set(
template=template, user_message=chat_input.message_response, context=memory_component.retrieve_messages_as_text
)
openai_component = OpenAIModelComponent(_id="openai")
openai_component.set(
input_value=prompt_component.build_prompt, max_tokens=100, temperature=0.1, api_key="test_api_key"
)
openai_component.get_output("text_output").value = "Mock response"
chat_output = ChatOutput(_id="chat_output")
chat_output.set(input_value=openai_component.text_response)
graph = Graph(chat_input, chat_output)
GraphStateModel = create_state_model_from_graph(graph)
assert GraphStateModel.__name__ == "GraphStateModel"
assert list(GraphStateModel.model_computed_fields.keys()) == [
"chat_input",
"chat_output",
"openai",
"prompt",
"chat_memory",
]
def test_graph_functional_start_graph_state_update():
chat_input = ChatInput(_id="chat_input")
chat_input.set(input_value="Test Sender Name")
chat_output = ChatOutput(input_value="test", _id="chat_output")
chat_output.set(sender_name=chat_input.message_response)
graph = Graph(chat_input, chat_output)
graph.prepare()
# Now iterate through the graph
# and check that the graph is running
# correctly
GraphStateModel = create_state_model_from_graph(graph)
graph_state_model = GraphStateModel()
ids = ["chat_input", "chat_output"]
results = []
for result in graph.start():
results.append(result)
assert len(results) == 3
assert all(result.vertex.id in ids for result in results if hasattr(result, "vertex"))
assert results[-1] == Finish()
assert graph_state_model.__class__.__name__ == "GraphStateModel"
assert graph_state_model.chat_input.message.get_text() == "Test Sender Name"
assert graph_state_model.chat_output.message.get_text() == "test"
def test_graph_state_model_serialization():
chat_input = ChatInput(_id="chat_input")
chat_input.set(input_value="Test Sender Name")
chat_output = ChatOutput(input_value="test", _id="chat_output")
chat_output.set(sender_name=chat_input.message_response)
graph = Graph(chat_input, chat_output)
graph.prepare()
# Now iterate through the graph
# and check that the graph is running
# correctly
GraphStateModel = create_state_model_from_graph(graph)
graph_state_model = GraphStateModel()
ids = ["chat_input", "chat_output"]
results = []
for result in graph.start():
results.append(result)
assert len(results) == 3
assert all(result.vertex.id in ids for result in results if hasattr(result, "vertex"))
assert results[-1] == Finish()
assert graph_state_model.__class__.__name__ == "GraphStateModel"
assert graph_state_model.chat_input.message.get_text() == "Test Sender Name"
assert graph_state_model.chat_output.message.get_text() == "test"
serialized_state_model = graph_state_model.model_dump()
assert serialized_state_model["chat_input"]["message"]["text"] == "Test Sender Name"
def test_graph_state_model_json_schema():
chat_input = ChatInput(_id="chat_input")
chat_input.set(input_value="Test Sender Name")
chat_output = ChatOutput(input_value="test", _id="chat_output")
chat_output.set(sender_name=chat_input.message_response)
graph = Graph(chat_input, chat_output)
graph.prepare()
GraphStateModel = create_state_model_from_graph(graph)
graph_state_model: BaseModel = GraphStateModel()
json_schema = graph_state_model.model_json_schema(mode="serialization")
# Test main schema structure
assert json_schema["title"] == "GraphStateModel"
assert json_schema["type"] == "object"
assert set(json_schema["required"]) == {"chat_input", "chat_output"}
# Test chat_input and chat_output properties
for prop in ["chat_input", "chat_output"]:
assert prop in json_schema["properties"]
assert json_schema["properties"][prop]["allOf"][0]["$ref"].startswith("#/$defs/")
assert json_schema["properties"][prop]["readOnly"] is True
# Test $defs
assert set(json_schema["$defs"].keys()) == {"ChatInputStateModel", "ChatOutputStateModel", "Image", "Message"}
# Test ChatInputStateModel and ChatOutputStateModel
for model in ["ChatInputStateModel", "ChatOutputStateModel"]:
assert json_schema["$defs"][model]["type"] == "object"
assert json_schema["$defs"][model]["title"] == model
assert "message" in json_schema["$defs"][model]["properties"]
assert json_schema["$defs"][model]["properties"]["message"]["allOf"][0]["$ref"] == "#/$defs/Message"
assert json_schema["$defs"][model]["properties"]["message"]["readOnly"] is True
assert json_schema["$defs"][model]["required"] == ["message"]
# Test Message model
message_props = json_schema["$defs"]["Message"]["properties"]
assert set(message_props.keys()) == {
"text_key",
"data",
"default_value",
"text",
"sender",
"sender_name",
"files",
"session_id",
"timestamp",
"flow_id",
}
assert message_props["text_key"]["type"] == "string"
assert message_props["data"]["type"] == "object"
assert "anyOf" in message_props["default_value"]
assert "anyOf" in message_props["files"]
assert message_props["timestamp"]["type"] == "string"
# Test Image model
image_props = json_schema["$defs"]["Image"]["properties"]
assert set(image_props.keys()) == {"path", "url"}
for prop in ["path", "url"]:
assert "anyOf" in image_props[prop]
assert {"type": "string"} in image_props[prop]["anyOf"]
assert {"type": "null"} in image_props[prop]["anyOf"]