Skip to content

Commit

Permalink
Add daft roundtrip tests
Browse files Browse the repository at this point in the history
  • Loading branch information
desmondcheongzx committed Oct 28, 2024
1 parent 9e9cdf7 commit 8446076
Showing 1 changed file with 84 additions and 137 deletions.
221 changes: 84 additions & 137 deletions tests/io/test_hive_style_partitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ def unify_timestamp(table):
)


def assert_tables_equal(daft_df, pa_table):
def assert_tables_equal(daft_table, pa_table):
sorted_pa_table = pa_table.sort_by([("id", "ascending")]).select(SCHEMA.names)
sorted_pa_table = unify_timestamp(sorted_pa_table)
sorted_daft_table = daft_df.sort(daft.col("id")).to_arrow().combine_chunks().select(SCHEMA.names)
sorted_daft_table = daft_table.sort_by([("id", "ascending")]).select(SCHEMA.names)
sorted_daft_table = unify_timestamp(sorted_daft_table)
assert sorted_pa_table == sorted_daft_table

Expand All @@ -56,9 +56,9 @@ def assert_tables_equal(daft_df, pa_table):
["str_col"],
["int_col"],
["date_col"],
# TODO(desmond): pyarrow does not conform to RFC 3339, so their timestamp format differs from
# ours. Specifically, their timestamps are output as `%Y-%m-%d %H:%M:%S%.f%:z` but we parse ours
# as %Y-%m-%dT%H:%M:%S%.f%:z.
# TODO(desmond): pyarrow does not conform to RFC 3339, so their timestamp format differs
# from ours. Specifically, their timestamps are output as `%Y-%m-%d %H:%M:%S%.f%:z` but we
# parse ours as %Y-%m-%dT%H:%M:%S%.f%:z.
# ['timestamp_col'],
["nullable_str"],
["nullable_int"],
Expand Down Expand Up @@ -90,7 +90,14 @@ def test_hive_pyarrow_daft_compatibility(tmpdir, partition_by, file_format, filt
)

