feat: adds metadata and batch_index to batch_run (#6318)

* Update batch_run.py

* updates to test component and fixes formatting

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: anovazzi1 <otavio2204@gmail.com>
Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>
This commit is contained in:
Edwin Jose 2025-02-14 16:04:50 -05:00 committed by GitHub
commit a1967bc472
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 287 additions and 47 deletions

View file

@ -1,9 +1,18 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
from loguru import logger
from langflow.custom import Component
from langflow.io import DataFrameInput, HandleInput, MultilineInput, Output, StrInput
from langflow.io import (
BoolInput,
DataFrameInput,
HandleInput,
MessageTextInput,
MultilineInput,
Output,
)
from langflow.schema import DataFrame
if TYPE_CHECKING:
@ -14,8 +23,8 @@ class BatchRunComponent(Component):
display_name = "Batch Run"
description = (
"Runs a language model over each row of a DataFrame's text column and returns a new "
"DataFrame with two columns: 'text_input' (the original text) and 'model_response' "
"containing the model's response."
"DataFrame with three columns: '**text_input**' (the original text), "
"'**model_response**' (the model's response),and '**batch_index**' (the processing order)."
)
icon = "List"
beta = True
@ -26,6 +35,7 @@ class BatchRunComponent(Component):
display_name="Language Model",
info="Connect the 'Language Model' output from your LLM component here.",
input_types=["LanguageModel"],
required=True,
),
MultilineInput(
name="system_message",
@ -37,12 +47,23 @@ class BatchRunComponent(Component):
name="df",
display_name="DataFrame",
info="The DataFrame whose column (specified by 'column_name') we'll treat as text messages.",
required=True,
),
StrInput(
MessageTextInput(
name="column_name",
display_name="Column Name",
info="The name of the DataFrame column to treat as text messages. Default='text'.",
value="text",
required=True,
advanced=True,
),
BoolInput(
name="enable_metadata",
display_name="Enable Metadata",
info="If True, add metadata to the output DataFrame.",
value=True,
required=False,
advanced=True,
),
]
@ -51,51 +72,123 @@ class BatchRunComponent(Component):
display_name="Batch Results",
name="batch_results",
method="run_batch",
info="A DataFrame with two columns: 'text_input' and 'model_response'.",
info="A DataFrame with columns: 'text_input', 'model_response', 'batch_index', and 'metadata'.",
),
]
async def run_batch(self) -> DataFrame:
"""For each row in df[column_name], combine that text with system_message, then invoke the model asynchronously.
def _create_base_row(self, text_input: str = "", model_response: str = "", batch_index: int = -1) -> dict[str, Any]:
"""Create a base row with optional metadata."""
return {
"text_input": text_input,
"model_response": model_response,
"batch_index": batch_index,
}
Returns a new DataFrame of the same length, with columns 'text_input' and 'model_response'.
def _add_metadata(
self, row: dict[str, Any], *, success: bool = True, system_msg: str = "", error: str | None = None
) -> None:
"""Add metadata to a row if enabled."""
if not self.enable_metadata:
return
if success:
row["metadata"] = {
"has_system_message": bool(system_msg),
"input_length": len(row["text_input"]),
"response_length": len(row["model_response"]),
"processing_status": "success",
}
else:
row["metadata"] = {
"error": error,
"processing_status": "failed",
}
async def run_batch(self) -> DataFrame:
"""Process each row in df[column_name] with the language model asynchronously.
Returns:
DataFrame: A new DataFrame containing:
- text_input: The original input text
- model_response: The model's response
- batch_index: The processing order
- metadata: Additional processing information
Raises:
ValueError: If the specified column is not found in the DataFrame
TypeError: If the model is not compatible or input types are wrong
"""
model: Runnable = self.model
system_msg = self.system_message or ""
df: DataFrame = self.df
col_name = self.column_name or "text"
# Validate inputs first
if not isinstance(df, DataFrame):
msg = f"Expected DataFrame input, got {type(df)}"
raise TypeError(msg)
if col_name not in df.columns:
msg = f"Column '{col_name}' not found in the DataFrame."
msg = f"Column '{col_name}' not found in the DataFrame. Available columns: {', '.join(df.columns)}"
raise ValueError(msg)
# Convert the specified column to a list of strings
user_texts = df[col_name].astype(str).tolist()
try:
# Convert the specified column to a list of strings
user_texts = df[col_name].astype(str).tolist()
total_rows = len(user_texts)
# Prepare the batch of conversations
conversations = [
[{"role": "system", "content": system_msg}, {"role": "user", "content": text}]
if system_msg
else [{"role": "user", "content": text}]
for text in user_texts
]
model = model.with_config(
{
"run_name": self.display_name,
"project_name": self.get_project_name(),
"callbacks": self.get_langchain_callbacks(),
}
)
logger.info(f"Processing {total_rows} rows with batch run")
responses = await model.abatch(conversations)
# Prepare the batch of conversations
conversations = [
[{"role": "system", "content": system_msg}, {"role": "user", "content": text}]
if system_msg
else [{"role": "user", "content": text}]
for text in user_texts
]
# Build the final data, each row has 'text_input' + 'model_response'
rows = []
for original_text, response in zip(user_texts, responses, strict=False):
resp_text = response.content if hasattr(response, "content") else str(response)
# Configure the model with project info and callbacks
model = model.with_config(
{
"run_name": self.display_name,
"project_name": self.get_project_name(),
"callbacks": self.get_langchain_callbacks(),
}
)
row = {"text_input": original_text, "model_response": resp_text}
rows.append(row)
# Process batches and track progress
responses_with_idx = [
(idx, response)
for idx, response in zip(
range(len(conversations)), await model.abatch(list(conversations)), strict=True
)
]
# Convert to a new DataFrame
return DataFrame(rows) # Langflow DataFrame from a list of dicts
# Sort by index to maintain order
responses_with_idx.sort(key=lambda x: x[0])
# Build the final data with enhanced metadata
rows: list[dict[str, Any]] = []
for idx, response in responses_with_idx:
resp_text = response.content if hasattr(response, "content") else str(response)
row = self._create_base_row(
text_input=user_texts[idx],
model_response=resp_text,
batch_index=idx,
)
self._add_metadata(row, success=True, system_msg=system_msg)
rows.append(row)
# Log progress
if (idx + 1) % max(1, total_rows // 10) == 0:
logger.info(f"Processed {idx + 1}/{total_rows} rows")
logger.info("Batch processing completed successfully")
return DataFrame(rows)
except (KeyError, AttributeError) as e:
# Handle data structure and attribute access errors
logger.error(f"Data processing error: {e!s}")
error_row = self._create_base_row()
self._add_metadata(error_row, success=False, error=str(e))
return DataFrame([error_row])

View file

@ -21,6 +21,7 @@ class TestBatchRunComponent(ComponentTestBaseWithoutClient):
"model": MockLanguageModel(),
"df": DataFrame({"text": ["Hello"]}),
"column_name": "text",
"enable_metadata": True,
}
@pytest.fixture
@ -33,7 +34,11 @@ class TestBatchRunComponent(ComponentTestBaseWithoutClient):
test_df = DataFrame({"text": ["Hello", "World", "Test"]})
component = BatchRunComponent(
model=MockLanguageModel(), system_message="You are a helpful assistant", df=test_df, column_name="text"
model=MockLanguageModel(),
system_message="You are a helpful assistant",
df=test_df,
column_name="text",
enable_metadata=True,
)
# Run the batch process
@ -43,46 +48,188 @@ class TestBatchRunComponent(ComponentTestBaseWithoutClient):
assert isinstance(result, DataFrame)
assert "text_input" in result.columns
assert "model_response" in result.columns
assert "metadata" in result.columns
assert len(result) == 3
assert all(isinstance(resp, str) for resp in result["model_response"])
# Convert DataFrame to list of dicts for easier testing
result_dicts = result.to_dict("records")
# Verify metadata
assert all(row["metadata"]["has_system_message"] for row in result_dicts)
assert all(row["metadata"]["processing_status"] == "success" for row in result_dicts)
async def test_batch_run_without_system_message(self):
async def test_batch_run_without_metadata(self):
test_df = DataFrame({"text": ["Hello", "World"]})
component = BatchRunComponent(model=MockLanguageModel(), df=test_df, column_name="text")
component = BatchRunComponent(
model=MockLanguageModel(),
df=test_df,
column_name="text",
enable_metadata=False,
)
result = await component.run_batch()
assert isinstance(result, DataFrame)
assert len(result) == 2
assert "metadata" not in result.columns
assert all(isinstance(resp, str) for resp in result["model_response"])
async def test_batch_run_error_with_metadata(self):
component = BatchRunComponent(
model=MockLanguageModel(),
df="not_a_dataframe", # This will cause a TypeError
column_name="text",
enable_metadata=True,
)
with pytest.raises(TypeError, match=re.escape("Expected DataFrame input, got <class 'str'>")):
await component.run_batch()
async def test_batch_run_error_without_metadata(self):
component = BatchRunComponent(
model=MockLanguageModel(),
df="not_a_dataframe", # This will cause a TypeError
column_name="text",
enable_metadata=False,
)
with pytest.raises(TypeError, match=re.escape("Expected DataFrame input, got <class 'str'>")):
await component.run_batch()
async def test_operational_error_with_metadata(self):
# Create a mock model that raises an AttributeError during processing
class ErrorModel:
def with_config(self, *_, **__):
return self
async def abatch(self, *_):
msg = "Mock error during batch processing"
raise AttributeError(msg)
component = BatchRunComponent(
model=ErrorModel(),
df=DataFrame({"text": ["test1", "test2"]}),
column_name="text",
enable_metadata=True,
)
result = await component.run_batch()
assert isinstance(result, DataFrame)
assert len(result) == 1 # Component returns a single error row
error_row = result.iloc[0]
# Verify error metadata
assert error_row["metadata"]["processing_status"] == "failed"
assert "Mock error during batch processing" in error_row["metadata"]["error"]
# Verify base row structure
assert error_row["text_input"] == ""
assert error_row["model_response"] == ""
assert error_row["batch_index"] == -1
async def test_operational_error_without_metadata(self):
# Create a mock model that raises an AttributeError during processing
class ErrorModel:
def with_config(self, *_, **__):
return self
async def abatch(self, *_):
msg = "Mock error during batch processing"
raise AttributeError(msg)
component = BatchRunComponent(
model=ErrorModel(),
df=DataFrame({"text": ["test1", "test2"]}),
column_name="text",
enable_metadata=False,
)
result = await component.run_batch()
assert isinstance(result, DataFrame)
assert len(result) == 1 # Component returns a single error row
error_row = result.iloc[0]
# Verify no metadata
assert "metadata" not in error_row
# Verify base row structure
assert error_row["text_input"] == ""
assert error_row["model_response"] == ""
assert error_row["batch_index"] == -1
def test_create_base_row(self):
component = BatchRunComponent()
row = component._create_base_row(text_input="test_input", model_response="test_response", batch_index=1)
assert row == {
"text_input": "test_input",
"model_response": "test_response",
"batch_index": 1,
}
def test_add_metadata_success(self):
component = BatchRunComponent(enable_metadata=True)
row = component._create_base_row(text_input="test_input", model_response="test_response", batch_index=1)
component._add_metadata(row, success=True, system_msg="test_system")
assert "metadata" in row
assert row["metadata"]["has_system_message"] is True
assert row["metadata"]["processing_status"] == "success"
assert row["metadata"]["input_length"] == len("test_input")
assert row["metadata"]["response_length"] == len("test_response")
def test_add_metadata_failure(self):
component = BatchRunComponent(enable_metadata=True)
row = component._create_base_row()
component._add_metadata(row, success=False, error="test_error")
assert "metadata" in row
assert row["metadata"]["processing_status"] == "failed"
assert row["metadata"]["error"] == "test_error"
def test_metadata_disabled(self):
component = BatchRunComponent(enable_metadata=False)
row = component._create_base_row(text_input="test")
component._add_metadata(row, success=True)
assert "metadata" not in row
async def test_invalid_column_name(self):
component = BatchRunComponent(
model=MockLanguageModel(), df=DataFrame({"text": ["Hello"]}), column_name="nonexistent_column"
model=MockLanguageModel(),
df=DataFrame({"text": ["Hello"]}),
column_name="nonexistent_column",
enable_metadata=True,
)
with pytest.raises(ValueError, match=re.escape("Column 'nonexistent_column' not found in the DataFrame.")):
with pytest.raises(
ValueError,
match=re.escape("Column 'nonexistent_column' not found in the DataFrame. Available columns: text"),
):
await component.run_batch()
async def test_empty_dataframe(self):
component = BatchRunComponent(model=MockLanguageModel(), df=DataFrame({"text": []}), column_name="text")
component = BatchRunComponent(
model=MockLanguageModel(),
df=DataFrame({"text": []}),
column_name="text",
enable_metadata=True,
)
result = await component.run_batch()
assert isinstance(result, DataFrame)
assert len(result) == 0
async def test_non_string_column_conversion(self):
test_df = DataFrame(
{
"text": [123, 456, 789] # Numeric values
}
)
test_df = DataFrame({"text": [123, 456, 789]}) # Numeric values
component = BatchRunComponent(model=MockLanguageModel(), df=test_df, column_name="text")
component = BatchRunComponent(
model=MockLanguageModel(),
df=test_df,
column_name="text",
enable_metadata=True,
)
result = await component.run_batch()
assert isinstance(result, DataFrame)
assert all(isinstance(text, str) for text in result["text_input"])
assert all(str(num) in text for num, text in zip(test_df["text"], result["text_input"], strict=False))
result_dicts = result.to_dict("records")
assert all(row["metadata"]["processing_status"] == "success" for row in result_dicts)