feat: adds metadata and batch_index to batch_run (#6318)
* Update batch_run.py * updates to test component and fixes formatting * [autofix.ci] apply automated fixes --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: anovazzi1 <otavio2204@gmail.com> Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>
This commit is contained in:
parent
ec5259a0fc
commit
a1967bc472
2 changed files with 287 additions and 47 deletions
|
|
@ -1,9 +1,18 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from langflow.custom import Component
|
||||
from langflow.io import DataFrameInput, HandleInput, MultilineInput, Output, StrInput
|
||||
from langflow.io import (
|
||||
BoolInput,
|
||||
DataFrameInput,
|
||||
HandleInput,
|
||||
MessageTextInput,
|
||||
MultilineInput,
|
||||
Output,
|
||||
)
|
||||
from langflow.schema import DataFrame
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -14,8 +23,8 @@ class BatchRunComponent(Component):
|
|||
display_name = "Batch Run"
|
||||
description = (
|
||||
"Runs a language model over each row of a DataFrame's text column and returns a new "
|
||||
"DataFrame with two columns: 'text_input' (the original text) and 'model_response' "
|
||||
"containing the model's response."
|
||||
"DataFrame with three columns: '**text_input**' (the original text), "
|
||||
"'**model_response**' (the model's response),and '**batch_index**' (the processing order)."
|
||||
)
|
||||
icon = "List"
|
||||
beta = True
|
||||
|
|
@ -26,6 +35,7 @@ class BatchRunComponent(Component):
|
|||
display_name="Language Model",
|
||||
info="Connect the 'Language Model' output from your LLM component here.",
|
||||
input_types=["LanguageModel"],
|
||||
required=True,
|
||||
),
|
||||
MultilineInput(
|
||||
name="system_message",
|
||||
|
|
@ -37,12 +47,23 @@ class BatchRunComponent(Component):
|
|||
name="df",
|
||||
display_name="DataFrame",
|
||||
info="The DataFrame whose column (specified by 'column_name') we'll treat as text messages.",
|
||||
required=True,
|
||||
),
|
||||
StrInput(
|
||||
MessageTextInput(
|
||||
name="column_name",
|
||||
display_name="Column Name",
|
||||
info="The name of the DataFrame column to treat as text messages. Default='text'.",
|
||||
value="text",
|
||||
required=True,
|
||||
advanced=True,
|
||||
),
|
||||
BoolInput(
|
||||
name="enable_metadata",
|
||||
display_name="Enable Metadata",
|
||||
info="If True, add metadata to the output DataFrame.",
|
||||
value=True,
|
||||
required=False,
|
||||
advanced=True,
|
||||
),
|
||||
]
|
||||
|
||||
|
|
@ -51,51 +72,123 @@ class BatchRunComponent(Component):
|
|||
display_name="Batch Results",
|
||||
name="batch_results",
|
||||
method="run_batch",
|
||||
info="A DataFrame with two columns: 'text_input' and 'model_response'.",
|
||||
info="A DataFrame with columns: 'text_input', 'model_response', 'batch_index', and 'metadata'.",
|
||||
),
|
||||
]
|
||||
|
||||
async def run_batch(self) -> DataFrame:
|
||||
"""For each row in df[column_name], combine that text with system_message, then invoke the model asynchronously.
|
||||
def _create_base_row(self, text_input: str = "", model_response: str = "", batch_index: int = -1) -> dict[str, Any]:
|
||||
"""Create a base row with optional metadata."""
|
||||
return {
|
||||
"text_input": text_input,
|
||||
"model_response": model_response,
|
||||
"batch_index": batch_index,
|
||||
}
|
||||
|
||||
Returns a new DataFrame of the same length, with columns 'text_input' and 'model_response'.
|
||||
def _add_metadata(
|
||||
self, row: dict[str, Any], *, success: bool = True, system_msg: str = "", error: str | None = None
|
||||
) -> None:
|
||||
"""Add metadata to a row if enabled."""
|
||||
if not self.enable_metadata:
|
||||
return
|
||||
|
||||
if success:
|
||||
row["metadata"] = {
|
||||
"has_system_message": bool(system_msg),
|
||||
"input_length": len(row["text_input"]),
|
||||
"response_length": len(row["model_response"]),
|
||||
"processing_status": "success",
|
||||
}
|
||||
else:
|
||||
row["metadata"] = {
|
||||
"error": error,
|
||||
"processing_status": "failed",
|
||||
}
|
||||
|
||||
async def run_batch(self) -> DataFrame:
|
||||
"""Process each row in df[column_name] with the language model asynchronously.
|
||||
|
||||
Returns:
|
||||
DataFrame: A new DataFrame containing:
|
||||
- text_input: The original input text
|
||||
- model_response: The model's response
|
||||
- batch_index: The processing order
|
||||
- metadata: Additional processing information
|
||||
|
||||
Raises:
|
||||
ValueError: If the specified column is not found in the DataFrame
|
||||
TypeError: If the model is not compatible or input types are wrong
|
||||
"""
|
||||
model: Runnable = self.model
|
||||
system_msg = self.system_message or ""
|
||||
df: DataFrame = self.df
|
||||
col_name = self.column_name or "text"
|
||||
|
||||
# Validate inputs first
|
||||
if not isinstance(df, DataFrame):
|
||||
msg = f"Expected DataFrame input, got {type(df)}"
|
||||
raise TypeError(msg)
|
||||
|
||||
if col_name not in df.columns:
|
||||
msg = f"Column '{col_name}' not found in the DataFrame."
|
||||
msg = f"Column '{col_name}' not found in the DataFrame. Available columns: {', '.join(df.columns)}"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Convert the specified column to a list of strings
|
||||
user_texts = df[col_name].astype(str).tolist()
|
||||
try:
|
||||
# Convert the specified column to a list of strings
|
||||
user_texts = df[col_name].astype(str).tolist()
|
||||
total_rows = len(user_texts)
|
||||
|
||||
# Prepare the batch of conversations
|
||||
conversations = [
|
||||
[{"role": "system", "content": system_msg}, {"role": "user", "content": text}]
|
||||
if system_msg
|
||||
else [{"role": "user", "content": text}]
|
||||
for text in user_texts
|
||||
]
|
||||
model = model.with_config(
|
||||
{
|
||||
"run_name": self.display_name,
|
||||
"project_name": self.get_project_name(),
|
||||
"callbacks": self.get_langchain_callbacks(),
|
||||
}
|
||||
)
|
||||
logger.info(f"Processing {total_rows} rows with batch run")
|
||||
|
||||
responses = await model.abatch(conversations)
|
||||
# Prepare the batch of conversations
|
||||
conversations = [
|
||||
[{"role": "system", "content": system_msg}, {"role": "user", "content": text}]
|
||||
if system_msg
|
||||
else [{"role": "user", "content": text}]
|
||||
for text in user_texts
|
||||
]
|
||||
|
||||
# Build the final data, each row has 'text_input' + 'model_response'
|
||||
rows = []
|
||||
for original_text, response in zip(user_texts, responses, strict=False):
|
||||
resp_text = response.content if hasattr(response, "content") else str(response)
|
||||
# Configure the model with project info and callbacks
|
||||
model = model.with_config(
|
||||
{
|
||||
"run_name": self.display_name,
|
||||
"project_name": self.get_project_name(),
|
||||
"callbacks": self.get_langchain_callbacks(),
|
||||
}
|
||||
)
|
||||
|
||||
row = {"text_input": original_text, "model_response": resp_text}
|
||||
rows.append(row)
|
||||
# Process batches and track progress
|
||||
responses_with_idx = [
|
||||
(idx, response)
|
||||
for idx, response in zip(
|
||||
range(len(conversations)), await model.abatch(list(conversations)), strict=True
|
||||
)
|
||||
]
|
||||
|
||||
# Convert to a new DataFrame
|
||||
return DataFrame(rows) # Langflow DataFrame from a list of dicts
|
||||
# Sort by index to maintain order
|
||||
responses_with_idx.sort(key=lambda x: x[0])
|
||||
|
||||
# Build the final data with enhanced metadata
|
||||
rows: list[dict[str, Any]] = []
|
||||
for idx, response in responses_with_idx:
|
||||
resp_text = response.content if hasattr(response, "content") else str(response)
|
||||
row = self._create_base_row(
|
||||
text_input=user_texts[idx],
|
||||
model_response=resp_text,
|
||||
batch_index=idx,
|
||||
)
|
||||
self._add_metadata(row, success=True, system_msg=system_msg)
|
||||
rows.append(row)
|
||||
|
||||
# Log progress
|
||||
if (idx + 1) % max(1, total_rows // 10) == 0:
|
||||
logger.info(f"Processed {idx + 1}/{total_rows} rows")
|
||||
|
||||
logger.info("Batch processing completed successfully")
|
||||
return DataFrame(rows)
|
||||
|
||||
except (KeyError, AttributeError) as e:
|
||||
# Handle data structure and attribute access errors
|
||||
logger.error(f"Data processing error: {e!s}")
|
||||
error_row = self._create_base_row()
|
||||
self._add_metadata(error_row, success=False, error=str(e))
|
||||
return DataFrame([error_row])
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ class TestBatchRunComponent(ComponentTestBaseWithoutClient):
|
|||
"model": MockLanguageModel(),
|
||||
"df": DataFrame({"text": ["Hello"]}),
|
||||
"column_name": "text",
|
||||
"enable_metadata": True,
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -33,7 +34,11 @@ class TestBatchRunComponent(ComponentTestBaseWithoutClient):
|
|||
test_df = DataFrame({"text": ["Hello", "World", "Test"]})
|
||||
|
||||
component = BatchRunComponent(
|
||||
model=MockLanguageModel(), system_message="You are a helpful assistant", df=test_df, column_name="text"
|
||||
model=MockLanguageModel(),
|
||||
system_message="You are a helpful assistant",
|
||||
df=test_df,
|
||||
column_name="text",
|
||||
enable_metadata=True,
|
||||
)
|
||||
|
||||
# Run the batch process
|
||||
|
|
@ -43,46 +48,188 @@ class TestBatchRunComponent(ComponentTestBaseWithoutClient):
|
|||
assert isinstance(result, DataFrame)
|
||||
assert "text_input" in result.columns
|
||||
assert "model_response" in result.columns
|
||||
assert "metadata" in result.columns
|
||||
assert len(result) == 3
|
||||
assert all(isinstance(resp, str) for resp in result["model_response"])
|
||||
# Convert DataFrame to list of dicts for easier testing
|
||||
result_dicts = result.to_dict("records")
|
||||
# Verify metadata
|
||||
assert all(row["metadata"]["has_system_message"] for row in result_dicts)
|
||||
assert all(row["metadata"]["processing_status"] == "success" for row in result_dicts)
|
||||
|
||||
async def test_batch_run_without_system_message(self):
|
||||
async def test_batch_run_without_metadata(self):
|
||||
test_df = DataFrame({"text": ["Hello", "World"]})
|
||||
|
||||
component = BatchRunComponent(model=MockLanguageModel(), df=test_df, column_name="text")
|
||||
component = BatchRunComponent(
|
||||
model=MockLanguageModel(),
|
||||
df=test_df,
|
||||
column_name="text",
|
||||
enable_metadata=False,
|
||||
)
|
||||
|
||||
result = await component.run_batch()
|
||||
|
||||
assert isinstance(result, DataFrame)
|
||||
assert len(result) == 2
|
||||
assert "metadata" not in result.columns
|
||||
assert all(isinstance(resp, str) for resp in result["model_response"])
|
||||
|
||||
async def test_batch_run_error_with_metadata(self):
|
||||
component = BatchRunComponent(
|
||||
model=MockLanguageModel(),
|
||||
df="not_a_dataframe", # This will cause a TypeError
|
||||
column_name="text",
|
||||
enable_metadata=True,
|
||||
)
|
||||
|
||||
with pytest.raises(TypeError, match=re.escape("Expected DataFrame input, got <class 'str'>")):
|
||||
await component.run_batch()
|
||||
|
||||
async def test_batch_run_error_without_metadata(self):
|
||||
component = BatchRunComponent(
|
||||
model=MockLanguageModel(),
|
||||
df="not_a_dataframe", # This will cause a TypeError
|
||||
column_name="text",
|
||||
enable_metadata=False,
|
||||
)
|
||||
|
||||
with pytest.raises(TypeError, match=re.escape("Expected DataFrame input, got <class 'str'>")):
|
||||
await component.run_batch()
|
||||
|
||||
async def test_operational_error_with_metadata(self):
|
||||
# Create a mock model that raises an AttributeError during processing
|
||||
class ErrorModel:
|
||||
def with_config(self, *_, **__):
|
||||
return self
|
||||
|
||||
async def abatch(self, *_):
|
||||
msg = "Mock error during batch processing"
|
||||
raise AttributeError(msg)
|
||||
|
||||
component = BatchRunComponent(
|
||||
model=ErrorModel(),
|
||||
df=DataFrame({"text": ["test1", "test2"]}),
|
||||
column_name="text",
|
||||
enable_metadata=True,
|
||||
)
|
||||
|
||||
result = await component.run_batch()
|
||||
assert isinstance(result, DataFrame)
|
||||
assert len(result) == 1 # Component returns a single error row
|
||||
error_row = result.iloc[0]
|
||||
# Verify error metadata
|
||||
assert error_row["metadata"]["processing_status"] == "failed"
|
||||
assert "Mock error during batch processing" in error_row["metadata"]["error"]
|
||||
# Verify base row structure
|
||||
assert error_row["text_input"] == ""
|
||||
assert error_row["model_response"] == ""
|
||||
assert error_row["batch_index"] == -1
|
||||
|
||||
async def test_operational_error_without_metadata(self):
|
||||
# Create a mock model that raises an AttributeError during processing
|
||||
class ErrorModel:
|
||||
def with_config(self, *_, **__):
|
||||
return self
|
||||
|
||||
async def abatch(self, *_):
|
||||
msg = "Mock error during batch processing"
|
||||
raise AttributeError(msg)
|
||||
|
||||
component = BatchRunComponent(
|
||||
model=ErrorModel(),
|
||||
df=DataFrame({"text": ["test1", "test2"]}),
|
||||
column_name="text",
|
||||
enable_metadata=False,
|
||||
)
|
||||
|
||||
result = await component.run_batch()
|
||||
assert isinstance(result, DataFrame)
|
||||
assert len(result) == 1 # Component returns a single error row
|
||||
error_row = result.iloc[0]
|
||||
# Verify no metadata
|
||||
assert "metadata" not in error_row
|
||||
# Verify base row structure
|
||||
assert error_row["text_input"] == ""
|
||||
assert error_row["model_response"] == ""
|
||||
assert error_row["batch_index"] == -1
|
||||
|
||||
def test_create_base_row(self):
|
||||
component = BatchRunComponent()
|
||||
row = component._create_base_row(text_input="test_input", model_response="test_response", batch_index=1)
|
||||
|
||||
assert row == {
|
||||
"text_input": "test_input",
|
||||
"model_response": "test_response",
|
||||
"batch_index": 1,
|
||||
}
|
||||
|
||||
def test_add_metadata_success(self):
|
||||
component = BatchRunComponent(enable_metadata=True)
|
||||
row = component._create_base_row(text_input="test_input", model_response="test_response", batch_index=1)
|
||||
component._add_metadata(row, success=True, system_msg="test_system")
|
||||
|
||||
assert "metadata" in row
|
||||
assert row["metadata"]["has_system_message"] is True
|
||||
assert row["metadata"]["processing_status"] == "success"
|
||||
assert row["metadata"]["input_length"] == len("test_input")
|
||||
assert row["metadata"]["response_length"] == len("test_response")
|
||||
|
||||
def test_add_metadata_failure(self):
|
||||
component = BatchRunComponent(enable_metadata=True)
|
||||
row = component._create_base_row()
|
||||
component._add_metadata(row, success=False, error="test_error")
|
||||
|
||||
assert "metadata" in row
|
||||
assert row["metadata"]["processing_status"] == "failed"
|
||||
assert row["metadata"]["error"] == "test_error"
|
||||
|
||||
def test_metadata_disabled(self):
|
||||
component = BatchRunComponent(enable_metadata=False)
|
||||
row = component._create_base_row(text_input="test")
|
||||
component._add_metadata(row, success=True)
|
||||
|
||||
assert "metadata" not in row
|
||||
|
||||
async def test_invalid_column_name(self):
|
||||
component = BatchRunComponent(
|
||||
model=MockLanguageModel(), df=DataFrame({"text": ["Hello"]}), column_name="nonexistent_column"
|
||||
model=MockLanguageModel(),
|
||||
df=DataFrame({"text": ["Hello"]}),
|
||||
column_name="nonexistent_column",
|
||||
enable_metadata=True,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=re.escape("Column 'nonexistent_column' not found in the DataFrame.")):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=re.escape("Column 'nonexistent_column' not found in the DataFrame. Available columns: text"),
|
||||
):
|
||||
await component.run_batch()
|
||||
|
||||
async def test_empty_dataframe(self):
|
||||
component = BatchRunComponent(model=MockLanguageModel(), df=DataFrame({"text": []}), column_name="text")
|
||||
component = BatchRunComponent(
|
||||
model=MockLanguageModel(),
|
||||
df=DataFrame({"text": []}),
|
||||
column_name="text",
|
||||
enable_metadata=True,
|
||||
)
|
||||
|
||||
result = await component.run_batch()
|
||||
assert isinstance(result, DataFrame)
|
||||
assert len(result) == 0
|
||||
|
||||
async def test_non_string_column_conversion(self):
|
||||
test_df = DataFrame(
|
||||
{
|
||||
"text": [123, 456, 789] # Numeric values
|
||||
}
|
||||
)
|
||||
test_df = DataFrame({"text": [123, 456, 789]}) # Numeric values
|
||||
|
||||
component = BatchRunComponent(model=MockLanguageModel(), df=test_df, column_name="text")
|
||||
component = BatchRunComponent(
|
||||
model=MockLanguageModel(),
|
||||
df=test_df,
|
||||
column_name="text",
|
||||
enable_metadata=True,
|
||||
)
|
||||
|
||||
result = await component.run_batch()
|
||||
|
||||
assert isinstance(result, DataFrame)
|
||||
assert all(isinstance(text, str) for text in result["text_input"])
|
||||
assert all(str(num) in text for num, text in zip(test_df["text"], result["text_input"], strict=False))
|
||||
result_dicts = result.to_dict("records")
|
||||
assert all(row["metadata"]["processing_status"] == "success" for row in result_dicts)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue