From c5d9cbae4947f7ca87c0bc3323a95fe3df3d9314 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Mon, 12 Aug 2024 21:53:57 -0300 Subject: [PATCH] feat: add dynamic state model creation and update (#3271) * feat: add initial implementation of dynamic state model creation and output getter in graph state module * feat: implement _reset_all_output_values method to initialize component outputs in custom_component class * feat: add state model management with lazy initialization and dynamic instance getter in custom_component class * feat: Refactor Component class to use public method get_output_by_method Refactor the Component class in the custom_component module to change the visibility of the method `_get_output_by_method` to public by renaming it to `get_output_by_method`. This change improves the accessibility and clarity of the method for external use. * feat: add output setter utility to manage output values in state model properties * feat: implement validation for methods' classes in output getter/setter utilities in state model to ensure proper structure * feat: add state model creation from graph in state_model.py * feat: enhance Graph class with lazy loading for state model creation from graph * feat: add unit tests for state model creation and validation in test_state_model.py * feat: add unit tests for state model creation and validation in test_state_model.py * feat: add functional test for graph state update and validation in test_graph_state_model.py * fix: update _instance_getter function to accept a parameter in component.py for state model instance retrieval * refactor: rename test to clarify purpose in test_state_model.py for functional state update validation * chore: import Finish constant in test_graph_state_model.py for improved clarity and usage in state model tests * refactor: add optional validation in output getter/setter methods for improved method integrity in state model handling * refactor: enhance state model creation with optional validation and error handling for output methods in model.py * refactor: serialize and deserialize GraphStateModel in test_graph_state_model.py * refactor: improve error message and add verbose mode for graph start in test_state_model.py * refactor: remove verbose flag from graph.start in TestCreateStateModel for consistency in test_state_model.py * refactor: disable validation when creating GraphStateModel in state_model.py for improved flexibility * refactor: add validation documentation for method attributes in model.py to enhance code clarity and usability * refactor: expand docstring for build_output_getter in model.py to clarify usage and validation details * refactor: add detailed docstring for build_output_setter in model.py to improve clarity on functionality and usage scenarios * refactor: add comprehensive docstring for create_state_model in model.py to clarify functionality and usage examples * refactor: enhance docstring for create_state_model_from_graph in state_model.py to clarify functionality and provide examples * test: add JSON schema validation in graph state model tests for improved structure and correctness verification * refactor: Improve graph_state_model.json_schema unit test readability and structure. --- .../custom/custom_component/component.py | 34 ++- src/backend/base/langflow/graph/graph/base.py | 8 + .../base/langflow/graph/graph/state_model.py | 67 +++++ .../base/langflow/graph/state/__init__.py | 0 .../base/langflow/graph/state/model.py | 237 ++++++++++++++++++ .../graph/graph/state/test_state_model.py | 139 ++++++++++ .../graph/graph/test_graph_state_model.py | 173 +++++++++++++ 7 files changed, 654 insertions(+), 4 deletions(-) create mode 100644 src/backend/base/langflow/graph/graph/state_model.py create mode 100644 src/backend/base/langflow/graph/state/__init__.py create mode 100644 src/backend/base/langflow/graph/state/model.py create mode 100644 src/backend/tests/unit/graph/graph/state/test_state_model.py create mode 100644 src/backend/tests/unit/graph/graph/test_graph_state_model.py diff --git a/src/backend/base/langflow/custom/custom_component/component.py b/src/backend/base/langflow/custom/custom_component/component.py index 997736bf8..3f38ff45e 100644 --- a/src/backend/base/langflow/custom/custom_component/component.py +++ b/src/backend/base/langflow/custom/custom_component/component.py @@ -7,6 +7,7 @@ import nanoid # type: ignore import yaml from pydantic import BaseModel +from langflow.graph.state.model import create_state_model from langflow.helpers.custom import format_type from langflow.schema.artifact import get_artifact_type, post_process_raw from langflow.schema.data import Data @@ -35,7 +36,7 @@ class Component(CustomComponent): def __init__(self, **kwargs): # if key starts with _ it is a config # else it is an input - + self._reset_all_output_values() inputs = {} config = {} for key, value in kwargs.items(): @@ -50,6 +51,7 @@ class Component(CustomComponent): self._parameters = inputs or {} self._edges: list[EdgeData] = [] self._components: list[Component] = [] + self._state_model = None self.set_attributes(self._parameters) self._output_logs = {} config = config or {} @@ -70,6 +72,30 @@ class Component(CustomComponent): self._set_output_types() self.set_class_code() + def _reset_all_output_values(self): + for output in self.outputs: + setattr(output, "value", UNDEFINED) + + def _build_state_model(self): + if self._state_model: + return self._state_model + name = self.name or self.__class__.__name__ + model_name = f"{name}StateModel" + fields = {} + for output in self.outputs: + fields[output.name] = getattr(self, output.method) + self._state_model = create_state_model(model_name=model_name, **fields) + return self._state_model + + def get_state_model_instance_getter(self): + state_model = self._build_state_model() + + def _instance_getter(_): + return state_model() + + _instance_getter.__annotations__["return"] = state_model + return _instance_getter + def __deepcopy__(self, memo): if id(self) in memo: return memo[id(self)] @@ -247,7 +273,7 @@ class Component(CustomComponent): output.add_types(return_types) output.set_selected() - def _get_output_by_method(self, method: Callable): + def get_output_by_method(self, method: Callable): # method is a callable and output.method is a string # we need to find the output that has the same method output = next((output for output in self.outputs if output.method == method.__name__), None) @@ -268,7 +294,7 @@ class Component(CustomComponent): method_is_output = ( hasattr(method, "__self__") and isinstance(method.__self__, Component) - and method.__self__._get_output_by_method(method) + and method.__self__.get_output_by_method(method) ) return method_is_output @@ -298,7 +324,7 @@ class Component(CustomComponent): def _connect_to_component(self, key, value, _input): component = value.__self__ self._components.append(component) - output = component._get_output_by_method(value) + output = component.get_output_by_method(value) self._add_edge(component, key, output, _input) def _add_edge(self, component, key, output, _input): diff --git a/src/backend/base/langflow/graph/graph/base.py b/src/backend/base/langflow/graph/graph/base.py index a47b11b99..9a93b6053 100644 --- a/src/backend/base/langflow/graph/graph/base.py +++ b/src/backend/base/langflow/graph/graph/base.py @@ -19,6 +19,7 @@ from langflow.graph.graph.constants import Finish, lazy_load_vertex_dict from langflow.graph.graph.runnable_vertices_manager import RunnableVerticesManager from langflow.graph.graph.schema import GraphData, GraphDump, VertexBuildResult from langflow.graph.graph.state_manager import GraphStateManager +from langflow.graph.graph.state_model import create_state_model_from_graph from langflow.graph.graph.utils import find_start_component_id, process_flow, sort_up_to_vertex from langflow.graph.schema import InterfaceComponentTypes, RunOutputs from langflow.graph.vertex.base import Vertex, VertexStates @@ -62,6 +63,7 @@ class Graph: log_config = {"disable": False} configure(**log_config) self._start = start + self._state_model = None self._end = end self._prepared = False self._runs = 0 @@ -109,6 +111,12 @@ class Graph: if (start is not None and end is None) or (start is None and end is not None): raise ValueError("You must provide both input and output components") + @property + def state_model(self): + if not self._state_model: + self._state_model = create_state_model_from_graph(self) + return self._state_model + def __add__(self, other): if not isinstance(other, Graph): raise TypeError("Can only add Graph objects") diff --git a/src/backend/base/langflow/graph/graph/state_model.py b/src/backend/base/langflow/graph/graph/state_model.py new file mode 100644 index 000000000..e747f079a --- /dev/null +++ b/src/backend/base/langflow/graph/graph/state_model.py @@ -0,0 +1,67 @@ +import re + +from langflow.graph.state.model import create_state_model +from langflow.helpers.base_model import BaseModel + + +def camel_to_snake(camel_str: str) -> str: + snake_str = re.sub(r"(? type[BaseModel]: + """ + Create a Pydantic state model from a graph representation. + + This function generates a Pydantic model that represents the state of an entire graph. + It creates getter methods for each vertex in the graph, allowing access to the state + of individual components within the graph structure. + + Args: + graph (BaseModel): The graph object from which to create the state model. + This should be a Pydantic model representing the graph structure, + with a 'vertices' attribute containing all graph vertices. + + Returns: + type[BaseModel]: A dynamically created Pydantic model class representing + the state of the entire graph. This model will have properties + corresponding to each vertex in the graph, with names converted from + the vertex IDs to snake case. + + Raises: + ValueError: If any vertex in the graph does not have a properly initialized + component instance (i.e., if vertex._custom_component is None). + + Notes: + - Each vertex in the graph must have a '_custom_component' attribute. + - The '_custom_component' must have a 'get_state_model_instance_getter' method. + - Vertex IDs are converted from camel case to snake case for the resulting model's field names. + - The resulting model uses the 'create_state_model' function with validation disabled. + + Example: + >>> class Vertex(BaseModel): + ... id: str + ... _custom_component: Any + >>> class Graph(BaseModel): + ... vertices: List[Vertex] + >>> # Assume proper setup of vertices and components + >>> graph = Graph(vertices=[...]) + >>> GraphStateModel = create_state_model_from_graph(graph) + >>> graph_state = GraphStateModel() + >>> # Access component states, e.g.: + >>> print(graph_state.some_component_name) + """ + for vertex in graph.vertices: + if hasattr(vertex, "_custom_component") and vertex._custom_component is None: + raise ValueError(f"Vertex {vertex.id} does not have a component instance.") + + state_model_getters = [ + vertex._custom_component.get_state_model_instance_getter() + for vertex in graph.vertices + if hasattr(vertex, "_custom_component") and hasattr(vertex._custom_component, "get_state_model_instance_getter") + ] + fields = { + camel_to_snake(vertex.id): state_model_getter + for vertex, state_model_getter in zip(graph.vertices, state_model_getters) + } + return create_state_model(model_name="GraphStateModel", validate=False, **fields) diff --git a/src/backend/base/langflow/graph/state/__init__.py b/src/backend/base/langflow/graph/state/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/backend/base/langflow/graph/state/model.py b/src/backend/base/langflow/graph/state/model.py new file mode 100644 index 000000000..a3bac48a8 --- /dev/null +++ b/src/backend/base/langflow/graph/state/model.py @@ -0,0 +1,237 @@ +from typing import Any, Callable, get_type_hints + +from pydantic import ConfigDict, computed_field, create_model +from pydantic.fields import FieldInfo + + +def __validate_method(method: Callable) -> None: + """ + Validates a method by checking if it has the required attributes. + + This function ensures that the given method belongs to a class with the necessary + structure for output handling. It checks for the presence of a __self__ attribute + on the method and a get_output_by_method attribute on the method's class. + + Args: + method (Callable): The method to be validated. + + Raises: + ValueError: If the method does not have a __self__ attribute or if the method's + class does not have a get_output_by_method attribute. + + Example: + >>> class ValidClass: + ... def get_output_by_method(self): + ... pass + ... def valid_method(self): + ... pass + >>> __validate_method(ValidClass().valid_method) # This will pass + >>> __validate_method(lambda x: x) # This will raise a ValueError + """ + if not hasattr(method, "__self__"): + raise ValueError(f"Method {method} does not have a __self__ attribute.") + if not hasattr(method.__self__, "get_output_by_method"): + raise ValueError(f"Method's class {method.__self__} must have a get_output_by_method attribute.") + + +def build_output_getter(method: Callable, validate: bool = True) -> Callable: + """ + Builds an output getter function for a given method in a graph component. + + This function creates a new callable that, when invoked, retrieves the output + of the specified method using the get_output_by_method of the method's class. + It's used in creating dynamic state models for graph components. + + Args: + method (Callable): The method for which to build the output getter. + validate (bool, optional): Whether to validate the method before building + the getter. Defaults to True. + + Returns: + Callable: The output getter function. When called, this function returns + the value of the output associated with the original method. + + Raises: + ValueError: If the method has no return type annotation or if validation fails. + + Notes: + - The getter function returns UNDEFINED if the output has not been set. + - When validate is True, the method must belong to a class with a + 'get_output_by_method' attribute. + - This function is typically used internally by create_state_model. + + Example: + >>> class ChatComponent: + ... def get_output_by_method(self, method): + ... return type('Output', (), {'value': "Hello, World!"})() + ... def get_message(self) -> str: + ... pass + >>> component = ChatComponent() + >>> getter = build_output_getter(component.get_message) + >>> print(getter(None)) # This will print "Hello, World!" + """ + + def output_getter(_): + if validate: + __validate_method(method) + methods_class = method.__self__ + output = methods_class.get_output_by_method(method) + return output.value + + return_type = get_type_hints(method).get("return", None) + + if return_type is None: + raise ValueError(f"Method {method.__name__} has no return type annotation.") + output_getter.__annotations__["return"] = return_type + return output_getter + + +def build_output_setter(method: Callable, validate: bool = True) -> Callable: + """ + Build an output setter function for a given method in a graph component. + + This function creates a new callable that, when invoked, sets the output + of the specified method using the get_output_by_method of the method's class. + It's used in creating dynamic state models for graph components, allowing + for the modification of component states. + + Args: + method (Callable): The method for which the output setter is being built. + validate (bool, optional): Flag indicating whether to validate the method + before building the setter. Defaults to True. + + Returns: + Callable: The output setter function. When called with a value, this function + sets the output associated with the original method to that value. + + Raises: + ValueError: If validation fails when validate is True. + + Notes: + - When validate is True, the method must belong to a class with a + 'get_output_by_method' attribute. + - This function is typically used internally by create_state_model. + - The setter allows for dynamic updating of component states in a graph. + + Example: + >>> class ChatComponent: + ... def get_output_by_method(self, method): + ... return type('Output', (), {'value': None})() + ... def set_message(self): + ... pass + >>> component = ChatComponent() + >>> setter = build_output_setter(component.set_message) + >>> setter(component, "New message") + >>> print(component.get_output_by_method(component.set_message).value) # Prints "New message" + """ + + def output_setter(self, value): + if validate: + __validate_method(method) + methods_class = method.__self__ + output = methods_class.get_output_by_method(method) + output.value = value + + return output_setter + + +def create_state_model(model_name: str = "State", validate: bool = True, **kwargs) -> type: + """ + Create a dynamic Pydantic state model based on the provided keyword arguments. + + This function generates a Pydantic model class with fields corresponding to the + provided keyword arguments. It can handle various types of field definitions, + including callable methods (which are converted to properties), FieldInfo objects, + and type-default value tuples. + + Args: + model_name (str, optional): The name of the model. Defaults to "State". + validate (bool, optional): Whether to validate the methods when converting + them to properties. Defaults to True. + **kwargs: Keyword arguments representing the fields of the model. Each argument + can be a callable method, a FieldInfo object, or a tuple of (type, default). + + Returns: + type: The dynamically created Pydantic state model class. + + Raises: + ValueError: If the provided field value is invalid or cannot be processed. + + Examples: + >>> from langflow.components.inputs import ChatInput + >>> from langflow.components.outputs.ChatOutput import ChatOutput + >>> from pydantic import Field + >>> + >>> chat_input = ChatInput() + >>> chat_output = ChatOutput() + >>> + >>> # Create a model with a method from a component + >>> StateModel = create_state_model(method_one=chat_input.message_response) + >>> state = StateModel() + >>> assert state.method_one is UNDEFINED + >>> chat_input.set_output_value("message", "test") + >>> assert state.method_one == "test" + >>> + >>> # Create a model with multiple components and a Pydantic Field + >>> NewStateModel = create_state_model( + ... model_name="NewStateModel", + ... first_method=chat_input.message_response, + ... second_method=chat_output.message_response, + ... my_attribute=Field(None) + ... ) + >>> new_state = NewStateModel() + >>> new_state.first_method = "test" + >>> new_state.my_attribute = 123 + >>> assert new_state.first_method == "test" + >>> assert new_state.my_attribute == 123 + >>> + >>> # Create a model with tuple-based field definitions + >>> TupleStateModel = create_state_model(field_one=(str, "default"), field_two=(int, 123)) + >>> tuple_state = TupleStateModel() + >>> assert tuple_state.field_one == "default" + >>> assert tuple_state.field_two == 123 + + Notes: + - The function handles empty keyword arguments gracefully. + - For tuple-based field definitions, the first element must be a valid Python type. + - Unsupported value types in keyword arguments will raise a ValueError. + - Callable methods must have proper return type annotations and belong to a class + with a 'get_output_by_method' attribute when validate is True. + """ + fields = {} + + for name, value in kwargs.items(): + # Extract the return type from the method's type annotations + if callable(value): + # Define the field with the return type + try: + __validate_method(value) + getter = build_output_getter(value, validate) + setter = build_output_setter(value, validate) + property_method = property(getter, setter) + except ValueError as e: + # If the method is not valid,assume it is already a getter + if "get_output_by_method" not in str(e) and "__self__" not in str(e) or validate: + raise e + property_method = value + fields[name] = computed_field(property_method) + elif isinstance(value, FieldInfo): + field_tuple = (value.annotation or Any, value) + fields[name] = field_tuple + elif isinstance(value, tuple) and len(value) == 2: + # Fields are defined by one of the following tuple forms: + + # (, ) + # (, Field(...)) + # typing.Annotated[, Field(...)] + if not isinstance(value[0], type): + raise ValueError(f"Invalid type for field {name}: {type(value[0])}") + fields[name] = (value[0], value[1]) + else: + raise ValueError(f"Invalid value type {type(value)} for field {name}") + + # Create the model dynamically + config_dict = ConfigDict(arbitrary_types_allowed=True, validate_assignment=True) + model = create_model(model_name, __config__=config_dict, **fields) + + return model diff --git a/src/backend/tests/unit/graph/graph/state/test_state_model.py b/src/backend/tests/unit/graph/graph/state/test_state_model.py new file mode 100644 index 000000000..c7b78842a --- /dev/null +++ b/src/backend/tests/unit/graph/graph/state/test_state_model.py @@ -0,0 +1,139 @@ +import pytest +from pydantic import Field + +from langflow.components.inputs import ChatInput +from langflow.components.outputs.ChatOutput import ChatOutput +from langflow.graph.graph.base import Graph +from langflow.graph.graph.constants import Finish +from langflow.graph.state.model import create_state_model +from langflow.template.field.base import UNDEFINED + + +@pytest.fixture +def client(): + pass + + +@pytest.fixture +def chat_input_component(): + return ChatInput() + + +@pytest.fixture +def chat_output_component(): + return ChatOutput() + + +class TestCreateStateModel: + # Successfully create a model with valid method return type annotations + + def test_create_model_with_valid_return_type_annotations(self, chat_input_component): + StateModel = create_state_model(method_one=chat_input_component.message_response) + + state_instance = StateModel() + assert state_instance.method_one is UNDEFINED + chat_input_component.set_output_value("message", "test") + assert state_instance.method_one == "test" + + def test_create_model_and_assign_values_fails(self, chat_input_component): + StateModel = create_state_model(method_one=chat_input_component.message_response) + + state_instance = StateModel() + state_instance.method_one = "test" + assert state_instance.method_one == "test" + + def test_create_with_multiple_components(self, chat_input_component, chat_output_component): + NewStateModel = create_state_model( + model_name="NewStateModel", + first_method=chat_input_component.message_response, + second_method=chat_output_component.message_response, + ) + state_instance = NewStateModel() + assert state_instance.first_method is UNDEFINED + assert state_instance.second_method is UNDEFINED + state_instance.first_method = "test" + state_instance.second_method = 123 + assert state_instance.first_method == "test" + assert state_instance.second_method == 123 + + def test_create_with_pydantic_field(self, chat_input_component): + StateModel = create_state_model(method_one=chat_input_component.message_response, my_attribute=Field(None)) + + state_instance = StateModel() + state_instance.method_one = "test" + state_instance.my_attribute = "test" + assert state_instance.method_one == "test" + assert state_instance.my_attribute == "test" + # my_attribute should be of type Any + state_instance.my_attribute = 123 + assert state_instance.my_attribute == 123 + + # Creates a model with fields based on provided keyword arguments + def test_create_model_with_fields_from_kwargs(self): + StateModel = create_state_model(field_one=(str, "default"), field_two=(int, 123)) + state_instance = StateModel() + assert state_instance.field_one == "default" + assert state_instance.field_two == 123 + + # Raises ValueError for invalid field type in tuple-based definitions + def test_raise_valueerror_for_invalid_field_type_in_tuple(self): + with pytest.raises(ValueError, match="Invalid type for field invalid_field"): + create_state_model(invalid_field=("not_a_type", "default")) + + # Raises ValueError for unsupported value types in keyword arguments + def test_raise_valueerror_for_unsupported_value_types(self): + with pytest.raises(ValueError, match="Invalid value type for field invalid_field"): + create_state_model(invalid_field=123) + + # Handles empty keyword arguments gracefully + def test_handle_empty_kwargs_gracefully(self): + StateModel = create_state_model() + state_instance = StateModel() + assert state_instance is not None + + # Ensures model name defaults to "State" if not provided + def test_default_model_name_to_state(self): + StateModel = create_state_model() + assert StateModel.__name__ == "State" + OtherNameModel = create_state_model(model_name="OtherName") + assert OtherNameModel.__name__ == "OtherName" + + # Validates that callable values are properly type-annotated + + def test_create_model_with_invalid_callable(self): + class MockComponent: + def method_one(self) -> str: + return "test" + + def method_two(self) -> int: + return 123 + + mock_component = MockComponent() + with pytest.raises(ValueError, match="get_output_by_method"): + create_state_model(method_one=mock_component.method_one, method_two=mock_component.method_two) + + def test_graph_functional_start_state_update(self): + chat_input = ChatInput(_id="chat_input") + chat_output = ChatOutput(input_value="test", _id="chat_output") + chat_output.set(sender_name=chat_input.message_response) + ChatStateModel = create_state_model(model_name="ChatState", message=chat_output.message_response) + chat_state_model = ChatStateModel() + assert chat_state_model.__class__.__name__ == "ChatState" + assert chat_state_model.message is UNDEFINED + + graph = Graph(chat_input, chat_output) + graph.prepare() + # Now iterate through the graph + # and check that the graph is running + # correctly + ids = ["chat_input", "chat_output"] + results = [] + for result in graph.start(): + results.append(result) + + assert len(results) == 3 + assert all(result.vertex.id in ids for result in results if hasattr(result, "vertex")) + assert results[-1] == Finish() + + assert chat_state_model.__class__.__name__ == "ChatState" + assert chat_state_model.message.get_text() == "test" diff --git a/src/backend/tests/unit/graph/graph/test_graph_state_model.py b/src/backend/tests/unit/graph/graph/test_graph_state_model.py new file mode 100644 index 000000000..e7555596b --- /dev/null +++ b/src/backend/tests/unit/graph/graph/test_graph_state_model.py @@ -0,0 +1,173 @@ +import pytest +from pydantic import BaseModel + +from langflow.components.helpers.Memory import MemoryComponent +from langflow.components.inputs.ChatInput import ChatInput +from langflow.components.models.OpenAIModel import OpenAIModelComponent +from langflow.components.outputs.ChatOutput import ChatOutput +from langflow.components.prompts.Prompt import PromptComponent +from langflow.graph import Graph +from langflow.graph.graph.constants import Finish +from langflow.graph.graph.state_model import create_state_model_from_graph + + +@pytest.fixture +def client(): + pass + + +def test_graph_state_model(): + session_id = "test_session_id" + template = """{context} + +User: {user_message} +AI: """ + memory_component = MemoryComponent(_id="chat_memory") + memory_component.set(session_id=session_id) + chat_input = ChatInput(_id="chat_input") + prompt_component = PromptComponent(_id="prompt") + prompt_component.set( + template=template, user_message=chat_input.message_response, context=memory_component.retrieve_messages_as_text + ) + openai_component = OpenAIModelComponent(_id="openai") + openai_component.set( + input_value=prompt_component.build_prompt, max_tokens=100, temperature=0.1, api_key="test_api_key" + ) + openai_component.get_output("text_output").value = "Mock response" + + chat_output = ChatOutput(_id="chat_output") + chat_output.set(input_value=openai_component.text_response) + + graph = Graph(chat_input, chat_output) + + GraphStateModel = create_state_model_from_graph(graph) + assert GraphStateModel.__name__ == "GraphStateModel" + assert list(GraphStateModel.model_computed_fields.keys()) == [ + "chat_input", + "chat_output", + "openai", + "prompt", + "chat_memory", + ] + + +def test_graph_functional_start_graph_state_update(): + chat_input = ChatInput(_id="chat_input") + chat_input.set(input_value="Test Sender Name") + chat_output = ChatOutput(input_value="test", _id="chat_output") + chat_output.set(sender_name=chat_input.message_response) + + graph = Graph(chat_input, chat_output) + graph.prepare() + # Now iterate through the graph + # and check that the graph is running + # correctly + GraphStateModel = create_state_model_from_graph(graph) + graph_state_model = GraphStateModel() + ids = ["chat_input", "chat_output"] + results = [] + for result in graph.start(): + results.append(result) + + assert len(results) == 3 + assert all(result.vertex.id in ids for result in results if hasattr(result, "vertex")) + assert results[-1] == Finish() + + assert graph_state_model.__class__.__name__ == "GraphStateModel" + assert graph_state_model.chat_input.message.get_text() == "Test Sender Name" + assert graph_state_model.chat_output.message.get_text() == "test" + + +def test_graph_state_model_serialization(): + chat_input = ChatInput(_id="chat_input") + chat_input.set(input_value="Test Sender Name") + chat_output = ChatOutput(input_value="test", _id="chat_output") + chat_output.set(sender_name=chat_input.message_response) + + graph = Graph(chat_input, chat_output) + graph.prepare() + # Now iterate through the graph + # and check that the graph is running + # correctly + GraphStateModel = create_state_model_from_graph(graph) + graph_state_model = GraphStateModel() + ids = ["chat_input", "chat_output"] + results = [] + for result in graph.start(): + results.append(result) + + assert len(results) == 3 + assert all(result.vertex.id in ids for result in results if hasattr(result, "vertex")) + assert results[-1] == Finish() + + assert graph_state_model.__class__.__name__ == "GraphStateModel" + assert graph_state_model.chat_input.message.get_text() == "Test Sender Name" + assert graph_state_model.chat_output.message.get_text() == "test" + + serialized_state_model = graph_state_model.model_dump() + assert serialized_state_model["chat_input"]["message"]["text"] == "Test Sender Name" + + +def test_graph_state_model_json_schema(): + chat_input = ChatInput(_id="chat_input") + chat_input.set(input_value="Test Sender Name") + chat_output = ChatOutput(input_value="test", _id="chat_output") + chat_output.set(sender_name=chat_input.message_response) + + graph = Graph(chat_input, chat_output) + graph.prepare() + + GraphStateModel = create_state_model_from_graph(graph) + graph_state_model: BaseModel = GraphStateModel() + json_schema = graph_state_model.model_json_schema(mode="serialization") + + # Test main schema structure + assert json_schema["title"] == "GraphStateModel" + assert json_schema["type"] == "object" + assert set(json_schema["required"]) == {"chat_input", "chat_output"} + + # Test chat_input and chat_output properties + for prop in ["chat_input", "chat_output"]: + assert prop in json_schema["properties"] + assert json_schema["properties"][prop]["allOf"][0]["$ref"].startswith("#/$defs/") + assert json_schema["properties"][prop]["readOnly"] is True + + # Test $defs + assert set(json_schema["$defs"].keys()) == {"ChatInputStateModel", "ChatOutputStateModel", "Image", "Message"} + + # Test ChatInputStateModel and ChatOutputStateModel + for model in ["ChatInputStateModel", "ChatOutputStateModel"]: + assert json_schema["$defs"][model]["type"] == "object" + assert json_schema["$defs"][model]["title"] == model + assert "message" in json_schema["$defs"][model]["properties"] + assert json_schema["$defs"][model]["properties"]["message"]["allOf"][0]["$ref"] == "#/$defs/Message" + assert json_schema["$defs"][model]["properties"]["message"]["readOnly"] is True + assert json_schema["$defs"][model]["required"] == ["message"] + + # Test Message model + message_props = json_schema["$defs"]["Message"]["properties"] + assert set(message_props.keys()) == { + "text_key", + "data", + "default_value", + "text", + "sender", + "sender_name", + "files", + "session_id", + "timestamp", + "flow_id", + } + assert message_props["text_key"]["type"] == "string" + assert message_props["data"]["type"] == "object" + assert "anyOf" in message_props["default_value"] + assert "anyOf" in message_props["files"] + assert message_props["timestamp"]["type"] == "string" + + # Test Image model + image_props = json_schema["$defs"]["Image"]["properties"] + assert set(image_props.keys()) == {"path", "url"} + for prop in ["path", "url"]: + assert "anyOf" in image_props[prop] + assert {"type": "string"} in image_props[prop]["anyOf"] + assert {"type": "null"} in image_props[prop]["anyOf"]