Skip to content

Commit

Permalink
only split files for rows with pyarrow 7+
Browse files Browse the repository at this point in the history
  • Loading branch information
samster25 committed Jan 23, 2024
1 parent fa1a8cc commit 83be7b2
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 10 deletions.
14 changes: 11 additions & 3 deletions daft/table/table_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,9 @@ def write_tabular(
io_config: IOConfig | None = None,
partition_null_fallback: str = "__HIVE_DEFAULT_PARTITION__",
) -> MicroPartition:

from daft.utils import ARROW_VERSION

[resolved_path], fs = _resolve_paths_and_filesystem(path, io_config=io_config)

tables_to_write: list[MicroPartition]
Expand Down Expand Up @@ -431,6 +434,13 @@ def file_visitor(written_file):
visited_sizes.append(written_file.size)
partition_idx.append(i)

kwargs = dict()

if ARROW_VERSION >= (7, 0, 0):
kwargs["max_rows_per_file"] = rows_per_file
kwargs["min_rows_per_group"] = rows_per_row_group
kwargs["max_rows_per_group"] = rows_per_row_group

pads.write_dataset(
arrow_table,
base_dir=full_path,
Expand All @@ -442,9 +452,7 @@ def file_visitor(written_file):
use_threads=True,
existing_data_behavior="overwrite_or_ignore",
filesystem=fs,
max_rows_per_file=rows_per_file,
min_rows_per_group=rows_per_row_group,
max_rows_per_group=rows_per_row_group,
**kwargs,
)

data_dict: dict[str, Any] = {schema.column_names()[0]: visited_paths, schema.column_names()[1]: visited_sizes}
Expand Down
14 changes: 7 additions & 7 deletions tests/cookbook/test_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from tests.conftest import assert_df_equals
from tests.cookbook.assets import COOKBOOK_DATA_CSV

PYARROW_GE_11_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) >= (11, 0, 0)
PYARROW_GE_7_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) >= (7, 0, 0)


def test_parquet_write(tmp_path):
Expand Down Expand Up @@ -71,8 +71,8 @@ def test_parquet_write_with_partitioning_readback_values(tmp_path):


@pytest.mark.skipif(
not PYARROW_GE_11_0_0,
reason="We only use pyarrow datasets 11 for this test",
not PYARROW_GE_7_0_0,
reason="We only use pyarrow datasets 7 for this test",
)
def test_parquet_write_with_null_values(tmp_path):
df = daft.from_pydict({"x": [1, 2, 3, None]})
Expand All @@ -83,8 +83,8 @@ def test_parquet_write_with_null_values(tmp_path):


@pytest.mark.skipif(
not PYARROW_GE_11_0_0,
reason="We only use pyarrow datasets 11 for this test",
not PYARROW_GE_7_0_0,
reason="We only use pyarrow datasets 7 for this test",
)
def test_parquet_write_multifile(tmp_path):
daft.set_execution_config(parquet_target_filesize=1024)
Expand All @@ -98,8 +98,8 @@ def test_parquet_write_multifile(tmp_path):


@pytest.mark.skipif(
not PYARROW_GE_11_0_0,
reason="We only use pyarrow datasets 11 for this test",
not PYARROW_GE_7_0_0,
reason="We only use pyarrow datasets 7 for this test",
)
def test_parquet_write_multifile_with_partitioning(tmp_path):
daft.set_execution_config(parquet_target_filesize=1024)
Expand Down

0 comments on commit 83be7b2

Please sign in to comment.