Skip to content

Commit

Permalink
[FEAT] Overwrite mode for write parquet/csv (#3108)
Browse files Browse the repository at this point in the history
Addresses: #3112 and
#1768

Implements overwrite mode for write_parquet and write_csv.

Upon finishing the write, we are left with a manifest of written file
paths. We can use this to perform a `delete all files not in manifest`,
by:
1. Do an `ls` to figure out all the current files in the root dir.
2. Use daft's built in `is_in` expression to get the file paths to
delete.
3. Delete them.

Notes:
- Relies on fsspec for `ls` and `rm` functionalities. This is favored
over pyarrow filesystem because `rm` is a **bulk** delete method, aka we
can do the delete in a single API call. Pyarrow filesystem does not have
bulk deletes.

---------

Co-authored-by: Colin Ho <[email protected]>
Co-authored-by: Colin Ho <[email protected]>
  • Loading branch information
3 people authored Nov 6, 2024
1 parent d8cdf36 commit 0d669ca
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 1 deletion.
17 changes: 17 additions & 0 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from daft.datatype import DataType
from daft.errors import ExpressionTypeError
from daft.expressions import Expression, ExpressionsProjection, col, lit
from daft.filesystem import overwrite_files
from daft.logical.builder import LogicalPlanBuilder
from daft.runners.partitioning import LocalPartitionSet, PartitionCacheEntry, PartitionSet
from daft.table import MicroPartition
Expand Down Expand Up @@ -513,6 +514,7 @@ def write_parquet(
self,
root_dir: Union[str, pathlib.Path],
compression: str = "snappy",
write_mode: Union[Literal["append"], Literal["overwrite"]] = "append",
partition_cols: Optional[List[ColumnInputType]] = None,
io_config: Optional[IOConfig] = None,
) -> "DataFrame":
Expand All @@ -526,6 +528,7 @@ def write_parquet(
Args:
root_dir (str): root file path to write parquet files to.
compression (str, optional): compression algorithm. Defaults to "snappy".
mode (str, optional): Operation mode of the write. `append` will add new data, `overwrite` will replace table with new data. Defaults to "append".
partition_cols (Optional[List[ColumnInputType]], optional): How to subpartition each partition further. Defaults to None.
io_config (Optional[IOConfig], optional): configurations to use when interacting with remote storage.
Expand All @@ -535,6 +538,9 @@ def write_parquet(
.. NOTE::
This call is **blocking** and will execute the DataFrame when called
"""
if write_mode not in ["append", "overwrite"]:
raise ValueError(f"Only support `append` or `overwrite` mode. {write_mode} is unsupported")

io_config = get_context().daft_planning_config.default_io_config if io_config is None else io_config

cols: Optional[List[Expression]] = None
Expand All @@ -553,6 +559,9 @@ def write_parquet(
write_df.collect()
assert write_df._result is not None

if write_mode == "overwrite":
overwrite_files(write_df, root_dir, io_config)

if len(write_df) > 0:
# Populate and return a new disconnected DataFrame
result_df = DataFrame(write_df._builder)
Expand All @@ -577,6 +586,7 @@ def write_parquet(
def write_csv(
self,
root_dir: Union[str, pathlib.Path],
write_mode: Union[Literal["append"], Literal["overwrite"]] = "append",
partition_cols: Optional[List[ColumnInputType]] = None,
io_config: Optional[IOConfig] = None,
) -> "DataFrame":
Expand All @@ -589,12 +599,16 @@ def write_csv(
Args:
root_dir (str): root file path to write parquet files to.
write_mode (str, optional): Operation mode of the write. `append` will add new data, `overwrite` will replace table with new data. Defaults to "append".
partition_cols (Optional[List[ColumnInputType]], optional): How to subpartition each partition further. Defaults to None.
io_config (Optional[IOConfig], optional): configurations to use when interacting with remote storage.
Returns:
DataFrame: The filenames that were written out as strings.
"""
if write_mode not in ["append", "overwrite"]:
raise ValueError(f"Only support `append` or `overwrite` mode. {write_mode} is unsupported")

io_config = get_context().daft_planning_config.default_io_config if io_config is None else io_config

cols: Optional[List[Expression]] = None
Expand All @@ -612,6 +626,9 @@ def write_csv(
write_df.collect()
assert write_df._result is not None

if write_mode == "overwrite":
overwrite_files(write_df, root_dir, io_config)

if len(write_df) > 0:
# Populate and return a new disconnected DataFrame
result_df = DataFrame(write_df._builder)
Expand Down
26 changes: 25 additions & 1 deletion daft/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,17 @@
import pathlib
import sys
import urllib.parse
from typing import Any, Literal
from typing import TYPE_CHECKING, Any, Literal

from daft.convert import from_pydict
from daft.daft import FileFormat, FileInfos, IOConfig, io_glob
from daft.dependencies import fsspec, pafs
from daft.expressions.expressions import col, lit
from daft.table import MicroPartition

if TYPE_CHECKING:
from daft import DataFrame

logger = logging.getLogger(__name__)

_CACHED_FSES: dict[tuple[str, IOConfig | None], pafs.FileSystem] = {}
Expand Down Expand Up @@ -353,3 +358,22 @@ def join_path(fs: pafs.FileSystem, base_path: str, *sub_paths: str) -> str:
return os.path.join(base_path, *sub_paths)
else:
return f"{base_path.rstrip('/')}/{'/'.join(sub_paths)}"


def overwrite_files(
manifest: DataFrame,
root_dir: str | pathlib.Path,
io_config: IOConfig | None,
) -> None:
[resolved_path], fs = _resolve_paths_and_filesystem(root_dir, io_config=io_config)
file_selector = pafs.FileSelector(resolved_path, recursive=True)
paths = [info.path for info in fs.get_file_info(file_selector) if info.type == pafs.FileType.File]
all_file_paths_df = from_pydict({"path": paths})

assert manifest._result is not None
written_file_paths = manifest._result._get_merged_micropartition().get_column("path")
to_delete = all_file_paths_df.where(~(col("path").is_in(lit(written_file_paths))))

# TODO: Look into parallelizing this
for entry in to_delete:
fs.delete_file(entry["path"])
150 changes: 150 additions & 0 deletions tests/io/test_write_modes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import uuid
from typing import List, Optional

import pytest
import s3fs

import daft
from daft import context

pytestmark = pytest.mark.skipif(
context.get_context().daft_execution_config.enable_native_executor is True,
reason="Native executor doesn't support writes yet",
)


def write(
df: daft.DataFrame,
path: str,
format: str,
write_mode: str,
partition_cols: Optional[List[str]] = None,
io_config: Optional[daft.io.IOConfig] = None,
):
if format == "parquet":
return df.write_parquet(
path,
write_mode=write_mode,
partition_cols=partition_cols,
io_config=io_config,
)
elif format == "csv":
return df.write_csv(
path,
write_mode=write_mode,
partition_cols=partition_cols,
io_config=io_config,
)
else:
raise ValueError(f"Unsupported format: {format}")


def read(path: str, format: str, io_config: Optional[daft.io.IOConfig] = None):
if format == "parquet":
return daft.read_parquet(path, io_config=io_config)
elif format == "csv":
return daft.read_csv(path, io_config=io_config)
else:
raise ValueError(f"Unsupported format: {format}")


def arrange_write_mode_test(existing_data, new_data, path, format, write_mode, partition_cols, io_config):
# Write some existing_data
write(existing_data, path, format, "append", partition_cols, io_config)

# Write some new data
write(new_data, path, format, write_mode, partition_cols, io_config)

# Read back the data
read_path = path + "/**" if partition_cols is not None else path
read_back = read(read_path, format, io_config).sort(["a", "b"]).to_pydict()

return read_back


@pytest.mark.parametrize("write_mode", ["append", "overwrite"])
@pytest.mark.parametrize("format", ["csv", "parquet"])
@pytest.mark.parametrize("num_partitions", [1, 2])
@pytest.mark.parametrize("partition_cols", [None, ["a"]])
def test_write_modes_local(tmp_path, write_mode, format, num_partitions, partition_cols):
path = str(tmp_path)
existing_data = {"a": ["a", "a", "b", "b"], "b": [1, 2, 3, 4]}
new_data = {
"a": ["a", "a", "b", "b"],
"b": [5, 6, 7, 8],
}

read_back = arrange_write_mode_test(
daft.from_pydict(existing_data).into_partitions(num_partitions),
daft.from_pydict(new_data).into_partitions(num_partitions),
path,
format,
write_mode,
partition_cols,
None,
)

# Check the data
if write_mode == "append":
assert read_back["a"] == ["a"] * 4 + ["b"] * 4
assert read_back["b"] == [1, 2, 5, 6, 3, 4, 7, 8]
elif write_mode == "overwrite":
assert read_back["a"] == ["a", "a", "b", "b"]
assert read_back["b"] == [5, 6, 7, 8]
else:
raise ValueError(f"Unsupported write_mode: {write_mode}")


@pytest.fixture(scope="function")
def bucket(minio_io_config):
BUCKET = "write-modes-bucket"

fs = s3fs.S3FileSystem(
key=minio_io_config.s3.key_id,
password=minio_io_config.s3.access_key,
client_kwargs={"endpoint_url": minio_io_config.s3.endpoint_url},
)
if not fs.exists(BUCKET):
fs.mkdir(BUCKET)
yield BUCKET


@pytest.mark.integration()
@pytest.mark.parametrize("write_mode", ["append", "overwrite"])
@pytest.mark.parametrize("format", ["csv", "parquet"])
@pytest.mark.parametrize("num_partitions", [1, 2])
@pytest.mark.parametrize("partition_cols", [None, ["a"]])
def test_write_modes_s3_minio(
minio_io_config,
bucket,
write_mode,
format,
num_partitions,
partition_cols,
):
path = f"s3://{bucket}/{str(uuid.uuid4())}"
existing_data = {"a": ["a", "a", "b", "b"], "b": [1, 2, 3, 4]}
new_data = {
"a": ["a", "a", "b", "b"],
"b": [5, 6, 7, 8],
}

read_back = arrange_write_mode_test(
daft.from_pydict(existing_data).into_partitions(num_partitions),
daft.from_pydict(new_data).into_partitions(num_partitions),
path,
format,
write_mode,
partition_cols,
minio_io_config,
)

# Check the data
if write_mode == "append":
assert read_back["a"] == ["a"] * 4 + ["b"] * 4
assert read_back["b"] == [1, 2, 5, 6, 3, 4, 7, 8]
elif write_mode == "overwrite":
assert read_back["a"] == ["a", "a", "b", "b"]
assert read_back["b"] == [5, 6, 7, 8]
else:
raise ValueError(f"Unsupported write_mode: {write_mode}")

0 comments on commit 0d669ca

Please sign in to comment.