feat: add dynamic state model creation and update (#3271)
* feat: add initial implementation of dynamic state model creation and output getter in graph state module * feat: implement _reset_all_output_values method to initialize component outputs in custom_component class * feat: add state model management with lazy initialization and dynamic instance getter in custom_component class * feat: Refactor Component class to use public method get_output_by_method Refactor the Component class in the custom_component module to change the visibility of the method `_get_output_by_method` to public by renaming it to `get_output_by_method`. This change improves the accessibility and clarity of the method for external use. * feat: add output setter utility to manage output values in state model properties * feat: implement validation for methods' classes in output getter/setter utilities in state model to ensure proper structure * feat: add state model creation from graph in state_model.py * feat: enhance Graph class with lazy loading for state model creation from graph * feat: add unit tests for state model creation and validation in test_state_model.py * feat: add unit tests for state model creation and validation in test_state_model.py * feat: add functional test for graph state update and validation in test_graph_state_model.py * fix: update _instance_getter function to accept a parameter in component.py for state model instance retrieval * refactor: rename test to clarify purpose in test_state_model.py for functional state update validation * chore: import Finish constant in test_graph_state_model.py for improved clarity and usage in state model tests * refactor: add optional validation in output getter/setter methods for improved method integrity in state model handling * refactor: enhance state model creation with optional validation and error handling for output methods in model.py * refactor: serialize and deserialize GraphStateModel in test_graph_state_model.py * refactor: improve error message and add verbose mode for graph start in test_state_model.py * refactor: remove verbose flag from graph.start in TestCreateStateModel for consistency in test_state_model.py * refactor: disable validation when creating GraphStateModel in state_model.py for improved flexibility * refactor: add validation documentation for method attributes in model.py to enhance code clarity and usability * refactor: expand docstring for build_output_getter in model.py to clarify usage and validation details * refactor: add detailed docstring for build_output_setter in model.py to improve clarity on functionality and usage scenarios * refactor: add comprehensive docstring for create_state_model in model.py to clarify functionality and usage examples * refactor: enhance docstring for create_state_model_from_graph in state_model.py to clarify functionality and provide examples * test: add JSON schema validation in graph state model tests for improved structure and correctness verification * refactor: Improve graph_state_model.json_schema unit test readability and structure.
This commit is contained in:
parent
2ffd723065
commit
c5d9cbae49
7 changed files with 654 additions and 4 deletions
|
|
@ -7,6 +7,7 @@ import nanoid # type: ignore
|
|||
import yaml
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langflow.graph.state.model import create_state_model
|
||||
from langflow.helpers.custom import format_type
|
||||
from langflow.schema.artifact import get_artifact_type, post_process_raw
|
||||
from langflow.schema.data import Data
|
||||
|
|
@ -35,7 +36,7 @@ class Component(CustomComponent):
|
|||
def __init__(self, **kwargs):
|
||||
# if key starts with _ it is a config
|
||||
# else it is an input
|
||||
|
||||
self._reset_all_output_values()
|
||||
inputs = {}
|
||||
config = {}
|
||||
for key, value in kwargs.items():
|
||||
|
|
@ -50,6 +51,7 @@ class Component(CustomComponent):
|
|||
self._parameters = inputs or {}
|
||||
self._edges: list[EdgeData] = []
|
||||
self._components: list[Component] = []
|
||||
self._state_model = None
|
||||
self.set_attributes(self._parameters)
|
||||
self._output_logs = {}
|
||||
config = config or {}
|
||||
|
|
@ -70,6 +72,30 @@ class Component(CustomComponent):
|
|||
self._set_output_types()
|
||||
self.set_class_code()
|
||||
|
||||
def _reset_all_output_values(self):
|
||||
for output in self.outputs:
|
||||
setattr(output, "value", UNDEFINED)
|
||||
|
||||
def _build_state_model(self):
|
||||
if self._state_model:
|
||||
return self._state_model
|
||||
name = self.name or self.__class__.__name__
|
||||
model_name = f"{name}StateModel"
|
||||
fields = {}
|
||||
for output in self.outputs:
|
||||
fields[output.name] = getattr(self, output.method)
|
||||
self._state_model = create_state_model(model_name=model_name, **fields)
|
||||
return self._state_model
|
||||
|
||||
def get_state_model_instance_getter(self):
|
||||
state_model = self._build_state_model()
|
||||
|
||||
def _instance_getter(_):
|
||||
return state_model()
|
||||
|
||||
_instance_getter.__annotations__["return"] = state_model
|
||||
return _instance_getter
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
if id(self) in memo:
|
||||
return memo[id(self)]
|
||||
|
|
@ -247,7 +273,7 @@ class Component(CustomComponent):
|
|||
output.add_types(return_types)
|
||||
output.set_selected()
|
||||
|
||||
def _get_output_by_method(self, method: Callable):
|
||||
def get_output_by_method(self, method: Callable):
|
||||
# method is a callable and output.method is a string
|
||||
# we need to find the output that has the same method
|
||||
output = next((output for output in self.outputs if output.method == method.__name__), None)
|
||||
|
|
@ -268,7 +294,7 @@ class Component(CustomComponent):
|
|||
method_is_output = (
|
||||
hasattr(method, "__self__")
|
||||
and isinstance(method.__self__, Component)
|
||||
and method.__self__._get_output_by_method(method)
|
||||
and method.__self__.get_output_by_method(method)
|
||||
)
|
||||
return method_is_output
|
||||
|
||||
|
|
@ -298,7 +324,7 @@ class Component(CustomComponent):
|
|||
def _connect_to_component(self, key, value, _input):
|
||||
component = value.__self__
|
||||
self._components.append(component)
|
||||
output = component._get_output_by_method(value)
|
||||
output = component.get_output_by_method(value)
|
||||
self._add_edge(component, key, output, _input)
|
||||
|
||||
def _add_edge(self, component, key, output, _input):
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from langflow.graph.graph.constants import Finish, lazy_load_vertex_dict
|
|||
from langflow.graph.graph.runnable_vertices_manager import RunnableVerticesManager
|
||||
from langflow.graph.graph.schema import GraphData, GraphDump, VertexBuildResult
|
||||
from langflow.graph.graph.state_manager import GraphStateManager
|
||||
from langflow.graph.graph.state_model import create_state_model_from_graph
|
||||
from langflow.graph.graph.utils import find_start_component_id, process_flow, sort_up_to_vertex
|
||||
from langflow.graph.schema import InterfaceComponentTypes, RunOutputs
|
||||
from langflow.graph.vertex.base import Vertex, VertexStates
|
||||
|
|
@ -62,6 +63,7 @@ class Graph:
|
|||
log_config = {"disable": False}
|
||||
configure(**log_config)
|
||||
self._start = start
|
||||
self._state_model = None
|
||||
self._end = end
|
||||
self._prepared = False
|
||||
self._runs = 0
|
||||
|
|
@ -109,6 +111,12 @@ class Graph:
|
|||
if (start is not None and end is None) or (start is None and end is not None):
|
||||
raise ValueError("You must provide both input and output components")
|
||||
|
||||
@property
|
||||
def state_model(self):
|
||||
if not self._state_model:
|
||||
self._state_model = create_state_model_from_graph(self)
|
||||
return self._state_model
|
||||
|
||||
def __add__(self, other):
|
||||
if not isinstance(other, Graph):
|
||||
raise TypeError("Can only add Graph objects")
|
||||
|
|
|
|||
67
src/backend/base/langflow/graph/graph/state_model.py
Normal file
67
src/backend/base/langflow/graph/graph/state_model.py
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
import re
|
||||
|
||||
from langflow.graph.state.model import create_state_model
|
||||
from langflow.helpers.base_model import BaseModel
|
||||
|
||||
|
||||
def camel_to_snake(camel_str: str) -> str:
|
||||
snake_str = re.sub(r"(?<!^)(?=[A-Z])", "_", camel_str).lower()
|
||||
return snake_str
|
||||
|
||||
|
||||
def create_state_model_from_graph(graph: BaseModel) -> type[BaseModel]:
|
||||
"""
|
||||
Create a Pydantic state model from a graph representation.
|
||||
|
||||
This function generates a Pydantic model that represents the state of an entire graph.
|
||||
It creates getter methods for each vertex in the graph, allowing access to the state
|
||||
of individual components within the graph structure.
|
||||
|
||||
Args:
|
||||
graph (BaseModel): The graph object from which to create the state model.
|
||||
This should be a Pydantic model representing the graph structure,
|
||||
with a 'vertices' attribute containing all graph vertices.
|
||||
|
||||
Returns:
|
||||
type[BaseModel]: A dynamically created Pydantic model class representing
|
||||
the state of the entire graph. This model will have properties
|
||||
corresponding to each vertex in the graph, with names converted from
|
||||
the vertex IDs to snake case.
|
||||
|
||||
Raises:
|
||||
ValueError: If any vertex in the graph does not have a properly initialized
|
||||
component instance (i.e., if vertex._custom_component is None).
|
||||
|
||||
Notes:
|
||||
- Each vertex in the graph must have a '_custom_component' attribute.
|
||||
- The '_custom_component' must have a 'get_state_model_instance_getter' method.
|
||||
- Vertex IDs are converted from camel case to snake case for the resulting model's field names.
|
||||
- The resulting model uses the 'create_state_model' function with validation disabled.
|
||||
|
||||
Example:
|
||||
>>> class Vertex(BaseModel):
|
||||
... id: str
|
||||
... _custom_component: Any
|
||||
>>> class Graph(BaseModel):
|
||||
... vertices: List[Vertex]
|
||||
>>> # Assume proper setup of vertices and components
|
||||
>>> graph = Graph(vertices=[...])
|
||||
>>> GraphStateModel = create_state_model_from_graph(graph)
|
||||
>>> graph_state = GraphStateModel()
|
||||
>>> # Access component states, e.g.:
|
||||
>>> print(graph_state.some_component_name)
|
||||
"""
|
||||
for vertex in graph.vertices:
|
||||
if hasattr(vertex, "_custom_component") and vertex._custom_component is None:
|
||||
raise ValueError(f"Vertex {vertex.id} does not have a component instance.")
|
||||
|
||||
state_model_getters = [
|
||||
vertex._custom_component.get_state_model_instance_getter()
|
||||
for vertex in graph.vertices
|
||||
if hasattr(vertex, "_custom_component") and hasattr(vertex._custom_component, "get_state_model_instance_getter")
|
||||
]
|
||||
fields = {
|
||||
camel_to_snake(vertex.id): state_model_getter
|
||||
for vertex, state_model_getter in zip(graph.vertices, state_model_getters)
|
||||
}
|
||||
return create_state_model(model_name="GraphStateModel", validate=False, **fields)
|
||||
0
src/backend/base/langflow/graph/state/__init__.py
Normal file
0
src/backend/base/langflow/graph/state/__init__.py
Normal file
237
src/backend/base/langflow/graph/state/model.py
Normal file
237
src/backend/base/langflow/graph/state/model.py
Normal file
|
|
@ -0,0 +1,237 @@
|
|||
from typing import Any, Callable, get_type_hints
|
||||
|
||||
from pydantic import ConfigDict, computed_field, create_model
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
|
||||
def __validate_method(method: Callable) -> None:
|
||||
"""
|
||||
Validates a method by checking if it has the required attributes.
|
||||
|
||||
This function ensures that the given method belongs to a class with the necessary
|
||||
structure for output handling. It checks for the presence of a __self__ attribute
|
||||
on the method and a get_output_by_method attribute on the method's class.
|
||||
|
||||
Args:
|
||||
method (Callable): The method to be validated.
|
||||
|
||||
Raises:
|
||||
ValueError: If the method does not have a __self__ attribute or if the method's
|
||||
class does not have a get_output_by_method attribute.
|
||||
|
||||
Example:
|
||||
>>> class ValidClass:
|
||||
... def get_output_by_method(self):
|
||||
... pass
|
||||
... def valid_method(self):
|
||||
... pass
|
||||
>>> __validate_method(ValidClass().valid_method) # This will pass
|
||||
>>> __validate_method(lambda x: x) # This will raise a ValueError
|
||||
"""
|
||||
if not hasattr(method, "__self__"):
|
||||
raise ValueError(f"Method {method} does not have a __self__ attribute.")
|
||||
if not hasattr(method.__self__, "get_output_by_method"):
|
||||
raise ValueError(f"Method's class {method.__self__} must have a get_output_by_method attribute.")
|
||||
|
||||
|
||||
def build_output_getter(method: Callable, validate: bool = True) -> Callable:
|
||||
"""
|
||||
Builds an output getter function for a given method in a graph component.
|
||||
|
||||
This function creates a new callable that, when invoked, retrieves the output
|
||||
of the specified method using the get_output_by_method of the method's class.
|
||||
It's used in creating dynamic state models for graph components.
|
||||
|
||||
Args:
|
||||
method (Callable): The method for which to build the output getter.
|
||||
validate (bool, optional): Whether to validate the method before building
|
||||
the getter. Defaults to True.
|
||||
|
||||
Returns:
|
||||
Callable: The output getter function. When called, this function returns
|
||||
the value of the output associated with the original method.
|
||||
|
||||
Raises:
|
||||
ValueError: If the method has no return type annotation or if validation fails.
|
||||
|
||||
Notes:
|
||||
- The getter function returns UNDEFINED if the output has not been set.
|
||||
- When validate is True, the method must belong to a class with a
|
||||
'get_output_by_method' attribute.
|
||||
- This function is typically used internally by create_state_model.
|
||||
|
||||
Example:
|
||||
>>> class ChatComponent:
|
||||
... def get_output_by_method(self, method):
|
||||
... return type('Output', (), {'value': "Hello, World!"})()
|
||||
... def get_message(self) -> str:
|
||||
... pass
|
||||
>>> component = ChatComponent()
|
||||
>>> getter = build_output_getter(component.get_message)
|
||||
>>> print(getter(None)) # This will print "Hello, World!"
|
||||
"""
|
||||
|
||||
def output_getter(_):
|
||||
if validate:
|
||||
__validate_method(method)
|
||||
methods_class = method.__self__
|
||||
output = methods_class.get_output_by_method(method)
|
||||
return output.value
|
||||
|
||||
return_type = get_type_hints(method).get("return", None)
|
||||
|
||||
if return_type is None:
|
||||
raise ValueError(f"Method {method.__name__} has no return type annotation.")
|
||||
output_getter.__annotations__["return"] = return_type
|
||||
return output_getter
|
||||
|
||||
|
||||
def build_output_setter(method: Callable, validate: bool = True) -> Callable:
|
||||
"""
|
||||
Build an output setter function for a given method in a graph component.
|
||||
|
||||
This function creates a new callable that, when invoked, sets the output
|
||||
of the specified method using the get_output_by_method of the method's class.
|
||||
It's used in creating dynamic state models for graph components, allowing
|
||||
for the modification of component states.
|
||||
|
||||
Args:
|
||||
method (Callable): The method for which the output setter is being built.
|
||||
validate (bool, optional): Flag indicating whether to validate the method
|
||||
before building the setter. Defaults to True.
|
||||
|
||||
Returns:
|
||||
Callable: The output setter function. When called with a value, this function
|
||||
sets the output associated with the original method to that value.
|
||||
|
||||
Raises:
|
||||
ValueError: If validation fails when validate is True.
|
||||
|
||||
Notes:
|
||||
- When validate is True, the method must belong to a class with a
|
||||
'get_output_by_method' attribute.
|
||||
- This function is typically used internally by create_state_model.
|
||||
- The setter allows for dynamic updating of component states in a graph.
|
||||
|
||||
Example:
|
||||
>>> class ChatComponent:
|
||||
... def get_output_by_method(self, method):
|
||||
... return type('Output', (), {'value': None})()
|
||||
... def set_message(self):
|
||||
... pass
|
||||
>>> component = ChatComponent()
|
||||
>>> setter = build_output_setter(component.set_message)
|
||||
>>> setter(component, "New message")
|
||||
>>> print(component.get_output_by_method(component.set_message).value) # Prints "New message"
|
||||
"""
|
||||
|
||||
def output_setter(self, value):
|
||||
if validate:
|
||||
__validate_method(method)
|
||||
methods_class = method.__self__
|
||||
output = methods_class.get_output_by_method(method)
|
||||
output.value = value
|
||||
|
||||
return output_setter
|
||||
|
||||
|
||||
def create_state_model(model_name: str = "State", validate: bool = True, **kwargs) -> type:
|
||||
"""
|
||||
Create a dynamic Pydantic state model based on the provided keyword arguments.
|
||||
|
||||
This function generates a Pydantic model class with fields corresponding to the
|
||||
provided keyword arguments. It can handle various types of field definitions,
|
||||
including callable methods (which are converted to properties), FieldInfo objects,
|
||||
and type-default value tuples.
|
||||
|
||||
Args:
|
||||
model_name (str, optional): The name of the model. Defaults to "State".
|
||||
validate (bool, optional): Whether to validate the methods when converting
|
||||
them to properties. Defaults to True.
|
||||
**kwargs: Keyword arguments representing the fields of the model. Each argument
|
||||
can be a callable method, a FieldInfo object, or a tuple of (type, default).
|
||||
|
||||
Returns:
|
||||
type: The dynamically created Pydantic state model class.
|
||||
|
||||
Raises:
|
||||
ValueError: If the provided field value is invalid or cannot be processed.
|
||||
|
||||
Examples:
|
||||
>>> from langflow.components.inputs import ChatInput
|
||||
>>> from langflow.components.outputs.ChatOutput import ChatOutput
|
||||
>>> from pydantic import Field
|
||||
>>>
|
||||
>>> chat_input = ChatInput()
|
||||
>>> chat_output = ChatOutput()
|
||||
>>>
|
||||
>>> # Create a model with a method from a component
|
||||
>>> StateModel = create_state_model(method_one=chat_input.message_response)
|
||||
>>> state = StateModel()
|
||||
>>> assert state.method_one is UNDEFINED
|
||||
>>> chat_input.set_output_value("message", "test")
|
||||
>>> assert state.method_one == "test"
|
||||
>>>
|
||||
>>> # Create a model with multiple components and a Pydantic Field
|
||||
>>> NewStateModel = create_state_model(
|
||||
... model_name="NewStateModel",
|
||||
... first_method=chat_input.message_response,
|
||||
... second_method=chat_output.message_response,
|
||||
... my_attribute=Field(None)
|
||||
... )
|
||||
>>> new_state = NewStateModel()
|
||||
>>> new_state.first_method = "test"
|
||||
>>> new_state.my_attribute = 123
|
||||
>>> assert new_state.first_method == "test"
|
||||
>>> assert new_state.my_attribute == 123
|
||||
>>>
|
||||
>>> # Create a model with tuple-based field definitions
|
||||
>>> TupleStateModel = create_state_model(field_one=(str, "default"), field_two=(int, 123))
|
||||
>>> tuple_state = TupleStateModel()
|
||||
>>> assert tuple_state.field_one == "default"
|
||||
>>> assert tuple_state.field_two == 123
|
||||
|
||||
Notes:
|
||||
- The function handles empty keyword arguments gracefully.
|
||||
- For tuple-based field definitions, the first element must be a valid Python type.
|
||||
- Unsupported value types in keyword arguments will raise a ValueError.
|
||||
- Callable methods must have proper return type annotations and belong to a class
|
||||
with a 'get_output_by_method' attribute when validate is True.
|
||||
"""
|
||||
fields = {}
|
||||
|
||||
for name, value in kwargs.items():
|
||||
# Extract the return type from the method's type annotations
|
||||
if callable(value):
|
||||
# Define the field with the return type
|
||||
try:
|
||||
__validate_method(value)
|
||||
getter = build_output_getter(value, validate)
|
||||
setter = build_output_setter(value, validate)
|
||||
property_method = property(getter, setter)
|
||||
except ValueError as e:
|
||||
# If the method is not valid,assume it is already a getter
|
||||
if "get_output_by_method" not in str(e) and "__self__" not in str(e) or validate:
|
||||
raise e
|
||||
property_method = value
|
||||
fields[name] = computed_field(property_method)
|
||||
elif isinstance(value, FieldInfo):
|
||||
field_tuple = (value.annotation or Any, value)
|
||||
fields[name] = field_tuple
|
||||
elif isinstance(value, tuple) and len(value) == 2:
|
||||
# Fields are defined by one of the following tuple forms:
|
||||
|
||||
# (<type>, <default value>)
|
||||
# (<type>, Field(...))
|
||||
# typing.Annotated[<type>, Field(...)]
|
||||
if not isinstance(value[0], type):
|
||||
raise ValueError(f"Invalid type for field {name}: {type(value[0])}")
|
||||
fields[name] = (value[0], value[1])
|
||||
else:
|
||||
raise ValueError(f"Invalid value type {type(value)} for field {name}")
|
||||
|
||||
# Create the model dynamically
|
||||
config_dict = ConfigDict(arbitrary_types_allowed=True, validate_assignment=True)
|
||||
model = create_model(model_name, __config__=config_dict, **fields)
|
||||
|
||||
return model
|
||||
139
src/backend/tests/unit/graph/graph/state/test_state_model.py
Normal file
139
src/backend/tests/unit/graph/graph/state/test_state_model.py
Normal file
|
|
@ -0,0 +1,139 @@
|
|||
import pytest
|
||||
from pydantic import Field
|
||||
|
||||
from langflow.components.inputs import ChatInput
|
||||
from langflow.components.outputs.ChatOutput import ChatOutput
|
||||
from langflow.graph.graph.base import Graph
|
||||
from langflow.graph.graph.constants import Finish
|
||||
from langflow.graph.state.model import create_state_model
|
||||
from langflow.template.field.base import UNDEFINED
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def chat_input_component():
|
||||
return ChatInput()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def chat_output_component():
|
||||
return ChatOutput()
|
||||
|
||||
|
||||
class TestCreateStateModel:
|
||||
# Successfully create a model with valid method return type annotations
|
||||
|
||||
def test_create_model_with_valid_return_type_annotations(self, chat_input_component):
|
||||
StateModel = create_state_model(method_one=chat_input_component.message_response)
|
||||
|
||||
state_instance = StateModel()
|
||||
assert state_instance.method_one is UNDEFINED
|
||||
chat_input_component.set_output_value("message", "test")
|
||||
assert state_instance.method_one == "test"
|
||||
|
||||
def test_create_model_and_assign_values_fails(self, chat_input_component):
|
||||
StateModel = create_state_model(method_one=chat_input_component.message_response)
|
||||
|
||||
state_instance = StateModel()
|
||||
state_instance.method_one = "test"
|
||||
assert state_instance.method_one == "test"
|
||||
|
||||
def test_create_with_multiple_components(self, chat_input_component, chat_output_component):
|
||||
NewStateModel = create_state_model(
|
||||
model_name="NewStateModel",
|
||||
first_method=chat_input_component.message_response,
|
||||
second_method=chat_output_component.message_response,
|
||||
)
|
||||
state_instance = NewStateModel()
|
||||
assert state_instance.first_method is UNDEFINED
|
||||
assert state_instance.second_method is UNDEFINED
|
||||
state_instance.first_method = "test"
|
||||
state_instance.second_method = 123
|
||||
assert state_instance.first_method == "test"
|
||||
assert state_instance.second_method == 123
|
||||
|
||||
def test_create_with_pydantic_field(self, chat_input_component):
|
||||
StateModel = create_state_model(method_one=chat_input_component.message_response, my_attribute=Field(None))
|
||||
|
||||
state_instance = StateModel()
|
||||
state_instance.method_one = "test"
|
||||
state_instance.my_attribute = "test"
|
||||
assert state_instance.method_one == "test"
|
||||
assert state_instance.my_attribute == "test"
|
||||
# my_attribute should be of type Any
|
||||
state_instance.my_attribute = 123
|
||||
assert state_instance.my_attribute == 123
|
||||
|
||||
# Creates a model with fields based on provided keyword arguments
|
||||
def test_create_model_with_fields_from_kwargs(self):
|
||||
StateModel = create_state_model(field_one=(str, "default"), field_two=(int, 123))
|
||||
state_instance = StateModel()
|
||||
assert state_instance.field_one == "default"
|
||||
assert state_instance.field_two == 123
|
||||
|
||||
# Raises ValueError for invalid field type in tuple-based definitions
|
||||
def test_raise_valueerror_for_invalid_field_type_in_tuple(self):
|
||||
with pytest.raises(ValueError, match="Invalid type for field invalid_field"):
|
||||
create_state_model(invalid_field=("not_a_type", "default"))
|
||||
|
||||
# Raises ValueError for unsupported value types in keyword arguments
|
||||
def test_raise_valueerror_for_unsupported_value_types(self):
|
||||
with pytest.raises(ValueError, match="Invalid value type <class 'int'> for field invalid_field"):
|
||||
create_state_model(invalid_field=123)
|
||||
|
||||
# Handles empty keyword arguments gracefully
|
||||
def test_handle_empty_kwargs_gracefully(self):
|
||||
StateModel = create_state_model()
|
||||
state_instance = StateModel()
|
||||
assert state_instance is not None
|
||||
|
||||
# Ensures model name defaults to "State" if not provided
|
||||
def test_default_model_name_to_state(self):
|
||||
StateModel = create_state_model()
|
||||
assert StateModel.__name__ == "State"
|
||||
OtherNameModel = create_state_model(model_name="OtherName")
|
||||
assert OtherNameModel.__name__ == "OtherName"
|
||||
|
||||
# Validates that callable values are properly type-annotated
|
||||
|
||||
def test_create_model_with_invalid_callable(self):
|
||||
class MockComponent:
|
||||
def method_one(self) -> str:
|
||||
return "test"
|
||||
|
||||
def method_two(self) -> int:
|
||||
return 123
|
||||
|
||||
mock_component = MockComponent()
|
||||
with pytest.raises(ValueError, match="get_output_by_method"):
|
||||
create_state_model(method_one=mock_component.method_one, method_two=mock_component.method_two)
|
||||
|
||||
def test_graph_functional_start_state_update(self):
|
||||
chat_input = ChatInput(_id="chat_input")
|
||||
chat_output = ChatOutput(input_value="test", _id="chat_output")
|
||||
chat_output.set(sender_name=chat_input.message_response)
|
||||
ChatStateModel = create_state_model(model_name="ChatState", message=chat_output.message_response)
|
||||
chat_state_model = ChatStateModel()
|
||||
assert chat_state_model.__class__.__name__ == "ChatState"
|
||||
assert chat_state_model.message is UNDEFINED
|
||||
|
||||
graph = Graph(chat_input, chat_output)
|
||||
graph.prepare()
|
||||
# Now iterate through the graph
|
||||
# and check that the graph is running
|
||||
# correctly
|
||||
ids = ["chat_input", "chat_output"]
|
||||
results = []
|
||||
for result in graph.start():
|
||||
results.append(result)
|
||||
|
||||
assert len(results) == 3
|
||||
assert all(result.vertex.id in ids for result in results if hasattr(result, "vertex"))
|
||||
assert results[-1] == Finish()
|
||||
|
||||
assert chat_state_model.__class__.__name__ == "ChatState"
|
||||
assert chat_state_model.message.get_text() == "test"
|
||||
173
src/backend/tests/unit/graph/graph/test_graph_state_model.py
Normal file
173
src/backend/tests/unit/graph/graph/test_graph_state_model.py
Normal file
|
|
@ -0,0 +1,173 @@
|
|||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langflow.components.helpers.Memory import MemoryComponent
|
||||
from langflow.components.inputs.ChatInput import ChatInput
|
||||
from langflow.components.models.OpenAIModel import OpenAIModelComponent
|
||||
from langflow.components.outputs.ChatOutput import ChatOutput
|
||||
from langflow.components.prompts.Prompt import PromptComponent
|
||||
from langflow.graph import Graph
|
||||
from langflow.graph.graph.constants import Finish
|
||||
from langflow.graph.graph.state_model import create_state_model_from_graph
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
pass
|
||||
|
||||
|
||||
def test_graph_state_model():
|
||||
session_id = "test_session_id"
|
||||
template = """{context}
|
||||
|
||||
User: {user_message}
|
||||
AI: """
|
||||
memory_component = MemoryComponent(_id="chat_memory")
|
||||
memory_component.set(session_id=session_id)
|
||||
chat_input = ChatInput(_id="chat_input")
|
||||
prompt_component = PromptComponent(_id="prompt")
|
||||
prompt_component.set(
|
||||
template=template, user_message=chat_input.message_response, context=memory_component.retrieve_messages_as_text
|
||||
)
|
||||
openai_component = OpenAIModelComponent(_id="openai")
|
||||
openai_component.set(
|
||||
input_value=prompt_component.build_prompt, max_tokens=100, temperature=0.1, api_key="test_api_key"
|
||||
)
|
||||
openai_component.get_output("text_output").value = "Mock response"
|
||||
|
||||
chat_output = ChatOutput(_id="chat_output")
|
||||
chat_output.set(input_value=openai_component.text_response)
|
||||
|
||||
graph = Graph(chat_input, chat_output)
|
||||
|
||||
GraphStateModel = create_state_model_from_graph(graph)
|
||||
assert GraphStateModel.__name__ == "GraphStateModel"
|
||||
assert list(GraphStateModel.model_computed_fields.keys()) == [
|
||||
"chat_input",
|
||||
"chat_output",
|
||||
"openai",
|
||||
"prompt",
|
||||
"chat_memory",
|
||||
]
|
||||
|
||||
|
||||
def test_graph_functional_start_graph_state_update():
|
||||
chat_input = ChatInput(_id="chat_input")
|
||||
chat_input.set(input_value="Test Sender Name")
|
||||
chat_output = ChatOutput(input_value="test", _id="chat_output")
|
||||
chat_output.set(sender_name=chat_input.message_response)
|
||||
|
||||
graph = Graph(chat_input, chat_output)
|
||||
graph.prepare()
|
||||
# Now iterate through the graph
|
||||
# and check that the graph is running
|
||||
# correctly
|
||||
GraphStateModel = create_state_model_from_graph(graph)
|
||||
graph_state_model = GraphStateModel()
|
||||
ids = ["chat_input", "chat_output"]
|
||||
results = []
|
||||
for result in graph.start():
|
||||
results.append(result)
|
||||
|
||||
assert len(results) == 3
|
||||
assert all(result.vertex.id in ids for result in results if hasattr(result, "vertex"))
|
||||
assert results[-1] == Finish()
|
||||
|
||||
assert graph_state_model.__class__.__name__ == "GraphStateModel"
|
||||
assert graph_state_model.chat_input.message.get_text() == "Test Sender Name"
|
||||
assert graph_state_model.chat_output.message.get_text() == "test"
|
||||
|
||||
|
||||
def test_graph_state_model_serialization():
|
||||
chat_input = ChatInput(_id="chat_input")
|
||||
chat_input.set(input_value="Test Sender Name")
|
||||
chat_output = ChatOutput(input_value="test", _id="chat_output")
|
||||
chat_output.set(sender_name=chat_input.message_response)
|
||||
|
||||
graph = Graph(chat_input, chat_output)
|
||||
graph.prepare()
|
||||
# Now iterate through the graph
|
||||
# and check that the graph is running
|
||||
# correctly
|
||||
GraphStateModel = create_state_model_from_graph(graph)
|
||||
graph_state_model = GraphStateModel()
|
||||
ids = ["chat_input", "chat_output"]
|
||||
results = []
|
||||
for result in graph.start():
|
||||
results.append(result)
|
||||
|
||||
assert len(results) == 3
|
||||
assert all(result.vertex.id in ids for result in results if hasattr(result, "vertex"))
|
||||
assert results[-1] == Finish()
|
||||
|
||||
assert graph_state_model.__class__.__name__ == "GraphStateModel"
|
||||
assert graph_state_model.chat_input.message.get_text() == "Test Sender Name"
|
||||
assert graph_state_model.chat_output.message.get_text() == "test"
|
||||
|
||||
serialized_state_model = graph_state_model.model_dump()
|
||||
assert serialized_state_model["chat_input"]["message"]["text"] == "Test Sender Name"
|
||||
|
||||
|
||||
def test_graph_state_model_json_schema():
|
||||
chat_input = ChatInput(_id="chat_input")
|
||||
chat_input.set(input_value="Test Sender Name")
|
||||
chat_output = ChatOutput(input_value="test", _id="chat_output")
|
||||
chat_output.set(sender_name=chat_input.message_response)
|
||||
|
||||
graph = Graph(chat_input, chat_output)
|
||||
graph.prepare()
|
||||
|
||||
GraphStateModel = create_state_model_from_graph(graph)
|
||||
graph_state_model: BaseModel = GraphStateModel()
|
||||
json_schema = graph_state_model.model_json_schema(mode="serialization")
|
||||
|
||||
# Test main schema structure
|
||||
assert json_schema["title"] == "GraphStateModel"
|
||||
assert json_schema["type"] == "object"
|
||||
assert set(json_schema["required"]) == {"chat_input", "chat_output"}
|
||||
|
||||
# Test chat_input and chat_output properties
|
||||
for prop in ["chat_input", "chat_output"]:
|
||||
assert prop in json_schema["properties"]
|
||||
assert json_schema["properties"][prop]["allOf"][0]["$ref"].startswith("#/$defs/")
|
||||
assert json_schema["properties"][prop]["readOnly"] is True
|
||||
|
||||
# Test $defs
|
||||
assert set(json_schema["$defs"].keys()) == {"ChatInputStateModel", "ChatOutputStateModel", "Image", "Message"}
|
||||
|
||||
# Test ChatInputStateModel and ChatOutputStateModel
|
||||
for model in ["ChatInputStateModel", "ChatOutputStateModel"]:
|
||||
assert json_schema["$defs"][model]["type"] == "object"
|
||||
assert json_schema["$defs"][model]["title"] == model
|
||||
assert "message" in json_schema["$defs"][model]["properties"]
|
||||
assert json_schema["$defs"][model]["properties"]["message"]["allOf"][0]["$ref"] == "#/$defs/Message"
|
||||
assert json_schema["$defs"][model]["properties"]["message"]["readOnly"] is True
|
||||
assert json_schema["$defs"][model]["required"] == ["message"]
|
||||
|
||||
# Test Message model
|
||||
message_props = json_schema["$defs"]["Message"]["properties"]
|
||||
assert set(message_props.keys()) == {
|
||||
"text_key",
|
||||
"data",
|
||||
"default_value",
|
||||
"text",
|
||||
"sender",
|
||||
"sender_name",
|
||||
"files",
|
||||
"session_id",
|
||||
"timestamp",
|
||||
"flow_id",
|
||||
}
|
||||
assert message_props["text_key"]["type"] == "string"
|
||||
assert message_props["data"]["type"] == "object"
|
||||
assert "anyOf" in message_props["default_value"]
|
||||
assert "anyOf" in message_props["files"]
|
||||
assert message_props["timestamp"]["type"] == "string"
|
||||
|
||||
# Test Image model
|
||||
image_props = json_schema["$defs"]["Image"]["properties"]
|
||||
assert set(image_props.keys()) == {"path", "url"}
|
||||
for prop in ["path", "url"]:
|
||||
assert "anyOf" in image_props[prop]
|
||||
assert {"type": "string"} in image_props[prop]["anyOf"]
|
||||
assert {"type": "null"} in image_props[prop]["anyOf"]
|
||||
Loading…
Add table
Add a link
Reference in a new issue