fix: make chat input sorting logic exclude other components (#5760)

* fix: Adjust chat input handling in layer sorting logic

* fix: Update chat input sorting logic and assertions in tests

* [autofix.ci] apply automated fixes

* fix: Update test assertion for chat input sorting to reflect changes in expected output

* [autofix.ci] apply automated fixes

* fix: update vertex sorting logic to handle start_component_id condition

- Modified the condition for sorting layers to include a check for start_component_id being None, ensuring correct behavior when this parameter is not set.
- This change improves the accuracy of the vertex sorting process in the graph utility functions.

* fix: update test assertion for vertex IDs in test_get_vertices

- Modified the test assertion in `test_get_vertices` to reflect the expected output, ensuring it only checks for "ChatInput" in the returned IDs.
- This change improves the accuracy of the test by aligning it with the current expected behavior of the endpoint.

* [autofix.ci] apply automated fixes

* [autofix.ci] apply automated fixes (attempt 2/3)

* [autofix.ci] apply automated fixes (attempt 3/3)

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Ítalo Johnny <italojohnnydosanjos@gmail.com>
This commit is contained in:
Gabriel Luiz Freitas Almeida 2025-01-20 10:49:44 -03:00 committed by GitHub
commit 18e21cfec9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 15 additions and 12 deletions

View file

@ -743,13 +743,15 @@ def sort_chat_inputs_first(
if not chat_input:
return vertices_layers
# If chat input already in first layer, just move it to index 0
if chat_input_layer_idx == 0:
first_layer = vertices_layers[0]
first_layer.remove(chat_input)
first_layer.insert(0, chat_input)
return vertices_layers
# If chat input is alone in first layer, keep as-is
if len(vertices_layers[0]) == 1:
return vertices_layers
# Otherwise move chat input to its own layer at the start
vertices_layers[0].remove(chat_input)
return [[chat_input], *vertices_layers]
# Otherwise create new layers with chat input first
result_layers = []
@ -865,7 +867,7 @@ def get_sorted_vertices(
# Sort chat inputs first and sort each layer by dependencies
all_layers = [first_layer, *remaining_layers]
if get_vertex_predecessors is not None:
if get_vertex_predecessors is not None and start_component_id is None:
all_layers = sort_chat_inputs_first(all_layers, get_vertex_predecessors)
if get_vertex_successors is not None:
all_layers = sort_layer_by_dependency(all_layers, get_vertex_successors)

View file

@ -483,10 +483,11 @@ def test_chat_inputs_at_start():
return []
result = utils.sort_chat_inputs_first(vertices_layers, get_vertex_predecessors)
assert len(result) == 3 # [chat_input] + original 3 layers
assert result[0] == ["ChatInput1", "B"]
assert result[1] == ["C"] # Original second layer
assert result[2] == ["D"] # Original third layer
assert len(result) == 4 # [chat_input] + original 3 layers
assert result[0] == ["ChatInput1"] # First layer contains only ChatInput1
assert result[1] == ["B"] # Second layer contains B
assert result[2] == ["C"] # Original second layer
assert result[3] == ["D"] # Original third layer
# Test that multiple chat inputs raise an error
vertices_layers_multiple = [["ChatInput1", "B"], ["ChatInput2", "C"], ["D"]]

View file

@ -73,7 +73,7 @@ async def consume_and_assert_stream(r):
assert parsed["event"] == "vertices_sorted"
ids = parsed["data"]["ids"]
ids.sort()
assert ids == ["ChatInput-CIGht", "Memory-amN4Z"]
assert ids == ["ChatInput-CIGht"]
to_run = parsed["data"]["to_run"]
to_run.sort()

View file

@ -257,7 +257,7 @@ async def test_get_vertices(client, added_flow_webhook_test, logged_in_headers):
# The important part is before the - (ConversationBufferMemory, PromptTemplate, ChatOpenAI, LLMChain)
ids = [_id.split("-")[0] for _id in response.json()["ids"]]
assert set(ids) == {"ChatInput", "Webhook"}
assert set(ids) == {"ChatInput"}
async def test_build_vertex_invalid_flow_id(client, logged_in_headers):