fix: add file format validation to component save-to-file (#7593)

* add _check_file_format

* [autofix.ci] apply automated fixes

* change to __adjust_file_path_with_format

* [autofix.ci] apply automated fixes

* Refactor and enhance tests for _adjust_file_path_with_format method

- Added parameterized tests to verify correct file extension handling for various formats.
- Ensured existing extensions are preserved and incorrect extensions are handled appropriately.
- Included a test for expanding the home directory symbol '~' in file paths.
- Removed outdated tests related to _check_file_format method.

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>
This commit is contained in:
Gustavo Costa 2025-04-14 14:44:36 -03:00 committed by GitHub
commit dde91e2581
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 83 additions and 1 deletions

View file

@ -103,6 +103,8 @@ class SaveToFileComponent(Component):
if not file_path.parent.exists():
file_path.parent.mkdir(parents=True, exist_ok=True)
file_path = self._adjust_file_path_with_format(file_path, file_format)
if input_type == "DataFrame":
dataframe = self.df
return self._save_dataframe(dataframe, file_path, file_format)
@ -116,6 +118,14 @@ class SaveToFileComponent(Component):
error_msg = f"Unsupported input type: {input_type}"
raise ValueError(error_msg)
def _adjust_file_path_with_format(self, path: Path, fmt: str) -> Path:
file_extension = path.suffix.lower().lstrip(".")
if fmt == "excel":
return Path(f"{path}.xlsx").expanduser() if file_extension not in ["xlsx", "xls"] else path
return Path(f"{path}.{fmt}").expanduser() if file_extension != fmt else path
def _save_dataframe(self, dataframe: DataFrame, path: Path, fmt: str) -> str:
if fmt == "csv":
dataframe.to_csv(path, index=False)

File diff suppressed because one or more lines are too long

View file

@ -163,3 +163,75 @@ class TestSaveToFileComponent(ComponentTestBaseWithoutClient):
with pytest.raises(ValueError, match="Unsupported input type"):
component.save_to_file()
@pytest.mark.parametrize(
("path_str", "fmt", "expected_suffix"),
[
("./test_output", "csv", ".csv"),
("./test_output", "json", ".json"),
("./test_output", "markdown", ".markdown"),
("./test_output", "txt", ".txt"),
],
)
def test_adjust_path_adds_extension(self, component_class, path_str, fmt, expected_suffix):
"""Test that the correct extension is added when none exists."""
component = component_class()
input_path = Path(path_str)
expected_path = Path(f"{path_str}{expected_suffix}")
result = component._adjust_file_path_with_format(input_path, fmt)
assert str(result) == str(expected_path.expanduser())
@pytest.mark.parametrize(
("path_str", "fmt"),
[
("./test_output.csv", "csv"),
("./test_output.json", "json"),
("./test_output.markdown", "markdown"),
("./test_output.txt", "txt"),
],
)
def test_adjust_path_keeps_existing_correct_extension(self, component_class, path_str, fmt):
"""Test that the existing correct extension is kept."""
component = component_class()
input_path = Path(path_str)
result = component._adjust_file_path_with_format(input_path, fmt)
assert str(result) == str(input_path.expanduser())
@pytest.mark.parametrize(
("path_str", "fmt", "expected_path_str"),
[
("./test_output.txt", "csv", "./test_output.txt.csv"), # Incorrect extension
("./test_output", "excel", "./test_output.xlsx"), # Add .xlsx for excel
("./test_output.txt", "excel", "./test_output.txt.xlsx"), # Incorrect extension for excel
],
)
def test_adjust_path_handles_incorrect_or_excel_add(self, component_class, path_str, fmt, expected_path_str):
"""Test handling incorrect extensions and adding .xlsx for excel."""
component = component_class()
input_path = Path(path_str)
expected_path = Path(expected_path_str)
result = component._adjust_file_path_with_format(input_path, fmt)
assert str(result) == str(expected_path.expanduser())
@pytest.mark.parametrize(
"path_str",
[
"./test_output.xlsx",
"./test_output.xls",
],
)
def test_adjust_path_keeps_existing_excel_extension(self, component_class, path_str):
"""Test that existing .xlsx or .xls extensions are kept for excel format."""
component = component_class()
input_path = Path(path_str)
result = component._adjust_file_path_with_format(input_path, "excel")
assert str(result) == str(input_path.expanduser())
def test_adjust_path_expands_home(self, component_class):
"""Test that the home directory symbol '~' is expanded."""
component = component_class()
input_path = Path("~/test_output")
expected_path = Path("~/test_output.csv").expanduser()
result = component._adjust_file_path_with_format(input_path, "csv")
assert str(result) == str(expected_path)
assert "~" not in str(result) # Ensure ~ was expanded