refactor: update BatchRunComponent to enhance functionality and usability (#7318)

* refactor: update BatchRunComponent to enhance functionality and usability

- Added TOML formatting for rows when no specific column is set.
- Updated display names and descriptions for clarity.
- Introduced an output column name option for customizable model response storage.
- Improved metadata handling and error management.
- Refactored row creation to include original columns and enhanced metadata.

* [autofix.ci] apply automated fixes

* fix: ruff errors

* [autofix.ci] apply automated fixes

* fix: component tests

* [autofix.ci] apply automated fixes

* Update src/backend/base/langflow/components/helpers/batch_run.py

Co-authored-by: Edwin Jose <edwin.jose@datastax.com>

* [autofix.ci] apply automated fixes

* [autofix.ci] apply automated fixes (attempt 2/3)

* ♻️ (batch_run.py): refactor type annotations to use Hashable for dictionary keys to improve type safety and compatibility with different types of keys

* youtube fix

* 🔧 (batch_run.py): remove unnecessary StrInput import and update MessageTextInput import to improve code cleanliness and remove redundancy

* 📝 (batch_run.py): Update import statement to include Hashable from collections.abc for better readability and maintainability
📝 (Youtube Analysis.json): Update display name from "Batch Results" to "DataFrame" for better clarity and consistency

* [autofix.ci] apply automated fixes

* [autofix.ci] apply automated fixes (attempt 2/3)

* uv ruff fixes

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: italojohnny <italojohnnydosanjos@gmail.com>
Co-authored-by: Edwin Jose <edwin.jose@datastax.com>
Co-authored-by: Cristhian Zanforlin Lousa <cristhian.lousa@gmail.com>
Co-authored-by: Carlos Coelho <80289056+carlosrcoelho@users.noreply.github.com>
This commit is contained in:
Rodrigo Nader 2025-04-10 06:37:15 -07:00 committed by GitHub
commit 67009190cd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 130 additions and 76 deletions

View file

@ -1,32 +1,23 @@
from __future__ import annotations
import operator
from typing import TYPE_CHECKING, Any
import toml # type: ignore[import-untyped]
from loguru import logger
from langflow.custom import Component
from langflow.io import (
BoolInput,
DataFrameInput,
HandleInput,
MessageTextInput,
MultilineInput,
Output,
)
from langflow.io import BoolInput, DataFrameInput, HandleInput, MessageTextInput, MultilineInput, Output
from langflow.schema import DataFrame
if TYPE_CHECKING:
from collections.abc import Hashable
from langchain_core.runnables import Runnable
class BatchRunComponent(Component):
display_name = "Batch Run"
description = (
"Runs a language model over each row of a DataFrame's text column and returns a new "
"DataFrame with three columns: '**text_input**' (the original text), "
"'**model_response**' (the model's response),and '**batch_index**' (the processing order)."
)
description = "Runs an LLM over each row of a DataFrame's column. If no column is set, the entire row is passed."
icon = "List"
beta = True
@ -40,7 +31,7 @@ class BatchRunComponent(Component):
),
MultilineInput(
name="system_message",
display_name="System Message",
display_name="Instructions",
info="Multi-line system instruction for all rows in the DataFrame.",
required=False,
),
@ -53,16 +44,26 @@ class BatchRunComponent(Component):
MessageTextInput(
name="column_name",
display_name="Column Name",
info="The name of the DataFrame column to treat as text messages. Default='text'.",
value="text",
required=True,
info=(
"The name of the DataFrame column to treat as text messages. "
"If empty, all columns will be formatted in TOML."
),
required=False,
advanced=False,
),
MessageTextInput(
name="output_column_name",
display_name="Output Column Name",
info="Name of the column where the model's response will be stored.",
value="model_response",
required=False,
advanced=True,
),
BoolInput(
name="enable_metadata",
display_name="Enable Metadata",
info="If True, add metadata to the output DataFrame.",
value=True,
value=False,
required=False,
advanced=True,
),
@ -70,23 +71,29 @@ class BatchRunComponent(Component):
outputs = [
Output(
display_name="Batch Results",
display_name="DataFrame",
name="batch_results",
method="run_batch",
info="A DataFrame with columns: 'text_input', 'model_response', 'batch_index', and 'metadata'.",
info="A DataFrame with all original columns plus the model's response column.",
),
]
def _create_base_row(self, text_input: str = "", model_response: str = "", batch_index: int = -1) -> dict[str, Any]:
"""Create a base row with optional metadata."""
return {
"text_input": text_input,
"model_response": model_response,
"batch_index": batch_index,
}
def _format_row_as_toml(self, row: dict[Hashable, Any]) -> str:
"""Convert a dictionary (row) into a TOML-formatted string."""
formatted_dict = {str(col): {"value": str(val)} for col, val in row.items()}
return toml.dumps(formatted_dict)
def _create_base_row(
self, original_row: dict[Hashable, Any], model_response: str = "", batch_index: int = -1
) -> dict[Hashable, Any]:
"""Create a base row with original columns and additional metadata."""
row = original_row.copy()
row[self.output_column_name] = model_response
row["batch_index"] = batch_index
return row
def _add_metadata(
self, row: dict[str, Any], *, success: bool = True, system_msg: str = "", error: str | None = None
self, row: dict[Hashable, Any], *, success: bool = True, system_msg: str = "", error: str | None = None
) -> None:
"""Add metadata to a row if enabled."""
if not self.enable_metadata:
@ -95,8 +102,8 @@ class BatchRunComponent(Component):
if success:
row["metadata"] = {
"has_system_message": bool(system_msg),
"input_length": len(row["text_input"]),
"response_length": len(row["model_response"]),
"input_length": len(row.get("text_input", "")),
"response_length": len(row[self.output_column_name]),
"processing_status": "success",
}
else:
@ -110,10 +117,10 @@ class BatchRunComponent(Component):
Returns:
DataFrame: A new DataFrame containing:
- text_input: The original input text
- model_response: The model's response
- batch_index: The processing order
- metadata: Additional processing information
- All original columns
- The model's response column (customizable name)
- 'batch_index' column for processing order
- 'metadata' (optional)
Raises:
ValueError: If the specified column is not found in the DataFrame
@ -122,22 +129,25 @@ class BatchRunComponent(Component):
model: Runnable = self.model
system_msg = self.system_message or ""
df: DataFrame = self.df
col_name = self.column_name or "text"
col_name = self.column_name or ""
# Validate inputs first
if not isinstance(df, DataFrame):
msg = f"Expected DataFrame input, got {type(df)}"
raise TypeError(msg)
if col_name not in df.columns:
if col_name and col_name not in df.columns:
msg = f"Column '{col_name}' not found in the DataFrame. Available columns: {', '.join(df.columns)}"
raise ValueError(msg)
try:
# Convert the specified column to a list of strings
user_texts = df[col_name].astype(str).tolist()
total_rows = len(user_texts)
# Determine text input for each row
if col_name:
user_texts = df[col_name].astype(str).tolist()
else:
user_texts = [self._format_row_as_toml(row) for row in df.to_dict(orient="records")]
total_rows = len(user_texts)
logger.info(f"Processing {total_rows} rows with batch run")
# Prepare the batch of conversations
@ -166,17 +176,15 @@ class BatchRunComponent(Component):
]
# Sort by index to maintain order
responses_with_idx.sort(key=operator.itemgetter(0))
responses_with_idx.sort(key=lambda x: x[0])
# Build the final data with enhanced metadata
rows: list[dict[str, Any]] = []
for idx, response in responses_with_idx:
resp_text = response.content if hasattr(response, "content") else str(response)
row = self._create_base_row(
text_input=user_texts[idx],
model_response=resp_text,
batch_index=idx,
)
rows: list[dict[Hashable, Any]] = []
for idx, (original_row, response) in enumerate(
zip(df.to_dict(orient="records"), responses_with_idx, strict=False)
):
response_text = response[1].content if hasattr(response[1], "content") else str(response[1])
row = self._create_base_row(original_row, model_response=response_text, batch_index=idx)
self._add_metadata(row, success=True, system_msg=system_msg)
rows.append(row)
@ -190,6 +198,6 @@ class BatchRunComponent(Component):
except (KeyError, AttributeError) as e:
# Handle data structure and attribute access errors
logger.error(f"Data processing error: {e!s}")
error_row = self._create_base_row()
error_row = self._create_base_row({col: "" for col in df.columns}, model_response="", batch_index=-1)
self._add_metadata(error_row, success=False, error=str(e))
return DataFrame([error_row])