if file_format == "json":
daft_df = daft.read_json(glob_path, hive_partitioning=True)
daft_df = daft.read_json(
glob_path,
schema={
"nullable_str": daft.DataType.string(),
"nullable_int": daft.DataType.int64(),
},
hive_partitioning=True,
)
if file_format == "parquet":
daft_df = daft.read_parquet(
glob_path,
Expand All @@ -112,137 +119,77 @@ def test_hive_pyarrow_daft_compatibility(tmpdir, partition_by, file_format, filt
sample_value = SAMPLE_DATA[first_col][0].as_py()
daft_df = daft_df.where(daft.col(first_col) == sample_value)
pa_table = pa_ds.to_table(filter=ds.field(first_col) == sample_value)
assert_tables_equal(daft_df, pa_table)


# def test_hive_daft_write_read(tmpdir, sample_data):
# # Convert PyArrow table to Daft DataFrame
# daft_df = daft.from_pyarrow(sample_data)

# # Write using Daft with different partition columns
# partition_configs = [
# ["str_col"],
# ["int_col"],
# ["date_col"],
# ["timestamp_col"],
# ["nullable_str"],
# ["str_col", "int_col"]
# ]

# for partition_by in partition_configs:
# partition_dir = os.path.join(tmpdir, "_".join(partition_by))
# os.makedirs(partition_dir, exist_ok=True)

# # Write partitioned dataset using Daft
# daft_df.write_parquet(
# partition_dir,
# partition_cols=partition_by
# )

# # Read back with Daft
# read_df = daft.read_parquet(
# os.path.join(partition_dir, "**"),
# hive_partitioning=True
# )

# assert_tables_equal(read_df, sample_data)

# def test_null_handling(sample_data):
# with tempfile.TemporaryDirectory() as tmpdir:
# # Write dataset with nullable columns as partitions
# ds.write_dataset(
# sample_data,
# tmpdir,
# format="parquet",
# partitioning=ds.partitioning(
# pa.schema([
# sample_data.schema.field("nullable_str"),
# sample_data.schema.field("nullable_int")
# ]),
# flavor="hive"
# )
# )

# # Read back with Daft
# daft_df = daft.read_parquet(
# os.path.join(tmpdir, "**"),
# hive_partitioning=True
# )

# # Verify null values are handled correctly
# # PyArrow uses __HIVE_DEFAULT_PARTITION__ for nulls
# null_count_daft = daft_df.where(daft.col("nullable_str").is_null()).count_rows()
# expected_null_count = len([x for x in sample_data["nullable_str"] if x.is_null()])
# assert null_count_daft == expected_null_count


# @pytest.fixture(scope="session")
# def public_storage_io_config() -> daft.io.IOConfig:
# return daft.io.IOConfig(
# azure=daft.io.AzureConfig(storage_account="dafttestdata", anonymous=True),
# s3=daft.io.S3Config(region_name="us-west-2", anonymous=True),
# gcs=daft.io.GCSConfig(anonymous=True),
# )
assert_tables_equal(daft_df.to_arrow(), pa_table)


# def check_file(public_storage_io_config, read_fn, uri):
# # These tables are partitioned on id1 (utf8) and id4 (int64).
# df = read_fn(uri, hive_partitioning=True, io_config=public_storage_io_config)
# column_names = df.schema().column_names()
# assert "id1" in column_names
# assert "id4" in column_names
# assert len(column_names) == 9
# assert df.count_rows() == 100000
# # Test schema inference on partition columns.
# pa_schema = df.schema().to_pyarrow_schema()
# assert pa_schema.field("id1").type == pa.large_string()
# assert pa_schema.field("id4").type == pa.int64()
# # Test that schema hints work on partition columns.
# df = read_fn(
# uri,
# hive_partitioning=True,
# io_config=public_storage_io_config,
# schema={
# "id4": daft.DataType.int32(),
# },
# )
# pa_schema = df.schema().to_pyarrow_schema()
# assert pa_schema.field("id1").type == pa.large_string()
# assert pa_schema.field("id4").type == pa.int32()

# # Test selects on a partition column and a non-partition columns.
# df_select = df.select("id2", "id1")
# column_names = df_select.schema().column_names()
# assert "id1" in column_names
# assert "id2" in column_names
# assert "id4" not in column_names
# assert len(column_names) == 2
# # TODO(desmond): .count_rows currently returns 0 if the first column in the Select is a
# # partition column.
# assert df_select.count_rows() == 100000

# # Test filtering on partition columns.
# df_filter = df.where((daft.col("id1") == "id003") & (daft.col("id4") == 4) & (daft.col("id3") == "id0000000971"))
# column_names = df_filter.schema().column_names()
# assert "id1" in column_names
# assert "id4" in column_names
# assert len(column_names) == 9
# assert df_filter.count_rows() == 1


# @pytest.mark.integration()
# def test_hive_style_reads_s3_csv(public_storage_io_config):
# uri = "s3://daft-public-data/test_fixtures/hive-style/test.csv/**"
# check_file(public_storage_io_config, daft.read_csv, uri)


# @pytest.mark.integration()
# def test_hive_style_reads_s3_json(public_storage_io_config):
# uri = "s3://daft-public-data/test_fixtures/hive-style/test.json/**"
# check_file(public_storage_io_config, daft.read_json, uri)

@pytest.mark.parametrize(
"partition_by",
[
["str_col"],
["int_col"],
["date_col"],
# TODO(desmond): Same issue as the timestamp issue mentioned above.
# ['timestamp_col'],
["nullable_str"],
["nullable_int"],
["str_col", "int_col"], # Test multiple partition columns.
["nullable_str", "nullable_int"], # Test multiple partition columns with nulls.
],
)
# TODO(desmond): Daft does not currently have a write_json API.
@pytest.mark.parametrize("file_format", ["csv", "parquet"])
@pytest.mark.parametrize("filter", [True, False])
def test_hive_daft_roundtrip(tmpdir, partition_by, file_format, filter):
filepath = f"{tmpdir}"
source = daft.from_arrow(SAMPLE_DATA)

# @pytest.mark.integration()
# def test_hive_style_reads_s3_parquet(public_storage_io_config):
# uri = "s3://daft-public-data/test_fixtures/hive-style/test.parquet/**"
# check_file(public_storage_io_config, daft.read_parquet, uri)
glob_path = os.path.join(tmpdir, "**")
target = ()
if file_format == "csv":
source.write_csv(filepath, partition_cols=[daft.col(col) for col in partition_by])
target = daft.read_csv(
glob_path,
schema={
"nullable_str": daft.DataType.string(),
"nullable_int": daft.DataType.int64(),
},
hive_partitioning=True,
)
# TODO(desmond): Daft has an inconsistency with handling null string columns when using
# `from_arrow` vs `read_csv`. For now we read back an unpartitioned CSV table to check the
# result against.
plain_filepath = f"{tmpdir}-plain"
source.write_csv(plain_filepath)
source = daft.read_csv(
f"{plain_filepath}/**",
schema={
"nullable_str": daft.DataType.string(),
"nullable_int": daft.DataType.int64(),
},
)
if file_format == "json":
source.write_json(filepath, partition_cols=[daft.col(col) for col in partition_by])
target = daft.read_json(
glob_path,
schema={
"nullable_str": daft.DataType.string(),
"nullable_int": daft.DataType.int64(),
},
hive_partitioning=True,
)
if file_format == "parquet":
source.write_parquet(filepath, partition_cols=[daft.col(col) for col in partition_by])
target = daft.read_parquet(
glob_path,
schema={
"nullable_str": daft.DataType.string(),
"nullable_int": daft.DataType.int64(),
},
hive_partitioning=True,
)
if filter:
first_col = partition_by[0]
sample_value = SAMPLE_DATA[first_col][0].as_py()
source = source.where(daft.col(first_col) == sample_value)
target = target.where(daft.col(first_col) == sample_value)
assert_tables_equal(target.to_arrow(), source.to_arrow())

0 comments on commit 8446076

Please sign in to comment.