feat: update if-else component (#8756)

* add ifelse component

* [autofix.ci] apply automated fixes

* fix test

* fix test

* fix test

* fix tests

* [autofix.ci] apply automated fixes

* fix tests

* fix test

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Lucas Democh <democh@datax.dev>
This commit is contained in:
Yuqi Tang 2025-07-02 15:56:57 -07:00 committed by GitHub
commit 19c9514d54
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 119 additions and 57 deletions

View file

@ -22,20 +22,31 @@ class ConditionalRouterComponent(Component):
info="The primary text input for the operation.",
required=True,
),
DropdownInput(
name="operator",
display_name="Operator",
options=[
"equals",
"not equals",
"contains",
"starts with",
"ends with",
"regex",
"less than",
"less than or equal",
"greater than",
"greater than or equal",
],
info="The operator to apply for comparing the texts.",
value="equals",
real_time_refresh=True,
),
MessageTextInput(
name="match_text",
display_name="Match Text",
info="The text input to compare against.",
required=True,
),
DropdownInput(
name="operator",
display_name="Operator",
options=["equals", "not equals", "contains", "starts with", "ends with", "regex"],
info="The operator to apply for comparing the texts.",
value="equals",
real_time_refresh=True,
),
BoolInput(
name="case_sensitive",
display_name="Case Sensitive",
@ -44,9 +55,15 @@ class ConditionalRouterComponent(Component):
advanced=True,
),
MessageInput(
name="message",
display_name="Alternative Output",
info="The message to pass through either route.",
name="true_case_message",
display_name="Case True",
info="The message to pass if the condition is True.",
advanced=True,
),
MessageInput(
name="false_case_message",
display_name="Case False",
info="The message to pass if the condition is False.",
advanced=True,
),
IntInput(
@ -94,6 +111,20 @@ class ConditionalRouterComponent(Component):
return bool(re.match(match_text, input_text))
except re.error:
return False # Return False if the regex is invalid
if operator in ["less than", "less than or equal", "greater than", "greater than or equal"]:
try:
input_num = float(input_text)
match_num = float(match_text)
if operator == "less than":
return input_num < match_num
if operator == "less than or equal":
return input_num <= match_num
if operator == "greater than":
return input_num > match_num
if operator == "greater than or equal":
return input_num >= match_num
except ValueError:
return False # Invalid number format for comparison
return False
def iterate_and_stop_once(self, route_to_stop: str):
@ -109,9 +140,9 @@ class ConditionalRouterComponent(Component):
self.input_text, self.match_text, self.operator, case_sensitive=self.case_sensitive
)
if result:
self.status = self.message
self.status = self.true_case_message
self.iterate_and_stop_once("false_result")
return self.message
return self.true_case_message
self.iterate_and_stop_once("true_result")
return Message(content="")
@ -120,9 +151,9 @@ class ConditionalRouterComponent(Component):
self.input_text, self.match_text, self.operator, case_sensitive=self.case_sensitive
)
if not result:
self.status = self.message
self.status = self.false_case_message
self.iterate_and_stop_once("true_result")
return self.message
return self.false_case_message
self.iterate_and_stop_once("false_result")
return Message(content="")
@ -130,8 +161,6 @@ class ConditionalRouterComponent(Component):
if field_name == "operator":
if field_value == "regex":
build_config.pop("case_sensitive", None)
# Ensure case_sensitive is present for all other operators
elif "case_sensitive" not in build_config:
case_sensitive_input = next(
(input_field for input_field in self.inputs if input_field.name == "case_sensitive"), None

File diff suppressed because one or more lines are too long

View file

@ -32,15 +32,16 @@ class Concatenate(Component):
def test_cycle_in_graph():
chat_input = ChatInput(_id="chat_input")
router = ConditionalRouterComponent(_id="router", default_route="true_result")
# Use router's message output instead of false_response
chat_input.set(input_value=router.message)
# Use router's true case message output instead of message
chat_input.set(input_value=router.true_case_message)
concat_component = Concatenate(_id="concatenate")
concat_component.set(text=chat_input.message_response)
router.set(
input_text=chat_input.message_response,
match_text="testtesttesttest",
operator="equals",
message=concat_component.concatenate,
true_case_message=concat_component.concatenate,
false_case_message=concat_component.concatenate,
)
text_output = TextOutputComponent(_id="text_output")
text_output.set(input_value=router.true_response)
@ -83,14 +84,16 @@ def test_cycle_in_graph():
def test_cycle_in_graph_max_iterations():
text_input = TextInputComponent(_id="text_input")
router = ConditionalRouterComponent(_id="router")
# Connect text_input to router's input
text_input.set(input_value=router.false_response)
concat_component = Concatenate(_id="concatenate")
concat_component.set(text=text_input.text_response)
# Connect concatenate output back to router's input to create cycle
router.set(
input_text=text_input.text_response,
match_text="testtesttesttest",
operator="equals",
message=concat_component.concatenate,
false_case_message=concat_component.concatenate,
)
text_output = TextOutputComponent(_id="text_output")
text_output.set(input_value=router.true_response)
@ -100,7 +103,7 @@ def test_cycle_in_graph_max_iterations():
graph = Graph(text_input, chat_output)
assert graph.is_cyclic is True
# Run queue should contain chat_input and not router
# Run queue should contain text_input and not router
assert "text_input" in graph._run_queue
assert "router" not in graph._run_queue
@ -109,24 +112,24 @@ def test_cycle_in_graph_max_iterations():
def test_that_outputs_cache_is_set_to_false_in_cycle():
chat_input = ChatInput(_id="chat_input")
text_input = TextInputComponent(_id="text_input")
router = ConditionalRouterComponent(_id="router")
# Use router's message output instead of false_response
chat_input.set(input_value=router.message)
concat_component = Concatenate(_id="concatenate")
concat_component.set(text=chat_input.message_response)
text_input.set(input_value=router.false_response)
concat_component.set(text=text_input.text_response)
router.set(
input_text=chat_input.message_response,
input_text=text_input.text_response,
match_text="testtesttesttest",
operator="equals",
message=concat_component.concatenate,
true_case_message=concat_component.concatenate,
false_case_message=concat_component.concatenate,
)
text_output = TextOutputComponent(_id="text_output")
text_output.set(input_value=router.true_response)
chat_output = ChatOutput(_id="chat_output")
chat_output.set(input_value=text_output.text_response)
graph = Graph(chat_input, chat_output)
graph = Graph(text_input, chat_output)
cycle_vertices = find_cycle_vertices(graph._get_edges_as_list_of_tuples())
cycle_outputs_lists = [
graph.vertex_map[vertex_id].custom_component._outputs_map.values() for vertex_id in cycle_vertices
@ -167,7 +170,8 @@ def test_updated_graph_with_prompts():
input_text=openai_component_1.text_response,
match_text=chat_input.message_response,
operator="contains",
message=openai_component_1.text_response,
true_case_message=openai_component_1.text_response,
false_case_message=openai_component_1.text_response,
)
# Second prompt: After the last try, provide a new hint
@ -236,7 +240,8 @@ def test_updated_graph_with_max_iterations():
input_text=openai_component_1.text_response,
match_text=chat_input.message_response,
operator="contains",
message=openai_component_1.text_response,
true_case_message=openai_component_1.text_response,
false_case_message=openai_component_1.text_response,
)
# Second prompt: After the last try, provide a new hint
@ -290,7 +295,8 @@ def test_conditional_router_max_iterations():
input_text=text_input.text_response,
match_text="bacon",
operator="equals",
message="This message should not be routed to true_result",
true_case_message="This message should not be routed to true_result",
false_case_message="This message should not be routed to false_result",
max_iterations=5,
default_route="true_result",
)