Skip to content

Commit

Permalink
add empty test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
samster25 committed Jan 23, 2024
1 parent 80f2823 commit fa1a8cc
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 8 deletions.
19 changes: 11 additions & 8 deletions daft/table/table_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
PythonStorageConfig,
StorageConfig,
)
from daft.datatype import DataType
from daft.expressions import ExpressionsProjection
from daft.filesystem import _resolve_paths_and_filesystem
from daft.logical.schema import Schema
Expand Down Expand Up @@ -385,10 +386,7 @@ def write_tabular(

visited_paths = []
visited_sizes = []

def file_visitor(written_file):
visited_paths.append(written_file.path)
visited_sizes.append(written_file.size)
partition_idx = []

execution_config = get_context().daft_execution_config

Expand All @@ -408,7 +406,7 @@ def file_visitor(written_file):
else:
raise ValueError(f"Unsupported file format {file_format}")

for tab, pf in zip(tables_to_write, part_keys_postfix_per_table):
for i, (tab, pf) in enumerate(zip(tables_to_write, part_keys_postfix_per_table)):
full_path = resolved_path
if pf is not None and len(pf) > 0:
full_path = f"{full_path}/{pf}"
Expand All @@ -426,7 +424,12 @@ def file_visitor(written_file):
rows_per_file = math.ceil(num_rows / target_num_files)

target_row_groups = max(math.ceil(size_bytes / TARGET_ROW_GROUP_SIZE / inflation_factor), 1)
rows_per_row_group = math.ceil(num_rows / target_row_groups)
rows_per_row_group = min(math.ceil(num_rows / target_row_groups), rows_per_file)

def file_visitor(written_file):
visited_paths.append(written_file.path)
visited_sizes.append(written_file.size)
partition_idx.append(i)

pads.write_dataset(
arrow_table,
Expand All @@ -447,7 +450,7 @@ def file_visitor(written_file):
data_dict: dict[str, Any] = {schema.column_names()[0]: visited_paths, schema.column_names()[1]: visited_sizes}

if partition_values is not None:
partition_idx_series = Series.from_pylist(partition_idx).cast(DataType.int64())
for c_name in partition_values.column_names():
data_dict[c_name] = partition_values.get_column(c_name)

data_dict[c_name] = partition_values.get_column(c_name).take(partition_idx_series)
return MicroPartition.from_pydict(data_dict)
48 changes: 48 additions & 0 deletions tests/cookbook/test_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,22 @@ def test_parquet_write_with_partitioning(tmp_path):
assert len(pd_df._preview.preview_partition) == 5


def test_empty_parquet_write_without_partitioning(tmp_path):
df = daft.read_csv(COOKBOOK_DATA_CSV)
df = df.where(daft.lit(False))
output_files = df.write_parquet(tmp_path)
assert len(output_files) == 1
assert len(output_files._preview.preview_partition) == 1


def test_empty_parquet_write_with_partitioning(tmp_path):
df = daft.read_csv(COOKBOOK_DATA_CSV)
df = df.where(daft.lit(False))
output_files = df.write_parquet(tmp_path, partition_cols=["Borough"])
assert len(output_files) == 0
assert len(output_files._preview.preview_partition) == 0


def test_parquet_write_with_partitioning_readback_values(tmp_path):
df = daft.read_csv(COOKBOOK_DATA_CSV)

Expand Down Expand Up @@ -66,6 +82,38 @@ def test_parquet_write_with_null_values(tmp_path):
assert readback.to_pydict() == {"x": [1, 2, 3, None], "y": [1, 2, 3, None]}


@pytest.mark.skipif(
not PYARROW_GE_11_0_0,
reason="We only use pyarrow datasets 11 for this test",
)
def test_parquet_write_multifile(tmp_path):
daft.set_execution_config(parquet_target_filesize=1024)
data = {"x": list(range(1_000))}
df = daft.from_pydict(data)
df2 = df.write_parquet(tmp_path)
assert len(df2) > 1
ds = pads.dataset(tmp_path, format="parquet")
readback = ds.to_table()
assert readback.to_pydict() == data


@pytest.mark.skipif(
not PYARROW_GE_11_0_0,
reason="We only use pyarrow datasets 11 for this test",
)
def test_parquet_write_multifile_with_partitioning(tmp_path):
daft.set_execution_config(parquet_target_filesize=1024)
data = {"x": list(range(1_000))}
df = daft.from_pydict(data)
df2 = df.write_parquet(tmp_path, partition_cols=[df["x"].alias("y") % 2])
assert len(df2) >= 4
ds = pads.dataset(tmp_path, format="parquet", partitioning=pads.HivePartitioning(pa.schema([("y", pa.int64())])))
readback = ds.to_table()
readback = readback.sort_by("x").to_pydict()
assert readback["x"] == data["x"]
assert readback["y"] == [y % 2 for y in data["x"]]


def test_csv_write(tmp_path):
df = daft.read_csv(COOKBOOK_DATA_CSV)

Expand Down

0 comments on commit fa1a8cc

Please sign in to comment.