File diff suppressed because one or more lines are too long

View file

@ -46,7 +46,7 @@ class TestBatchRunComponent(ComponentTestBaseWithoutClient):
# Verify the results
assert isinstance(result, DataFrame)
assert "text_input" in result.columns
assert "text" in result.columns
assert "model_response" in result.columns
assert "metadata" in result.columns
assert len(result) == 3
@ -121,7 +121,7 @@ class TestBatchRunComponent(ComponentTestBaseWithoutClient):
assert error_row["metadata"]["processing_status"] == "failed"
assert "Mock error during batch processing" in error_row["metadata"]["error"]
# Verify base row structure
assert error_row["text_input"] == ""
assert error_row["text"] == ""
assert error_row["model_response"] == ""
assert error_row["batch_index"] == -1
@ -149,45 +149,66 @@ class TestBatchRunComponent(ComponentTestBaseWithoutClient):
# Verify no metadata
assert "metadata" not in error_row
# Verify base row structure
assert error_row["text_input"] == ""
assert error_row["text"] == ""
assert error_row["model_response"] == ""
assert error_row["batch_index"] == -1
def test_create_base_row(self):
component = BatchRunComponent()
row = component._create_base_row(text_input="test_input", model_response="test_response", batch_index=1)
assert row == {
"text_input": "test_input",
"model_response": "test_response",
"batch_index": 1,
}
row = component._create_base_row(
original_row={"text_input": "test_input"},
model_response="test_response",
batch_index=1,
)
assert row["text_input"] == "test_input"
assert row["model_response"] == "test_response"
assert row["batch_index"] == 1
def test_add_metadata_success(self):
component = BatchRunComponent(enable_metadata=True)
row = component._create_base_row(text_input="test_input", model_response="test_response", batch_index=1)
component._add_metadata(row, success=True, system_msg="test_system")
# Passa text_input dentro do dicionário original_row
original_row = {"text_input": "test_input"}
row = component._create_base_row(
original_row=original_row,
model_response="test_response",
batch_index=1,
)
component._add_metadata(row, success=True, system_msg="Instructions here")
assert "metadata" in row
assert row["metadata"]["has_system_message"] is True
assert row["metadata"]["processing_status"] == "success"
assert row["metadata"]["input_length"] == len("test_input")
assert row["metadata"]["response_length"] == len("test_response")
assert row["metadata"]["processing_status"] == "success"
def test_add_metadata_failure(self):
component = BatchRunComponent(enable_metadata=True)
row = component._create_base_row()
component._add_metadata(row, success=False, error="test_error")
# Fornecendo um original_row vazio (poderia conter outras chaves se necessário)
row = component._create_base_row(original_row={}, model_response="", batch_index=1)
# Adiciona metadata simulando falha
component._add_metadata(row, success=False, error="Simulated error")
assert "metadata" in row
assert row["metadata"]["processing_status"] == "failed"
assert row["metadata"]["error"] == "test_error"
assert row["metadata"]["error"] == "Simulated error"
def test_metadata_disabled(self):
component = BatchRunComponent(enable_metadata=False)
row = component._create_base_row(text_input="test")
component._add_metadata(row, success=True)
# Fornece text_input dentro do dicionário original_row
row = component._create_base_row(
original_row={"text_input": "test"},
model_response="response",
batch_index=0,
)
component._add_metadata(row, success=True, system_msg="test")
# Como o metadata está desabilitado, ele não deve existir
assert "metadata" not in row
async def test_invalid_column_name(self):
@ -229,7 +250,9 @@ class TestBatchRunComponent(ComponentTestBaseWithoutClient):
result = await component.run_batch()
assert isinstance(result, DataFrame)
assert all(isinstance(text, str) for text in result["text_input"])
assert all(str(num) in text for num, text in zip(test_df["text"], result["text_input"], strict=False))
assert all(isinstance(text, int) for text in result["text"])
assert all(
str(num) in response for num, response in zip(test_df["text"], result["model_response"], strict=False)
)
result_dicts = result.to_dict("records")
assert all(row["metadata"]["processing_status"] == "success" for row in result_dicts)