fix: correct loop component dependencies (#8091)
* feat: Minimal experiment with zipping pre- and post-loop lists Update test JSON to demonstrate a simple workflow using custom components for sequence generation and zipping, with a loop component to process the data. The changes include: - Replaced previous components with custom components - Added a sequence maker component - Added a zipper component - Configured loop component to work with the new components - Updated flow description and last tested version * feat: Refactor Loop Test workflow with enhanced component interactions Update LoopTest.json to demonstrate a more complex data processing workflow: - Modify MyZipper component to return Message instead of Data - Update Loop component's stop condition logic - Adjust node positions and connections - Upgrade last tested version to 1.2.0 * test: Enhance Loop Component Test with JSON Parsing and Assertion Add more robust testing for the Loop component by: - Parsing TextOutput event from the response - Extracting and parsing JSON data - Adding detailed assertions to verify loop output - Improving test coverage for loop component interactions * refactor: simplify LoopTest.json structure and update node definitions - Reduced the size of LoopTest.json by removing unnecessary edges and nodes. - Updated node definitions for `ParseData` and `MessagetoData` components to enhance clarity and maintainability. - Adjusted connections between nodes to reflect the new structure, ensuring proper data flow. - Improved documentation within the JSON structure for better understanding of component functionalities. * feat: add method to retrieve incoming edge by target parameter - Implemented `get_incoming_edge_by_target_param` method in both `Component` and `Vertex` classes to facilitate the retrieval of source vertex IDs for incoming edges targeting specific parameters. - Enhanced performance by caching outgoing and incoming edges in the `Vertex` class to avoid redundant calculations. * feat: add dependency update method in LoopComponent - Introduced `update_dependency` method to manage dependencies for the next iteration in the loop. - Refactored existing logic to ensure proper handling of current items and loop termination conditions. - Enhanced code clarity and maintainability by restructuring the flow of data processing within the loop. * [autofix.ci] apply automated fixes * refactor: update message assertions in TestLoopComponentWithAPI for accuracy * feat: enhance LoopTest.json structure and component definitions - Expanded the LoopTest.json file to include additional nodes and edges, improving the representation of component interactions. - Updated definitions for `MyZipper`, `LoopComponent`, `MessagetoData`, and `ChatOutput` to enhance clarity and functionality. - Introduced new properties and methods in components to support better data handling and processing. - Improved documentation within the JSON structure for better understanding of component functionalities and usage. * feat: add ran_at_least_once tracking to RunnableVerticesManager - Introduced a new set, `ran_at_least_once`, to track vertices that have been executed at least once. - Updated serialization methods to include the new property for state management. - Enhanced logic in `all_predecessors_are_fulfilled` to prevent infinite loops for loop vertices that have already run. * fix: add error handling for missing vertex in Component class * refactor: improve variable naming and enhance readability in TestLoopComponentWithAPI * feat: track vertex execution in run_manager by adding ran_at_least_once tracking * feat: Enhance LoopComponent with dependency management and improved item output handling - Added a method to update dependencies for the LoopComponent to ensure proper execution order. - Improved item output logic to handle stopping conditions more effectively and update dependencies for subsequent runs. - Refactored the item_output method to streamline the flow of data processing and context management. * [autofix.ci] apply automated fixes --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Edwin Jose <edwin.jose@datastax.com> Co-authored-by: Eric Hare <ericrhare@gmail.com>
This commit is contained in:
parent
d3d06be8e5
commit
d50c90522e
8 changed files with 476 additions and 1034 deletions
|
|
@ -67,20 +67,27 @@ class LoopComponent(Component):
|
|||
|
||||
if self.evaluate_stop_loop():
|
||||
self.stop("item")
|
||||
return Data(text="")
|
||||
else:
|
||||
# Get data list and current index
|
||||
data_list, current_index = self.loop_variables()
|
||||
if current_index < len(data_list):
|
||||
# Output current item and increment index
|
||||
try:
|
||||
current_item = data_list[current_index]
|
||||
except IndexError:
|
||||
current_item = Data(text="")
|
||||
self.aggregated_output()
|
||||
self.update_ctx({f"{self._id}_index": current_index + 1})
|
||||
|
||||
# Get data list and current index
|
||||
data_list, current_index = self.loop_variables()
|
||||
if current_index < len(data_list):
|
||||
# Output current item and increment index
|
||||
try:
|
||||
current_item = data_list[current_index]
|
||||
except IndexError:
|
||||
current_item = Data(text="")
|
||||
self.aggregated_output()
|
||||
self.update_ctx({f"{self._id}_index": current_index + 1})
|
||||
# Now we need to update the dependencies for the next run
|
||||
self.update_dependency()
|
||||
return current_item
|
||||
|
||||
def update_dependency(self):
|
||||
item_dependency_id = self.get_incoming_edge_by_target_param("item")
|
||||
|
||||
self.graph.run_manager.run_predecessors[self._id].append(item_dependency_id)
|
||||
|
||||
def done_output(self) -> DataFrame:
|
||||
"""Trigger the done output when iteration is complete."""
|
||||
self.initialize_data()
|
||||
|
|
|
|||
|
|
@ -160,6 +160,23 @@ class Component(CustomComponent):
|
|||
self.set_class_code()
|
||||
self._set_output_required_inputs()
|
||||
|
||||
def get_incoming_edge_by_target_param(self, target_param: str) -> str | None:
|
||||
"""Get the source vertex ID for an incoming edge that targets a specific parameter.
|
||||
|
||||
This method delegates to the underlying vertex to find an incoming edge that connects
|
||||
to the specified target parameter.
|
||||
|
||||
Args:
|
||||
target_param (str): The name of the target parameter to find an incoming edge for
|
||||
|
||||
Returns:
|
||||
str | None: The ID of the source vertex if an incoming edge is found, None otherwise
|
||||
"""
|
||||
if self._vertex is None:
|
||||
msg = "Vertex not found. Please build the graph first."
|
||||
raise ValueError(msg)
|
||||
return self._vertex.get_incoming_edge_by_target_param(target_param)
|
||||
|
||||
@property
|
||||
def enabled_tools(self) -> list[str] | None:
|
||||
"""Dynamically determine which tools should be enabled.
|
||||
|
|
|
|||
|
|
@ -1585,6 +1585,7 @@ class Graph:
|
|||
async def get_next_runnable_vertices(self, lock: asyncio.Lock, vertex: Vertex, *, cache: bool = True) -> list[str]:
|
||||
v_id = vertex.id
|
||||
v_successors_ids = vertex.successors_ids
|
||||
self.run_manager.ran_at_least_once.add(v_id)
|
||||
async with lock:
|
||||
self.run_manager.remove_vertex_from_runnables(v_id)
|
||||
next_runnable_vertices = self.find_next_runnable_vertices(v_successors_ids)
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ class RunnableVerticesManager:
|
|||
self.vertices_to_run: set[str] = set() # Set of vertices that are ready to run
|
||||
self.vertices_being_run: set[str] = set() # Set of vertices that are currently running
|
||||
self.cycle_vertices: set[str] = set() # Set of vertices that are in a cycle
|
||||
self.ran_at_least_once: set[str] = set() # Set of vertices that have been run at least once
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
|
|
@ -15,6 +16,7 @@ class RunnableVerticesManager:
|
|||
"run_predecessors": self.run_predecessors,
|
||||
"vertices_to_run": self.vertices_to_run,
|
||||
"vertices_being_run": self.vertices_being_run,
|
||||
"ran_at_least_once": self.ran_at_least_once,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
|
@ -24,6 +26,7 @@ class RunnableVerticesManager:
|
|||
instance.run_predecessors = data["run_predecessors"]
|
||||
instance.vertices_to_run = data["vertices_to_run"]
|
||||
instance.vertices_being_run = data["vertices_being_run"]
|
||||
instance.ran_at_least_once = data.get("ran_at_least_once", set())
|
||||
return instance
|
||||
|
||||
def __getstate__(self) -> object:
|
||||
|
|
@ -32,6 +35,7 @@ class RunnableVerticesManager:
|
|||
"run_predecessors": self.run_predecessors,
|
||||
"vertices_to_run": self.vertices_to_run,
|
||||
"vertices_being_run": self.vertices_being_run,
|
||||
"ran_at_least_once": self.ran_at_least_once,
|
||||
}
|
||||
|
||||
def __setstate__(self, state: dict) -> None:
|
||||
|
|
@ -39,6 +43,7 @@ class RunnableVerticesManager:
|
|||
self.run_predecessors = state["run_predecessors"]
|
||||
self.vertices_to_run = state["vertices_to_run"]
|
||||
self.vertices_being_run = state["vertices_being_run"]
|
||||
self.ran_at_least_once = state["ran_at_least_once"]
|
||||
|
||||
def all_predecessors_are_fulfilled(self) -> bool:
|
||||
return all(not value for value in self.run_predecessors.values())
|
||||
|
|
@ -81,6 +86,12 @@ class RunnableVerticesManager:
|
|||
# For cycle vertices, check if any pending predecessors are also in cycle
|
||||
# Using set intersection is faster than iteration
|
||||
if vertex_id in self.cycle_vertices:
|
||||
# If this is a loop vertex that has run before and has no pending predecessors,
|
||||
# it should not run again to prevent infinite loops
|
||||
if is_loop and vertex_id in self.ran_at_least_once and bool(set(pending)):
|
||||
return False
|
||||
# For loop vertices, allow running if it's a loop or if none of its pending
|
||||
# predecessors are also cycle vertices (preventing circular dependencies)
|
||||
return is_loop or not bool(set(pending) & self.cycle_vertices)
|
||||
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -106,6 +106,8 @@ class Vertex:
|
|||
self.output_names: list[str] = [
|
||||
output["name"] for output in self.outputs if isinstance(output, dict) and "name" in output
|
||||
]
|
||||
self._incoming_edges: list[CycleEdge] | None = None
|
||||
self._outgoing_edges: list[CycleEdge] | None = None
|
||||
|
||||
@property
|
||||
def is_loop(self) -> bool:
|
||||
|
|
@ -185,11 +187,19 @@ class Vertex:
|
|||
|
||||
@property
|
||||
def outgoing_edges(self) -> list[CycleEdge]:
|
||||
return [edge for edge in self.edges if edge.source_id == self.id]
|
||||
if self._outgoing_edges is None:
|
||||
self._outgoing_edges = [edge for edge in self.edges if edge.source_id == self.id]
|
||||
return self._outgoing_edges
|
||||
|
||||
@property
|
||||
def incoming_edges(self) -> list[CycleEdge]:
|
||||
return [edge for edge in self.edges if edge.target_id == self.id]
|
||||
if self._incoming_edges is None:
|
||||
self._incoming_edges = [edge for edge in self.edges if edge.target_id == self.id]
|
||||
return self._incoming_edges
|
||||
|
||||
# Get edge connected to an output of a certain name
|
||||
def get_incoming_edge_by_target_param(self, target_param: str) -> str | None:
|
||||
return next((edge.source_id for edge in self.incoming_edges if edge.target_param == target_param), None)
|
||||
|
||||
@property
|
||||
def edges_source_names(self) -> set[str | None]:
|
||||
|
|
|
|||
|
|
@ -1728,7 +1728,7 @@
|
|||
"show": true,
|
||||
"title_case": false,
|
||||
"type": "code",
|
||||
"value": "from langflow.custom import Component\nfrom langflow.io import HandleInput, Output\nfrom langflow.schema import Data\nfrom langflow.schema.dataframe import DataFrame\n\n\nclass LoopComponent(Component):\n display_name = \"Loop\"\n description = (\n \"Iterates over a list of Data objects, outputting one item at a time and aggregating results from loop inputs.\"\n )\n icon = \"infinity\"\n\n inputs = [\n HandleInput(\n name=\"data\",\n display_name=\"Data or DataFrame\",\n info=\"The initial list of Data objects or DataFrame to iterate over.\",\n input_types=[\"Data\", \"DataFrame\"],\n ),\n ]\n\n outputs = [\n Output(display_name=\"Item\", name=\"item\", method=\"item_output\", allows_loop=True, group_outputs=True),\n Output(display_name=\"Done\", name=\"done\", method=\"done_output\", group_outputs=True),\n ]\n\n def initialize_data(self) -> None:\n \"\"\"Initialize the data list, context index, and aggregated list.\"\"\"\n if self.ctx.get(f\"{self._id}_initialized\", False):\n return\n\n # Ensure data is a list of Data objects\n data_list = self._validate_data(self.data)\n\n # Store the initial data and context variables\n self.update_ctx(\n {\n f\"{self._id}_data\": data_list,\n f\"{self._id}_index\": 0,\n f\"{self._id}_aggregated\": [],\n f\"{self._id}_initialized\": True,\n }\n )\n\n def _validate_data(self, data):\n \"\"\"Validate and return a list of Data objects.\"\"\"\n if isinstance(data, DataFrame):\n return data.to_data_list()\n if isinstance(data, Data):\n return [data]\n if isinstance(data, list) and all(isinstance(item, Data) for item in data):\n return data\n msg = \"The 'data' input must be a DataFrame, a list of Data objects, or a single Data object.\"\n raise TypeError(msg)\n\n def evaluate_stop_loop(self) -> bool:\n \"\"\"Evaluate whether to stop item or done output.\"\"\"\n current_index = self.ctx.get(f\"{self._id}_index\", 0)\n data_length = len(self.ctx.get(f\"{self._id}_data\", []))\n return current_index > data_length\n\n def item_output(self) -> Data:\n \"\"\"Output the next item in the list or stop if done.\"\"\"\n self.initialize_data()\n current_item = Data(text=\"\")\n\n if self.evaluate_stop_loop():\n self.stop(\"item\")\n return Data(text=\"\")\n\n # Get data list and current index\n data_list, current_index = self.loop_variables()\n if current_index < len(data_list):\n # Output current item and increment index\n try:\n current_item = data_list[current_index]\n except IndexError:\n current_item = Data(text=\"\")\n self.aggregated_output()\n self.update_ctx({f\"{self._id}_index\": current_index + 1})\n return current_item\n\n def done_output(self) -> DataFrame:\n \"\"\"Trigger the done output when iteration is complete.\"\"\"\n self.initialize_data()\n\n if self.evaluate_stop_loop():\n self.stop(\"item\")\n self.start(\"done\")\n\n aggregated = self.ctx.get(f\"{self._id}_aggregated\", [])\n\n return DataFrame(aggregated)\n self.stop(\"done\")\n return DataFrame([])\n\n def loop_variables(self):\n \"\"\"Retrieve loop variables from context.\"\"\"\n return (\n self.ctx.get(f\"{self._id}_data\", []),\n self.ctx.get(f\"{self._id}_index\", 0),\n )\n\n def aggregated_output(self) -> list[Data]:\n \"\"\"Return the aggregated list once all items are processed.\"\"\"\n self.initialize_data()\n\n # Get data list and aggregated list\n data_list = self.ctx.get(f\"{self._id}_data\", [])\n aggregated = self.ctx.get(f\"{self._id}_aggregated\", [])\n loop_input = self.item\n if loop_input is not None and not isinstance(loop_input, str) and len(aggregated) <= len(data_list):\n aggregated.append(loop_input)\n self.update_ctx({f\"{self._id}_aggregated\": aggregated})\n return aggregated\n"
|
||||
"value": "from langflow.custom import Component\nfrom langflow.io import HandleInput, Output\nfrom langflow.schema import Data\nfrom langflow.schema.dataframe import DataFrame\n\n\nclass LoopComponent(Component):\n display_name = \"Loop\"\n description = (\n \"Iterates over a list of Data objects, outputting one item at a time and aggregating results from loop inputs.\"\n )\n icon = \"infinity\"\n\n inputs = [\n HandleInput(\n name=\"data\",\n display_name=\"Data or DataFrame\",\n info=\"The initial list of Data objects or DataFrame to iterate over.\",\n input_types=[\"Data\", \"DataFrame\"],\n ),\n ]\n\n outputs = [\n Output(display_name=\"Item\", name=\"item\", method=\"item_output\", allows_loop=True, group_outputs=True),\n Output(display_name=\"Done\", name=\"done\", method=\"done_output\", group_outputs=True),\n ]\n\n def initialize_data(self) -> None:\n \"\"\"Initialize the data list, context index, and aggregated list.\"\"\"\n if self.ctx.get(f\"{self._id}_initialized\", False):\n return\n\n # Ensure data is a list of Data objects\n data_list = self._validate_data(self.data)\n\n # Store the initial data and context variables\n self.update_ctx(\n {\n f\"{self._id}_data\": data_list,\n f\"{self._id}_index\": 0,\n f\"{self._id}_aggregated\": [],\n f\"{self._id}_initialized\": True,\n }\n )\n\n def _validate_data(self, data):\n \"\"\"Validate and return a list of Data objects.\"\"\"\n if isinstance(data, DataFrame):\n return data.to_data_list()\n if isinstance(data, Data):\n return [data]\n if isinstance(data, list) and all(isinstance(item, Data) for item in data):\n return data\n msg = \"The 'data' input must be a DataFrame, a list of Data objects, or a single Data object.\"\n raise TypeError(msg)\n\n def evaluate_stop_loop(self) -> bool:\n \"\"\"Evaluate whether to stop item or done output.\"\"\"\n current_index = self.ctx.get(f\"{self._id}_index\", 0)\n data_length = len(self.ctx.get(f\"{self._id}_data\", []))\n return current_index > data_length\n\n def item_output(self) -> Data:\n \"\"\"Output the next item in the list or stop if done.\"\"\"\n self.initialize_data()\n current_item = Data(text=\"\")\n\n if self.evaluate_stop_loop():\n self.stop(\"item\")\n else:\n # Get data list and current index\n data_list, current_index = self.loop_variables()\n if current_index < len(data_list):\n # Output current item and increment index\n try:\n current_item = data_list[current_index]\n except IndexError:\n current_item = Data(text=\"\")\n self.aggregated_output()\n self.update_ctx({f\"{self._id}_index\": current_index + 1})\n\n # Now we need to update the dependencies for the next run\n self.update_dependency()\n return current_item\n\n def update_dependency(self):\n item_dependency_id = self.get_incoming_edge_by_target_param(\"item\")\n\n self.graph.run_manager.run_predecessors[self._id].append(item_dependency_id)\n\n def done_output(self) -> DataFrame:\n \"\"\"Trigger the done output when iteration is complete.\"\"\"\n self.initialize_data()\n\n if self.evaluate_stop_loop():\n self.stop(\"item\")\n self.start(\"done\")\n\n aggregated = self.ctx.get(f\"{self._id}_aggregated\", [])\n\n return DataFrame(aggregated)\n self.stop(\"done\")\n return DataFrame([])\n\n def loop_variables(self):\n \"\"\"Retrieve loop variables from context.\"\"\"\n return (\n self.ctx.get(f\"{self._id}_data\", []),\n self.ctx.get(f\"{self._id}_index\", 0),\n )\n\n def aggregated_output(self) -> list[Data]:\n \"\"\"Return the aggregated list once all items are processed.\"\"\"\n self.initialize_data()\n\n # Get data list and aggregated list\n data_list = self.ctx.get(f\"{self._id}_data\", [])\n aggregated = self.ctx.get(f\"{self._id}_aggregated\", [])\n loop_input = self.item\n if loop_input is not None and not isinstance(loop_input, str) and len(aggregated) <= len(data_list):\n aggregated.append(loop_input)\n self.update_ctx({f\"{self._id}_aggregated\": aggregated})\n return aggregated\n"
|
||||
},
|
||||
"data": {
|
||||
"_input_type": "HandleInput",
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
|
|
@ -1,3 +1,4 @@
|
|||
import json
|
||||
from uuid import UUID
|
||||
|
||||
import orjson
|
||||
|
|
@ -52,15 +53,12 @@ class TestLoopComponentWithAPI(ComponentTestBaseWithClient):
|
|||
|
||||
async def check_messages(self, flow_id):
|
||||
messages = await aget_messages(flow_id=UUID(flow_id), order="ASC")
|
||||
assert len(messages) == 2
|
||||
assert len(messages) == 1
|
||||
assert messages[0].session_id == flow_id
|
||||
assert messages[0].sender == "User"
|
||||
assert messages[0].sender_name == "User"
|
||||
assert messages[0].text != ""
|
||||
assert messages[1].session_id == flow_id
|
||||
assert messages[1].sender == "Machine"
|
||||
assert messages[1].sender_name == "AI"
|
||||
assert len(messages[1].text) > 0
|
||||
assert messages[0].sender == "Machine"
|
||||
assert messages[0].sender_name == "AI"
|
||||
assert len(messages[0].text) > 0
|
||||
return messages
|
||||
|
||||
async def test_build_flow_loop(self, client, json_loop_test, logged_in_headers):
|
||||
"""Test building a flow with a loop component."""
|
||||
|
|
@ -77,13 +75,31 @@ class TestLoopComponentWithAPI(ComponentTestBaseWithClient):
|
|||
assert events_response.status_code == 200
|
||||
|
||||
# Process the events stream
|
||||
chat_output = None
|
||||
lines = []
|
||||
async for line in events_response.aiter_lines():
|
||||
if not line: # Skip empty lines
|
||||
continue
|
||||
lines.append(line)
|
||||
if "ChatOutput" in line:
|
||||
chat_output = json.loads(line)
|
||||
# Process events if needed
|
||||
# We could add specific assertions here for loop-related events
|
||||
assert chat_output is not None
|
||||
messages = await self.check_messages(flow_id)
|
||||
ai_message = messages[0].text
|
||||
json_data = orjson.loads(ai_message)
|
||||
|
||||
await self.check_messages(flow_id)
|
||||
# Use a for loop for better debugging
|
||||
found = []
|
||||
json_data = [(data["text"], q_dict) for data, q_dict in json_data]
|
||||
for text, q_dict in json_data:
|
||||
expected_text = f"==> {q_dict['q']}"
|
||||
assert expected_text in text, (
|
||||
f"Found {found} until now, but expected '{expected_text}' not found in '{text}',"
|
||||
f"and the json_data is {json_data}"
|
||||
)
|
||||
found.append(expected_text)
|
||||
|
||||
async def test_run_flow_loop(self, client: AsyncClient, created_api_key, json_loop_test, logged_in_headers):
|
||||
flow_id = await self._create_flow(client, json_loop_test, logged_in_headers)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue