feat: update if-else component (#8756)
* add ifelse component * [autofix.ci] apply automated fixes * fix test * fix test * fix test * fix tests * [autofix.ci] apply automated fixes * fix tests * fix test --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Lucas Democh <democh@datax.dev>
This commit is contained in:
parent
e28093234b
commit
19c9514d54
3 changed files with 119 additions and 57 deletions
|
|
@ -22,20 +22,31 @@ class ConditionalRouterComponent(Component):
|
|||
info="The primary text input for the operation.",
|
||||
required=True,
|
||||
),
|
||||
DropdownInput(
|
||||
name="operator",
|
||||
display_name="Operator",
|
||||
options=[
|
||||
"equals",
|
||||
"not equals",
|
||||
"contains",
|
||||
"starts with",
|
||||
"ends with",
|
||||
"regex",
|
||||
"less than",
|
||||
"less than or equal",
|
||||
"greater than",
|
||||
"greater than or equal",
|
||||
],
|
||||
info="The operator to apply for comparing the texts.",
|
||||
value="equals",
|
||||
real_time_refresh=True,
|
||||
),
|
||||
MessageTextInput(
|
||||
name="match_text",
|
||||
display_name="Match Text",
|
||||
info="The text input to compare against.",
|
||||
required=True,
|
||||
),
|
||||
DropdownInput(
|
||||
name="operator",
|
||||
display_name="Operator",
|
||||
options=["equals", "not equals", "contains", "starts with", "ends with", "regex"],
|
||||
info="The operator to apply for comparing the texts.",
|
||||
value="equals",
|
||||
real_time_refresh=True,
|
||||
),
|
||||
BoolInput(
|
||||
name="case_sensitive",
|
||||
display_name="Case Sensitive",
|
||||
|
|
@ -44,9 +55,15 @@ class ConditionalRouterComponent(Component):
|
|||
advanced=True,
|
||||
),
|
||||
MessageInput(
|
||||
name="message",
|
||||
display_name="Alternative Output",
|
||||
info="The message to pass through either route.",
|
||||
name="true_case_message",
|
||||
display_name="Case True",
|
||||
info="The message to pass if the condition is True.",
|
||||
advanced=True,
|
||||
),
|
||||
MessageInput(
|
||||
name="false_case_message",
|
||||
display_name="Case False",
|
||||
info="The message to pass if the condition is False.",
|
||||
advanced=True,
|
||||
),
|
||||
IntInput(
|
||||
|
|
@ -94,6 +111,20 @@ class ConditionalRouterComponent(Component):
|
|||
return bool(re.match(match_text, input_text))
|
||||
except re.error:
|
||||
return False # Return False if the regex is invalid
|
||||
if operator in ["less than", "less than or equal", "greater than", "greater than or equal"]:
|
||||
try:
|
||||
input_num = float(input_text)
|
||||
match_num = float(match_text)
|
||||
if operator == "less than":
|
||||
return input_num < match_num
|
||||
if operator == "less than or equal":
|
||||
return input_num <= match_num
|
||||
if operator == "greater than":
|
||||
return input_num > match_num
|
||||
if operator == "greater than or equal":
|
||||
return input_num >= match_num
|
||||
except ValueError:
|
||||
return False # Invalid number format for comparison
|
||||
return False
|
||||
|
||||
def iterate_and_stop_once(self, route_to_stop: str):
|
||||
|
|
@ -109,9 +140,9 @@ class ConditionalRouterComponent(Component):
|
|||
self.input_text, self.match_text, self.operator, case_sensitive=self.case_sensitive
|
||||
)
|
||||
if result:
|
||||
self.status = self.message
|
||||
self.status = self.true_case_message
|
||||
self.iterate_and_stop_once("false_result")
|
||||
return self.message
|
||||
return self.true_case_message
|
||||
self.iterate_and_stop_once("true_result")
|
||||
return Message(content="")
|
||||
|
||||
|
|
@ -120,9 +151,9 @@ class ConditionalRouterComponent(Component):
|
|||
self.input_text, self.match_text, self.operator, case_sensitive=self.case_sensitive
|
||||
)
|
||||
if not result:
|
||||
self.status = self.message
|
||||
self.status = self.false_case_message
|
||||
self.iterate_and_stop_once("true_result")
|
||||
return self.message
|
||||
return self.false_case_message
|
||||
self.iterate_and_stop_once("false_result")
|
||||
return Message(content="")
|
||||
|
||||
|
|
@ -130,8 +161,6 @@ class ConditionalRouterComponent(Component):
|
|||
if field_name == "operator":
|
||||
if field_value == "regex":
|
||||
build_config.pop("case_sensitive", None)
|
||||
|
||||
# Ensure case_sensitive is present for all other operators
|
||||
elif "case_sensitive" not in build_config:
|
||||
case_sensitive_input = next(
|
||||
(input_field for input_field in self.inputs if input_field.name == "case_sensitive"), None
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
|
|
@ -32,15 +32,16 @@ class Concatenate(Component):
|
|||
def test_cycle_in_graph():
|
||||
chat_input = ChatInput(_id="chat_input")
|
||||
router = ConditionalRouterComponent(_id="router", default_route="true_result")
|
||||
# Use router's message output instead of false_response
|
||||
chat_input.set(input_value=router.message)
|
||||
# Use router's true case message output instead of message
|
||||
chat_input.set(input_value=router.true_case_message)
|
||||
concat_component = Concatenate(_id="concatenate")
|
||||
concat_component.set(text=chat_input.message_response)
|
||||
router.set(
|
||||
input_text=chat_input.message_response,
|
||||
match_text="testtesttesttest",
|
||||
operator="equals",
|
||||
message=concat_component.concatenate,
|
||||
true_case_message=concat_component.concatenate,
|
||||
false_case_message=concat_component.concatenate,
|
||||
)
|
||||
text_output = TextOutputComponent(_id="text_output")
|
||||
text_output.set(input_value=router.true_response)
|
||||
|
|
@ -83,14 +84,16 @@ def test_cycle_in_graph():
|
|||
def test_cycle_in_graph_max_iterations():
|
||||
text_input = TextInputComponent(_id="text_input")
|
||||
router = ConditionalRouterComponent(_id="router")
|
||||
# Connect text_input to router's input
|
||||
text_input.set(input_value=router.false_response)
|
||||
concat_component = Concatenate(_id="concatenate")
|
||||
concat_component.set(text=text_input.text_response)
|
||||
# Connect concatenate output back to router's input to create cycle
|
||||
router.set(
|
||||
input_text=text_input.text_response,
|
||||
match_text="testtesttesttest",
|
||||
operator="equals",
|
||||
message=concat_component.concatenate,
|
||||
false_case_message=concat_component.concatenate,
|
||||
)
|
||||
text_output = TextOutputComponent(_id="text_output")
|
||||
text_output.set(input_value=router.true_response)
|
||||
|
|
@ -100,7 +103,7 @@ def test_cycle_in_graph_max_iterations():
|
|||
graph = Graph(text_input, chat_output)
|
||||
assert graph.is_cyclic is True
|
||||
|
||||
# Run queue should contain chat_input and not router
|
||||
# Run queue should contain text_input and not router
|
||||
assert "text_input" in graph._run_queue
|
||||
assert "router" not in graph._run_queue
|
||||
|
||||
|
|
@ -109,24 +112,24 @@ def test_cycle_in_graph_max_iterations():
|
|||
|
||||
|
||||
def test_that_outputs_cache_is_set_to_false_in_cycle():
|
||||
chat_input = ChatInput(_id="chat_input")
|
||||
text_input = TextInputComponent(_id="text_input")
|
||||
router = ConditionalRouterComponent(_id="router")
|
||||
# Use router's message output instead of false_response
|
||||
chat_input.set(input_value=router.message)
|
||||
concat_component = Concatenate(_id="concatenate")
|
||||
concat_component.set(text=chat_input.message_response)
|
||||
text_input.set(input_value=router.false_response)
|
||||
concat_component.set(text=text_input.text_response)
|
||||
router.set(
|
||||
input_text=chat_input.message_response,
|
||||
input_text=text_input.text_response,
|
||||
match_text="testtesttesttest",
|
||||
operator="equals",
|
||||
message=concat_component.concatenate,
|
||||
true_case_message=concat_component.concatenate,
|
||||
false_case_message=concat_component.concatenate,
|
||||
)
|
||||
text_output = TextOutputComponent(_id="text_output")
|
||||
text_output.set(input_value=router.true_response)
|
||||
chat_output = ChatOutput(_id="chat_output")
|
||||
chat_output.set(input_value=text_output.text_response)
|
||||
|
||||
graph = Graph(chat_input, chat_output)
|
||||
graph = Graph(text_input, chat_output)
|
||||
cycle_vertices = find_cycle_vertices(graph._get_edges_as_list_of_tuples())
|
||||
cycle_outputs_lists = [
|
||||
graph.vertex_map[vertex_id].custom_component._outputs_map.values() for vertex_id in cycle_vertices
|
||||
|
|
@ -167,7 +170,8 @@ def test_updated_graph_with_prompts():
|
|||
input_text=openai_component_1.text_response,
|
||||
match_text=chat_input.message_response,
|
||||
operator="contains",
|
||||
message=openai_component_1.text_response,
|
||||
true_case_message=openai_component_1.text_response,
|
||||
false_case_message=openai_component_1.text_response,
|
||||
)
|
||||
|
||||
# Second prompt: After the last try, provide a new hint
|
||||
|
|
@ -236,7 +240,8 @@ def test_updated_graph_with_max_iterations():
|
|||
input_text=openai_component_1.text_response,
|
||||
match_text=chat_input.message_response,
|
||||
operator="contains",
|
||||
message=openai_component_1.text_response,
|
||||
true_case_message=openai_component_1.text_response,
|
||||
false_case_message=openai_component_1.text_response,
|
||||
)
|
||||
|
||||
# Second prompt: After the last try, provide a new hint
|
||||
|
|
@ -290,7 +295,8 @@ def test_conditional_router_max_iterations():
|
|||
input_text=text_input.text_response,
|
||||
match_text="bacon",
|
||||
operator="equals",
|
||||
message="This message should not be routed to true_result",
|
||||
true_case_message="This message should not be routed to true_result",
|
||||
false_case_message="This message should not be routed to false_result",
|
||||
max_iterations=5,
|
||||
default_route="true_result",
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue