fix: add file format validation to component save-to-file (#7593)
* add _check_file_format * [autofix.ci] apply automated fixes * change to __adjust_file_path_with_format * [autofix.ci] apply automated fixes * Refactor and enhance tests for _adjust_file_path_with_format method - Added parameterized tests to verify correct file extension handling for various formats. - Ensured existing extensions are preserved and incorrect extensions are handled appropriately. - Included a test for expanding the home directory symbol '~' in file paths. - Removed outdated tests related to _check_file_format method. --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>
This commit is contained in:
parent
90128ca4e3
commit
dde91e2581
3 changed files with 83 additions and 1 deletions
|
|
@ -103,6 +103,8 @@ class SaveToFileComponent(Component):
|
|||
if not file_path.parent.exists():
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
file_path = self._adjust_file_path_with_format(file_path, file_format)
|
||||
|
||||
if input_type == "DataFrame":
|
||||
dataframe = self.df
|
||||
return self._save_dataframe(dataframe, file_path, file_format)
|
||||
|
|
@ -116,6 +118,14 @@ class SaveToFileComponent(Component):
|
|||
error_msg = f"Unsupported input type: {input_type}"
|
||||
raise ValueError(error_msg)
|
||||
|
||||
def _adjust_file_path_with_format(self, path: Path, fmt: str) -> Path:
|
||||
file_extension = path.suffix.lower().lstrip(".")
|
||||
|
||||
if fmt == "excel":
|
||||
return Path(f"{path}.xlsx").expanduser() if file_extension not in ["xlsx", "xls"] else path
|
||||
|
||||
return Path(f"{path}.{fmt}").expanduser() if file_extension != fmt else path
|
||||
|
||||
def _save_dataframe(self, dataframe: DataFrame, path: Path, fmt: str) -> str:
|
||||
if fmt == "csv":
|
||||
dataframe.to_csv(path, index=False)
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
|
|
@ -163,3 +163,75 @@ class TestSaveToFileComponent(ComponentTestBaseWithoutClient):
|
|||
|
||||
with pytest.raises(ValueError, match="Unsupported input type"):
|
||||
component.save_to_file()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("path_str", "fmt", "expected_suffix"),
|
||||
[
|
||||
("./test_output", "csv", ".csv"),
|
||||
("./test_output", "json", ".json"),
|
||||
("./test_output", "markdown", ".markdown"),
|
||||
("./test_output", "txt", ".txt"),
|
||||
],
|
||||
)
|
||||
def test_adjust_path_adds_extension(self, component_class, path_str, fmt, expected_suffix):
|
||||
"""Test that the correct extension is added when none exists."""
|
||||
component = component_class()
|
||||
input_path = Path(path_str)
|
||||
expected_path = Path(f"{path_str}{expected_suffix}")
|
||||
result = component._adjust_file_path_with_format(input_path, fmt)
|
||||
assert str(result) == str(expected_path.expanduser())
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("path_str", "fmt"),
|
||||
[
|
||||
("./test_output.csv", "csv"),
|
||||
("./test_output.json", "json"),
|
||||
("./test_output.markdown", "markdown"),
|
||||
("./test_output.txt", "txt"),
|
||||
],
|
||||
)
|
||||
def test_adjust_path_keeps_existing_correct_extension(self, component_class, path_str, fmt):
|
||||
"""Test that the existing correct extension is kept."""
|
||||
component = component_class()
|
||||
input_path = Path(path_str)
|
||||
result = component._adjust_file_path_with_format(input_path, fmt)
|
||||
assert str(result) == str(input_path.expanduser())
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("path_str", "fmt", "expected_path_str"),
|
||||
[
|
||||
("./test_output.txt", "csv", "./test_output.txt.csv"), # Incorrect extension
|
||||
("./test_output", "excel", "./test_output.xlsx"), # Add .xlsx for excel
|
||||
("./test_output.txt", "excel", "./test_output.txt.xlsx"), # Incorrect extension for excel
|
||||
],
|
||||
)
|
||||
def test_adjust_path_handles_incorrect_or_excel_add(self, component_class, path_str, fmt, expected_path_str):
|
||||
"""Test handling incorrect extensions and adding .xlsx for excel."""
|
||||
component = component_class()
|
||||
input_path = Path(path_str)
|
||||
expected_path = Path(expected_path_str)
|
||||
result = component._adjust_file_path_with_format(input_path, fmt)
|
||||
assert str(result) == str(expected_path.expanduser())
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path_str",
|
||||
[
|
||||
"./test_output.xlsx",
|
||||
"./test_output.xls",
|
||||
],
|
||||
)
|
||||
def test_adjust_path_keeps_existing_excel_extension(self, component_class, path_str):
|
||||
"""Test that existing .xlsx or .xls extensions are kept for excel format."""
|
||||
component = component_class()
|
||||
input_path = Path(path_str)
|
||||
result = component._adjust_file_path_with_format(input_path, "excel")
|
||||
assert str(result) == str(input_path.expanduser())
|
||||
|
||||
def test_adjust_path_expands_home(self, component_class):
|
||||
"""Test that the home directory symbol '~' is expanded."""
|
||||
component = component_class()
|
||||
input_path = Path("~/test_output")
|
||||
expected_path = Path("~/test_output.csv").expanduser()
|
||||
result = component._adjust_file_path_with_format(input_path, "csv")
|
||||
assert str(result) == str(expected_path)
|
||||
assert "~" not in str(result) # Ensure ~ was expanded
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue