diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index 8dbb33111e..52f0f7458e 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -37,6 +37,7 @@ from daft.datatype import DataType from daft.errors import ExpressionTypeError from daft.expressions import Expression, ExpressionsProjection, col, lit +from daft.filesystem import overwrite_files from daft.logical.builder import LogicalPlanBuilder from daft.runners.partitioning import LocalPartitionSet, PartitionCacheEntry, PartitionSet from daft.table import MicroPartition @@ -513,6 +514,7 @@ def write_parquet( self, root_dir: Union[str, pathlib.Path], compression: str = "snappy", + write_mode: Union[Literal["append"], Literal["overwrite"]] = "append", partition_cols: Optional[List[ColumnInputType]] = None, io_config: Optional[IOConfig] = None, ) -> "DataFrame": @@ -526,6 +528,7 @@ def write_parquet( Args: root_dir (str): root file path to write parquet files to. compression (str, optional): compression algorithm. Defaults to "snappy". + mode (str, optional): Operation mode of the write. `append` will add new data, `overwrite` will replace table with new data. Defaults to "append". partition_cols (Optional[List[ColumnInputType]], optional): How to subpartition each partition further. Defaults to None. io_config (Optional[IOConfig], optional): configurations to use when interacting with remote storage. @@ -535,6 +538,9 @@ def write_parquet( .. NOTE:: This call is **blocking** and will execute the DataFrame when called """ + if write_mode not in ["append", "overwrite"]: + raise ValueError(f"Only support `append` or `overwrite` mode. {write_mode} is unsupported") + io_config = get_context().daft_planning_config.default_io_config if io_config is None else io_config cols: Optional[List[Expression]] = None @@ -553,6 +559,9 @@ def write_parquet( write_df.collect() assert write_df._result is not None + if write_mode == "overwrite": + overwrite_files(write_df, root_dir, io_config) + if len(write_df) > 0: # Populate and return a new disconnected DataFrame result_df = DataFrame(write_df._builder) @@ -577,6 +586,7 @@ def write_parquet( def write_csv( self, root_dir: Union[str, pathlib.Path], + write_mode: Union[Literal["append"], Literal["overwrite"]] = "append", partition_cols: Optional[List[ColumnInputType]] = None, io_config: Optional[IOConfig] = None, ) -> "DataFrame": @@ -589,12 +599,16 @@ def write_csv( Args: root_dir (str): root file path to write parquet files to. + write_mode (str, optional): Operation mode of the write. `append` will add new data, `overwrite` will replace table with new data. Defaults to "append". partition_cols (Optional[List[ColumnInputType]], optional): How to subpartition each partition further. Defaults to None. io_config (Optional[IOConfig], optional): configurations to use when interacting with remote storage. Returns: DataFrame: The filenames that were written out as strings. """ + if write_mode not in ["append", "overwrite"]: + raise ValueError(f"Only support `append` or `overwrite` mode. {write_mode} is unsupported") + io_config = get_context().daft_planning_config.default_io_config if io_config is None else io_config cols: Optional[List[Expression]] = None @@ -612,6 +626,9 @@ def write_csv( write_df.collect() assert write_df._result is not None + if write_mode == "overwrite": + overwrite_files(write_df, root_dir, io_config) + if len(write_df) > 0: # Populate and return a new disconnected DataFrame result_df = DataFrame(write_df._builder) diff --git a/daft/filesystem.py b/daft/filesystem.py index 93ed1971fd..019405d32c 100644 --- a/daft/filesystem.py +++ b/daft/filesystem.py @@ -6,12 +6,17 @@ import pathlib import sys import urllib.parse -from typing import Any, Literal +from typing import TYPE_CHECKING, Any, Literal +from daft.convert import from_pydict from daft.daft import FileFormat, FileInfos, IOConfig, io_glob from daft.dependencies import fsspec, pafs +from daft.expressions.expressions import col, lit from daft.table import MicroPartition +if TYPE_CHECKING: + from daft import DataFrame + logger = logging.getLogger(__name__) _CACHED_FSES: dict[tuple[str, IOConfig | None], pafs.FileSystem] = {} @@ -353,3 +358,22 @@ def join_path(fs: pafs.FileSystem, base_path: str, *sub_paths: str) -> str: return os.path.join(base_path, *sub_paths) else: return f"{base_path.rstrip('/')}/{'/'.join(sub_paths)}" + + +def overwrite_files( + manifest: DataFrame, + root_dir: str | pathlib.Path, + io_config: IOConfig | None, +) -> None: + [resolved_path], fs = _resolve_paths_and_filesystem(root_dir, io_config=io_config) + file_selector = pafs.FileSelector(resolved_path, recursive=True) + paths = [info.path for info in fs.get_file_info(file_selector) if info.type == pafs.FileType.File] + all_file_paths_df = from_pydict({"path": paths}) + + assert manifest._result is not None + written_file_paths = manifest._result._get_merged_micropartition().get_column("path") + to_delete = all_file_paths_df.where(~(col("path").is_in(lit(written_file_paths)))) + + # TODO: Look into parallelizing this + for entry in to_delete: + fs.delete_file(entry["path"]) diff --git a/tests/io/test_write_modes.py b/tests/io/test_write_modes.py new file mode 100644 index 0000000000..dde38f89a2 --- /dev/null +++ b/tests/io/test_write_modes.py @@ -0,0 +1,150 @@ +import uuid +from typing import List, Optional + +import pytest +import s3fs + +import daft +from daft import context + +pytestmark = pytest.mark.skipif( + context.get_context().daft_execution_config.enable_native_executor is True, + reason="Native executor doesn't support writes yet", +) + + +def write( + df: daft.DataFrame, + path: str, + format: str, + write_mode: str, + partition_cols: Optional[List[str]] = None, + io_config: Optional[daft.io.IOConfig] = None, +): + if format == "parquet": + return df.write_parquet( + path, + write_mode=write_mode, + partition_cols=partition_cols, + io_config=io_config, + ) + elif format == "csv": + return df.write_csv( + path, + write_mode=write_mode, + partition_cols=partition_cols, + io_config=io_config, + ) + else: + raise ValueError(f"Unsupported format: {format}") + + +def read(path: str, format: str, io_config: Optional[daft.io.IOConfig] = None): + if format == "parquet": + return daft.read_parquet(path, io_config=io_config) + elif format == "csv": + return daft.read_csv(path, io_config=io_config) + else: + raise ValueError(f"Unsupported format: {format}") + + +def arrange_write_mode_test(existing_data, new_data, path, format, write_mode, partition_cols, io_config): + # Write some existing_data + write(existing_data, path, format, "append", partition_cols, io_config) + + # Write some new data + write(new_data, path, format, write_mode, partition_cols, io_config) + + # Read back the data + read_path = path + "/**" if partition_cols is not None else path + read_back = read(read_path, format, io_config).sort(["a", "b"]).to_pydict() + + return read_back + + +@pytest.mark.parametrize("write_mode", ["append", "overwrite"]) +@pytest.mark.parametrize("format", ["csv", "parquet"]) +@pytest.mark.parametrize("num_partitions", [1, 2]) +@pytest.mark.parametrize("partition_cols", [None, ["a"]]) +def test_write_modes_local(tmp_path, write_mode, format, num_partitions, partition_cols): + path = str(tmp_path) + existing_data = {"a": ["a", "a", "b", "b"], "b": [1, 2, 3, 4]} + new_data = { + "a": ["a", "a", "b", "b"], + "b": [5, 6, 7, 8], + } + + read_back = arrange_write_mode_test( + daft.from_pydict(existing_data).into_partitions(num_partitions), + daft.from_pydict(new_data).into_partitions(num_partitions), + path, + format, + write_mode, + partition_cols, + None, + ) + + # Check the data + if write_mode == "append": + assert read_back["a"] == ["a"] * 4 + ["b"] * 4 + assert read_back["b"] == [1, 2, 5, 6, 3, 4, 7, 8] + elif write_mode == "overwrite": + assert read_back["a"] == ["a", "a", "b", "b"] + assert read_back["b"] == [5, 6, 7, 8] + else: + raise ValueError(f"Unsupported write_mode: {write_mode}") + + +@pytest.fixture(scope="function") +def bucket(minio_io_config): + BUCKET = "write-modes-bucket" + + fs = s3fs.S3FileSystem( + key=minio_io_config.s3.key_id, + password=minio_io_config.s3.access_key, + client_kwargs={"endpoint_url": minio_io_config.s3.endpoint_url}, + ) + if not fs.exists(BUCKET): + fs.mkdir(BUCKET) + yield BUCKET + + +@pytest.mark.integration() +@pytest.mark.parametrize("write_mode", ["append", "overwrite"]) +@pytest.mark.parametrize("format", ["csv", "parquet"]) +@pytest.mark.parametrize("num_partitions", [1, 2]) +@pytest.mark.parametrize("partition_cols", [None, ["a"]]) +def test_write_modes_s3_minio( + minio_io_config, + bucket, + write_mode, + format, + num_partitions, + partition_cols, +): + path = f"s3://{bucket}/{str(uuid.uuid4())}" + existing_data = {"a": ["a", "a", "b", "b"], "b": [1, 2, 3, 4]} + new_data = { + "a": ["a", "a", "b", "b"], + "b": [5, 6, 7, 8], + } + + read_back = arrange_write_mode_test( + daft.from_pydict(existing_data).into_partitions(num_partitions), + daft.from_pydict(new_data).into_partitions(num_partitions), + path, + format, + write_mode, + partition_cols, + minio_io_config, + ) + + # Check the data + if write_mode == "append": + assert read_back["a"] == ["a"] * 4 + ["b"] * 4 + assert read_back["b"] == [1, 2, 5, 6, 3, 4, 7, 8] + elif write_mode == "overwrite": + assert read_back["a"] == ["a", "a", "b", "b"] + assert read_back["b"] == [5, 6, 7, 8] + else: + raise ValueError(f"Unsupported write_mode: {write_mode}")