refactor: update BatchRunComponent to enhance functionality and usability (#7318)
* refactor: update BatchRunComponent to enhance functionality and usability - Added TOML formatting for rows when no specific column is set. - Updated display names and descriptions for clarity. - Introduced an output column name option for customizable model response storage. - Improved metadata handling and error management. - Refactored row creation to include original columns and enhanced metadata. * [autofix.ci] apply automated fixes * fix: ruff errors * [autofix.ci] apply automated fixes * fix: component tests * [autofix.ci] apply automated fixes * Update src/backend/base/langflow/components/helpers/batch_run.py Co-authored-by: Edwin Jose <edwin.jose@datastax.com> * [autofix.ci] apply automated fixes * [autofix.ci] apply automated fixes (attempt 2/3) * ♻️ (batch_run.py): refactor type annotations to use Hashable for dictionary keys to improve type safety and compatibility with different types of keys * youtube fix * 🔧 (batch_run.py): remove unnecessary StrInput import and update MessageTextInput import to improve code cleanliness and remove redundancy * 📝 (batch_run.py): Update import statement to include Hashable from collections.abc for better readability and maintainability 📝 (Youtube Analysis.json): Update display name from "Batch Results" to "DataFrame" for better clarity and consistency * [autofix.ci] apply automated fixes * [autofix.ci] apply automated fixes (attempt 2/3) * uv ruff fixes * [autofix.ci] apply automated fixes --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: italojohnny <italojohnnydosanjos@gmail.com> Co-authored-by: Edwin Jose <edwin.jose@datastax.com> Co-authored-by: Cristhian Zanforlin Lousa <cristhian.lousa@gmail.com> Co-authored-by: Carlos Coelho <80289056+carlosrcoelho@users.noreply.github.com>
This commit is contained in:
parent
404e04989a
commit
67009190cd
3 changed files with 130 additions and 76 deletions
|
|
@ -1,32 +1,23 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import toml # type: ignore[import-untyped]
|
||||
from loguru import logger
|
||||
|
||||
from langflow.custom import Component
|
||||
from langflow.io import (
|
||||
BoolInput,
|
||||
DataFrameInput,
|
||||
HandleInput,
|
||||
MessageTextInput,
|
||||
MultilineInput,
|
||||
Output,
|
||||
)
|
||||
from langflow.io import BoolInput, DataFrameInput, HandleInput, MessageTextInput, MultilineInput, Output
|
||||
from langflow.schema import DataFrame
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Hashable
|
||||
|
||||
from langchain_core.runnables import Runnable
|
||||
|
||||
|
||||
class BatchRunComponent(Component):
|
||||
display_name = "Batch Run"
|
||||
description = (
|
||||
"Runs a language model over each row of a DataFrame's text column and returns a new "
|
||||
"DataFrame with three columns: '**text_input**' (the original text), "
|
||||
"'**model_response**' (the model's response),and '**batch_index**' (the processing order)."
|
||||
)
|
||||
description = "Runs an LLM over each row of a DataFrame's column. If no column is set, the entire row is passed."
|
||||
icon = "List"
|
||||
beta = True
|
||||
|
||||
|
|
@ -40,7 +31,7 @@ class BatchRunComponent(Component):
|
|||
),
|
||||
MultilineInput(
|
||||
name="system_message",
|
||||
display_name="System Message",
|
||||
display_name="Instructions",
|
||||
info="Multi-line system instruction for all rows in the DataFrame.",
|
||||
required=False,
|
||||
),
|
||||
|
|
@ -53,16 +44,26 @@ class BatchRunComponent(Component):
|
|||
MessageTextInput(
|
||||
name="column_name",
|
||||
display_name="Column Name",
|
||||
info="The name of the DataFrame column to treat as text messages. Default='text'.",
|
||||
value="text",
|
||||
required=True,
|
||||
info=(
|
||||
"The name of the DataFrame column to treat as text messages. "
|
||||
"If empty, all columns will be formatted in TOML."
|
||||
),
|
||||
required=False,
|
||||
advanced=False,
|
||||
),
|
||||
MessageTextInput(
|
||||
name="output_column_name",
|
||||
display_name="Output Column Name",
|
||||
info="Name of the column where the model's response will be stored.",
|
||||
value="model_response",
|
||||
required=False,
|
||||
advanced=True,
|
||||
),
|
||||
BoolInput(
|
||||
name="enable_metadata",
|
||||
display_name="Enable Metadata",
|
||||
info="If True, add metadata to the output DataFrame.",
|
||||
value=True,
|
||||
value=False,
|
||||
required=False,
|
||||
advanced=True,
|
||||
),
|
||||
|
|
@ -70,23 +71,29 @@ class BatchRunComponent(Component):
|
|||
|
||||
outputs = [
|
||||
Output(
|
||||
display_name="Batch Results",
|
||||
display_name="DataFrame",
|
||||
name="batch_results",
|
||||
method="run_batch",
|
||||
info="A DataFrame with columns: 'text_input', 'model_response', 'batch_index', and 'metadata'.",
|
||||
info="A DataFrame with all original columns plus the model's response column.",
|
||||
),
|
||||
]
|
||||
|
||||
def _create_base_row(self, text_input: str = "", model_response: str = "", batch_index: int = -1) -> dict[str, Any]:
|
||||
"""Create a base row with optional metadata."""
|
||||
return {
|
||||
"text_input": text_input,
|
||||
"model_response": model_response,
|
||||
"batch_index": batch_index,
|
||||
}
|
||||
def _format_row_as_toml(self, row: dict[Hashable, Any]) -> str:
|
||||
"""Convert a dictionary (row) into a TOML-formatted string."""
|
||||
formatted_dict = {str(col): {"value": str(val)} for col, val in row.items()}
|
||||
return toml.dumps(formatted_dict)
|
||||
|
||||
def _create_base_row(
|
||||
self, original_row: dict[Hashable, Any], model_response: str = "", batch_index: int = -1
|
||||
) -> dict[Hashable, Any]:
|
||||
"""Create a base row with original columns and additional metadata."""
|
||||
row = original_row.copy()
|
||||
row[self.output_column_name] = model_response
|
||||
row["batch_index"] = batch_index
|
||||
return row
|
||||
|
||||
def _add_metadata(
|
||||
self, row: dict[str, Any], *, success: bool = True, system_msg: str = "", error: str | None = None
|
||||
self, row: dict[Hashable, Any], *, success: bool = True, system_msg: str = "", error: str | None = None
|
||||
) -> None:
|
||||
"""Add metadata to a row if enabled."""
|
||||
if not self.enable_metadata:
|
||||
|
|
@ -95,8 +102,8 @@ class BatchRunComponent(Component):
|
|||
if success:
|
||||
row["metadata"] = {
|
||||
"has_system_message": bool(system_msg),
|
||||
"input_length": len(row["text_input"]),
|
||||
"response_length": len(row["model_response"]),
|
||||
"input_length": len(row.get("text_input", "")),
|
||||
"response_length": len(row[self.output_column_name]),
|
||||
"processing_status": "success",
|
||||
}
|
||||
else:
|
||||
|
|
@ -110,10 +117,10 @@ class BatchRunComponent(Component):
|
|||
|
||||
Returns:
|
||||
DataFrame: A new DataFrame containing:
|
||||
- text_input: The original input text
|
||||
- model_response: The model's response
|
||||
- batch_index: The processing order
|
||||
- metadata: Additional processing information
|
||||
- All original columns
|
||||
- The model's response column (customizable name)
|
||||
- 'batch_index' column for processing order
|
||||
- 'metadata' (optional)
|
||||
|
||||
Raises:
|
||||
ValueError: If the specified column is not found in the DataFrame
|
||||
|
|
@ -122,22 +129,25 @@ class BatchRunComponent(Component):
|
|||
model: Runnable = self.model
|
||||
system_msg = self.system_message or ""
|
||||
df: DataFrame = self.df
|
||||
col_name = self.column_name or "text"
|
||||
col_name = self.column_name or ""
|
||||
|
||||
# Validate inputs first
|
||||
if not isinstance(df, DataFrame):
|
||||
msg = f"Expected DataFrame input, got {type(df)}"
|
||||
raise TypeError(msg)
|
||||
|
||||
if col_name not in df.columns:
|
||||
if col_name and col_name not in df.columns:
|
||||
msg = f"Column '{col_name}' not found in the DataFrame. Available columns: {', '.join(df.columns)}"
|
||||
raise ValueError(msg)
|
||||
|
||||
try:
|
||||
# Convert the specified column to a list of strings
|
||||
user_texts = df[col_name].astype(str).tolist()
|
||||
total_rows = len(user_texts)
|
||||
# Determine text input for each row
|
||||
if col_name:
|
||||
user_texts = df[col_name].astype(str).tolist()
|
||||
else:
|
||||
user_texts = [self._format_row_as_toml(row) for row in df.to_dict(orient="records")]
|
||||
|
||||
total_rows = len(user_texts)
|
||||
logger.info(f"Processing {total_rows} rows with batch run")
|
||||
|
||||
# Prepare the batch of conversations
|
||||
|
|
@ -166,17 +176,15 @@ class BatchRunComponent(Component):
|
|||
]
|
||||
|
||||
# Sort by index to maintain order
|
||||
responses_with_idx.sort(key=operator.itemgetter(0))
|
||||
responses_with_idx.sort(key=lambda x: x[0])
|
||||
|
||||
# Build the final data with enhanced metadata
|
||||
rows: list[dict[str, Any]] = []
|
||||
for idx, response in responses_with_idx:
|
||||
resp_text = response.content if hasattr(response, "content") else str(response)
|
||||
row = self._create_base_row(
|
||||
text_input=user_texts[idx],
|
||||
model_response=resp_text,
|
||||
batch_index=idx,
|
||||
)
|
||||
rows: list[dict[Hashable, Any]] = []
|
||||
for idx, (original_row, response) in enumerate(
|
||||
zip(df.to_dict(orient="records"), responses_with_idx, strict=False)
|
||||
):
|
||||
response_text = response[1].content if hasattr(response[1], "content") else str(response[1])
|
||||
row = self._create_base_row(original_row, model_response=response_text, batch_index=idx)
|
||||
self._add_metadata(row, success=True, system_msg=system_msg)
|
||||
rows.append(row)
|
||||
|
||||
|
|
@ -190,6 +198,6 @@ class BatchRunComponent(Component):
|
|||
except (KeyError, AttributeError) as e:
|
||||
# Handle data structure and attribute access errors
|
||||
logger.error(f"Data processing error: {e!s}")
|
||||
error_row = self._create_base_row()
|
||||
error_row = self._create_base_row({col: "" for col in df.columns}, model_response="", batch_index=-1)
|
||||
self._add_metadata(error_row, success=False, error=str(e))
|
||||
return DataFrame([error_row])
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
|
|
@ -46,7 +46,7 @@ class TestBatchRunComponent(ComponentTestBaseWithoutClient):
|
|||
|
||||
# Verify the results
|
||||
assert isinstance(result, DataFrame)
|
||||
assert "text_input" in result.columns
|
||||
assert "text" in result.columns
|
||||
assert "model_response" in result.columns
|
||||
assert "metadata" in result.columns
|
||||
assert len(result) == 3
|
||||
|
|
@ -121,7 +121,7 @@ class TestBatchRunComponent(ComponentTestBaseWithoutClient):
|
|||
assert error_row["metadata"]["processing_status"] == "failed"
|
||||
assert "Mock error during batch processing" in error_row["metadata"]["error"]
|
||||
# Verify base row structure
|
||||
assert error_row["text_input"] == ""
|
||||
assert error_row["text"] == ""
|
||||
assert error_row["model_response"] == ""
|
||||
assert error_row["batch_index"] == -1
|
||||
|
||||
|
|
@ -149,45 +149,66 @@ class TestBatchRunComponent(ComponentTestBaseWithoutClient):
|
|||
# Verify no metadata
|
||||
assert "metadata" not in error_row
|
||||
# Verify base row structure
|
||||
assert error_row["text_input"] == ""
|
||||
assert error_row["text"] == ""
|
||||
assert error_row["model_response"] == ""
|
||||
assert error_row["batch_index"] == -1
|
||||
|
||||
def test_create_base_row(self):
|
||||
component = BatchRunComponent()
|
||||
row = component._create_base_row(text_input="test_input", model_response="test_response", batch_index=1)
|
||||
|
||||
assert row == {
|
||||
"text_input": "test_input",
|
||||
"model_response": "test_response",
|
||||
"batch_index": 1,
|
||||
}
|
||||
row = component._create_base_row(
|
||||
original_row={"text_input": "test_input"},
|
||||
model_response="test_response",
|
||||
batch_index=1,
|
||||
)
|
||||
assert row["text_input"] == "test_input"
|
||||
assert row["model_response"] == "test_response"
|
||||
assert row["batch_index"] == 1
|
||||
|
||||
def test_add_metadata_success(self):
|
||||
component = BatchRunComponent(enable_metadata=True)
|
||||
row = component._create_base_row(text_input="test_input", model_response="test_response", batch_index=1)
|
||||
component._add_metadata(row, success=True, system_msg="test_system")
|
||||
|
||||
# Passa text_input dentro do dicionário original_row
|
||||
original_row = {"text_input": "test_input"}
|
||||
row = component._create_base_row(
|
||||
original_row=original_row,
|
||||
model_response="test_response",
|
||||
batch_index=1,
|
||||
)
|
||||
|
||||
component._add_metadata(row, success=True, system_msg="Instructions here")
|
||||
|
||||
assert "metadata" in row
|
||||
assert row["metadata"]["has_system_message"] is True
|
||||
assert row["metadata"]["processing_status"] == "success"
|
||||
assert row["metadata"]["input_length"] == len("test_input")
|
||||
assert row["metadata"]["response_length"] == len("test_response")
|
||||
assert row["metadata"]["processing_status"] == "success"
|
||||
|
||||
def test_add_metadata_failure(self):
|
||||
component = BatchRunComponent(enable_metadata=True)
|
||||
row = component._create_base_row()
|
||||
component._add_metadata(row, success=False, error="test_error")
|
||||
|
||||
# Fornecendo um original_row vazio (poderia conter outras chaves se necessário)
|
||||
row = component._create_base_row(original_row={}, model_response="", batch_index=1)
|
||||
|
||||
# Adiciona metadata simulando falha
|
||||
component._add_metadata(row, success=False, error="Simulated error")
|
||||
|
||||
assert "metadata" in row
|
||||
assert row["metadata"]["processing_status"] == "failed"
|
||||
assert row["metadata"]["error"] == "test_error"
|
||||
assert row["metadata"]["error"] == "Simulated error"
|
||||
|
||||
def test_metadata_disabled(self):
|
||||
component = BatchRunComponent(enable_metadata=False)
|
||||
row = component._create_base_row(text_input="test")
|
||||
component._add_metadata(row, success=True)
|
||||
|
||||
# Fornece text_input dentro do dicionário original_row
|
||||
row = component._create_base_row(
|
||||
original_row={"text_input": "test"},
|
||||
model_response="response",
|
||||
batch_index=0,
|
||||
)
|
||||
|
||||
component._add_metadata(row, success=True, system_msg="test")
|
||||
|
||||
# Como o metadata está desabilitado, ele não deve existir
|
||||
assert "metadata" not in row
|
||||
|
||||
async def test_invalid_column_name(self):
|
||||
|
|
@ -229,7 +250,9 @@ class TestBatchRunComponent(ComponentTestBaseWithoutClient):
|
|||
result = await component.run_batch()
|
||||
|
||||
assert isinstance(result, DataFrame)
|
||||
assert all(isinstance(text, str) for text in result["text_input"])
|
||||
assert all(str(num) in text for num, text in zip(test_df["text"], result["text_input"], strict=False))
|
||||
assert all(isinstance(text, int) for text in result["text"])
|
||||
assert all(
|
||||
str(num) in response for num, response in zip(test_df["text"], result["model_response"], strict=False)
|
||||
)
|
||||
result_dicts = result.to_dict("records")
|
||||
assert all(row["metadata"]["processing_status"] == "success" for row in result_dicts)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue