Refactor RoutingVertex to handle missing condition and result values

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-02-22 23:36:02 -03:00
commit 75ee16acb0

View file

@ -373,35 +373,18 @@ class RoutingVertex(StatelessVertex):
return self.artifacts["repr"] or super()._built_object_repr()
return super()._built_object_repr()
def _build(self, *args, **kwargs):
super()._build(*args, **kwargs)
# After building, the _built_object should be a dict with
# {"result": Any, "condition": bool}
# if true, we need to set should_run attr in the target of true edge
# to true and should_run attr in the target of false edge to false
# TODO: Add support for multiple conditions
def _run(self, *args, **kwargs):
if self._built_object:
condition = self._built_object.get("condition")
result = self._built_object.get("result")
if condition is not None:
for edge in self.edges:
if edge.source_id == self.id:
target_vertex = self.graph.get_vertex(edge.target_id)
# source_handle.channel and condition should be the same
channel_bool = edge.source_handle.channel == "true"
if condition == channel_bool:
target_vertex.should_run = True
else:
target_vertex.should_run = False
if condition is None:
raise ValueError("Condition is required for the routing vertex.")
if result is None:
raise ValueError("Result is required for the routing vertex.")
if condition is True:
self._built_result = result
else:
raise ValueError(f"RoutingVertex {self.id} must have a condition in the _built_object")
self._built_result = result
else:
raise ValueError(f"RoutingVertex {self.id} must have a _built_object with a condition and a result")
self.graph.mark_branch(self.id, "INACTIVE")
def dict_to_codeblock(d: dict) -> str: