From f10d4dadd41fb4688e2d53cc4728e87d5b5712eb Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Mon, 30 Sep 2024 13:06:22 -0700 Subject: [PATCH] [BUG] Writes from empty partitions should return empty micropartitions with non-null schema (#2952) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit If one partition is empty the write will return a list of file paths / partition cols but the data type is NULL. This is problematic because it will cause schema mismatch with other partitions that did have writes. ``` import daft df = ( daft.from_pydict({"foo": [1, 2, 3], "bar": ["a", "b", "c"]}) .into_partitions(4) .write_parquet("z", partition_cols=["bar"]) ) print(df) daft.exceptions.DaftCoreException: DaftError::SchemaMismatch MicroPartition concat requires all schemas to match, ╭─────────────┬──────╮ │ Column Name ┆ Type │ ╞═════════════╪══════╡ │ path ┆ Utf8 │ ├╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌┤ │ bar ┆ Utf8 │ ╰─────────────┴──────╯ vs ╭─────────────┬──────╮ │ Column Name ┆ Type │ ╞═════════════╪══════╡ │ path ┆ Null │ ╰─────────────┴──────╯ ``` --------- Co-authored-by: Colin Ho --- daft/iceberg/iceberg_write.py | 7 ++++- daft/table/table_io.py | 22 +++++++++----- tests/cookbook/test_write.py | 40 +++++++++++++++++++++++++ tests/io/delta_lake/test_table_write.py | 25 ++++++++++++++++ tests/io/iceberg/test_iceberg_writes.py | 12 ++++++++ 5 files changed, 98 insertions(+), 8 deletions(-) diff --git a/daft/iceberg/iceberg_write.py b/daft/iceberg/iceberg_write.py index 0de4c950d8..9fda932db7 100644 --- a/daft/iceberg/iceberg_write.py +++ b/daft/iceberg/iceberg_write.py @@ -4,6 +4,8 @@ from typing import TYPE_CHECKING, Any, Iterator, List, Tuple from daft import Expression, col +from daft.datatype import DataType +from daft.io.common import _get_schema_from_dict from daft.table import MicroPartition from daft.table.partitioning import PartitionedTable, partition_strings_to_path @@ -211,7 +213,10 @@ def visitor(self, partition_record: "IcebergRecord") -> "IcebergWriteVisitors.Fi return self.FileVisitor(self, partition_record) def to_metadata(self) -> MicroPartition: - return MicroPartition.from_pydict({"data_file": self.data_files}) + col_name = "data_file" + if len(self.data_files) == 0: + return MicroPartition.empty(_get_schema_from_dict({col_name: DataType.python()})) + return MicroPartition.from_pydict({col_name: self.data_files}) def partitioned_table_to_iceberg_iter( diff --git a/daft/table/table_io.py b/daft/table/table_io.py index ba07fab8a4..0f892534d9 100644 --- a/daft/table/table_io.py +++ b/daft/table/table_io.py @@ -22,6 +22,7 @@ PythonStorageConfig, StorageConfig, ) +from daft.datatype import DataType from daft.dependencies import pa, pacsv, pads, pajson, pq from daft.expressions import ExpressionsProjection, col from daft.filesystem import ( @@ -29,6 +30,7 @@ canonicalize_protocol, get_protocol_from_path, ) +from daft.io.common import _get_schema_from_dict from daft.logical.schema import Schema from daft.runners.partitioning import ( TableParseCSVOptions, @@ -426,16 +428,22 @@ def __call__(self, written_file): self.parent.paths.append(written_file.path) self.parent.partition_indices.append(self.idx) - def __init__(self, partition_values: MicroPartition | None, path_key: str = "path"): + def __init__(self, partition_values: MicroPartition | None, schema: Schema): self.paths: list[str] = [] self.partition_indices: list[int] = [] self.partition_values = partition_values - self.path_key = path_key + self.path_key = schema.column_names()[ + 0 + ] # I kept this from our original code, but idk why it's the first column name -kevin + self.schema = schema def visitor(self, partition_idx: int) -> TabularWriteVisitors.FileVisitor: return self.FileVisitor(self, partition_idx) def to_metadata(self) -> MicroPartition: + if len(self.paths) == 0: + return MicroPartition.empty(self.schema) + metadata: dict[str, Any] = {self.path_key: self.paths} if self.partition_values: @@ -488,10 +496,7 @@ def write_tabular( partitioned = PartitionedTable(table, partition_cols) - # I kept this from our original code, but idk why it's the first column name -kevin - path_key = schema.column_names()[0] - - visitors = TabularWriteVisitors(partitioned.partition_values(), path_key) + visitors = TabularWriteVisitors(partitioned.partition_values(), schema) for i, (part_table, part_path) in enumerate(partitioned_table_to_hive_iter(partitioned, resolved_path)): size_bytes = part_table.nbytes @@ -686,7 +691,10 @@ def visitor(self, partition_values: dict[str, str | None]) -> DeltaLakeWriteVisi return self.FileVisitor(self, partition_values) def to_metadata(self) -> MicroPartition: - return MicroPartition.from_pydict({"add_action": self.add_actions}) + col_name = "add_action" + if len(self.add_actions) == 0: + return MicroPartition.empty(_get_schema_from_dict({col_name: DataType.python()})) + return MicroPartition.from_pydict({col_name: self.add_actions}) def write_deltalake( diff --git a/tests/cookbook/test_write.py b/tests/cookbook/test_write.py index 46db61d47e..ddd2c9b040 100644 --- a/tests/cookbook/test_write.py +++ b/tests/cookbook/test_write.py @@ -199,6 +199,26 @@ def test_parquet_write_multifile_with_partitioning(tmp_path, smaller_parquet_tar assert readback["y"] == [y % 2 for y in data["x"]] +def test_parquet_write_with_some_empty_partitions(tmp_path): + data = {"x": [1, 2, 3], "y": ["a", "b", "c"]} + output_files = daft.from_pydict(data).into_partitions(4).write_parquet(tmp_path) + + assert len(output_files) == 3 + + read_back = daft.read_parquet(tmp_path.as_posix() + "/**/*.parquet").sort("x").to_pydict() + assert read_back == data + + +def test_parquet_partitioned_write_with_some_empty_partitions(tmp_path): + data = {"x": [1, 2, 3], "y": ["a", "b", "c"]} + output_files = daft.from_pydict(data).into_partitions(4).write_parquet(tmp_path, partition_cols=["x"]) + + assert len(output_files) == 3 + + read_back = daft.read_parquet(tmp_path.as_posix() + "/**/*.parquet").sort("x").to_pydict() + assert read_back == data + + def test_csv_write(tmp_path): df = daft.read_csv(COOKBOOK_DATA_CSV) @@ -262,3 +282,23 @@ def test_empty_csv_write_with_partitioning(tmp_path): assert len(pd_df) == 1 assert len(pd_df._preview.preview_partition) == 1 + + +def test_csv_write_with_some_empty_partitions(tmp_path): + data = {"x": [1, 2, 3], "y": ["a", "b", "c"]} + output_files = daft.from_pydict(data).into_partitions(4).write_csv(tmp_path) + + assert len(output_files) == 3 + + read_back = daft.read_csv(tmp_path.as_posix() + "/**/*.csv").sort("x").to_pydict() + assert read_back == data + + +def test_csv_partitioned_write_with_some_empty_partitions(tmp_path): + data = {"x": [1, 2, 3], "y": ["a", "b", "c"]} + output_files = daft.from_pydict(data).into_partitions(4).write_csv(tmp_path, partition_cols=["x"]) + + assert len(output_files) == 3 + + read_back = daft.read_csv(tmp_path.as_posix() + "/**/*.csv").sort("x").to_pydict() + assert read_back == data diff --git a/tests/io/delta_lake/test_table_write.py b/tests/io/delta_lake/test_table_write.py index 6519e85d0f..7a65d835cb 100644 --- a/tests/io/delta_lake/test_table_write.py +++ b/tests/io/delta_lake/test_table_write.py @@ -185,6 +185,21 @@ def test_deltalake_write_ignore(tmp_path): assert read_delta.to_pyarrow_table() == df1.to_arrow() +def test_deltalake_write_with_empty_partition(tmp_path, base_table): + deltalake = pytest.importorskip("deltalake") + path = tmp_path / "some_table" + df = daft.from_arrow(base_table).into_partitions(4) + result = df.write_deltalake(str(path)) + result = result.to_pydict() + assert result["operation"] == ["ADD", "ADD", "ADD"] + assert result["rows"] == [1, 1, 1] + + read_delta = deltalake.DeltaTable(str(path)) + expected_schema = Schema.from_pyarrow_schema(read_delta.schema().to_pyarrow()) + assert df.schema() == expected_schema + assert read_delta.to_pyarrow_table() == base_table + + def check_equal_both_daft_and_delta_rs(df: daft.DataFrame, path: Path, sort_order: list[tuple[str, str]]): deltalake = pytest.importorskip("deltalake") @@ -256,6 +271,16 @@ def test_deltalake_write_partitioned_empty(tmp_path): check_equal_both_daft_and_delta_rs(df, path, [("int", "ascending")]) +def test_deltalake_write_partitioned_some_empty(tmp_path): + path = tmp_path / "some_table" + + df = daft.from_pydict({"int": [1, 2, 3, None], "string": ["foo", "foo", "bar", None]}).into_partitions(5) + + df.write_deltalake(str(path), partition_cols=["int"]) + + check_equal_both_daft_and_delta_rs(df, path, [("int", "ascending")]) + + def test_deltalake_write_partitioned_existing_table(tmp_path): path = tmp_path / "some_table" diff --git a/tests/io/iceberg/test_iceberg_writes.py b/tests/io/iceberg/test_iceberg_writes.py index aab68a0c5c..7f85447dba 100644 --- a/tests/io/iceberg/test_iceberg_writes.py +++ b/tests/io/iceberg/test_iceberg_writes.py @@ -209,6 +209,18 @@ def test_read_after_write_nested_fields(local_catalog): assert as_arrow == read_back.to_arrow() +def test_read_after_write_with_empty_partition(local_catalog): + df = daft.from_pydict({"x": [1, 2, 3]}).into_partitions(4) + as_arrow = df.to_arrow() + table = local_catalog.create_table("default.test", as_arrow.schema) + result = df.write_iceberg(table) + as_dict = result.to_pydict() + assert as_dict["operation"] == ["ADD", "ADD", "ADD"] + assert as_dict["rows"] == [1, 1, 1] + read_back = daft.read_iceberg(table) + assert as_arrow == read_back.to_arrow() + + @pytest.fixture def complex_table() -> tuple[pa.Table, Schema]: table = pa.table(