diff --git a/tests/io/test_hive_style_partitions.py b/tests/io/test_hive_style_partitions.py index 7457eee37a..9797799135 100644 --- a/tests/io/test_hive_style_partitions.py +++ b/tests/io/test_hive_style_partitions.py @@ -42,10 +42,10 @@ def unify_timestamp(table): ) -def assert_tables_equal(daft_df, pa_table): +def assert_tables_equal(daft_table, pa_table): sorted_pa_table = pa_table.sort_by([("id", "ascending")]).select(SCHEMA.names) sorted_pa_table = unify_timestamp(sorted_pa_table) - sorted_daft_table = daft_df.sort(daft.col("id")).to_arrow().combine_chunks().select(SCHEMA.names) + sorted_daft_table = daft_table.sort_by([("id", "ascending")]).select(SCHEMA.names) sorted_daft_table = unify_timestamp(sorted_daft_table) assert sorted_pa_table == sorted_daft_table @@ -56,9 +56,9 @@ def assert_tables_equal(daft_df, pa_table): ["str_col"], ["int_col"], ["date_col"], - # TODO(desmond): pyarrow does not conform to RFC 3339, so their timestamp format differs from - # ours. Specifically, their timestamps are output as `%Y-%m-%d %H:%M:%S%.f%:z` but we parse ours - # as %Y-%m-%dT%H:%M:%S%.f%:z. + # TODO(desmond): pyarrow does not conform to RFC 3339, so their timestamp format differs + # from ours. Specifically, their timestamps are output as `%Y-%m-%d %H:%M:%S%.f%:z` but we + # parse ours as %Y-%m-%dT%H:%M:%S%.f%:z. # ['timestamp_col'], ["nullable_str"], ["nullable_int"], @@ -90,7 +90,14 @@ def test_hive_pyarrow_daft_compatibility(tmpdir, partition_by, file_format, filt ) if file_format == "json": - daft_df = daft.read_json(glob_path, hive_partitioning=True) + daft_df = daft.read_json( + glob_path, + schema={ + "nullable_str": daft.DataType.string(), + "nullable_int": daft.DataType.int64(), + }, + hive_partitioning=True, + ) if file_format == "parquet": daft_df = daft.read_parquet( glob_path, @@ -112,137 +119,77 @@ def test_hive_pyarrow_daft_compatibility(tmpdir, partition_by, file_format, filt sample_value = SAMPLE_DATA[first_col][0].as_py() daft_df = daft_df.where(daft.col(first_col) == sample_value) pa_table = pa_ds.to_table(filter=ds.field(first_col) == sample_value) - assert_tables_equal(daft_df, pa_table) - - -# def test_hive_daft_write_read(tmpdir, sample_data): -# # Convert PyArrow table to Daft DataFrame -# daft_df = daft.from_pyarrow(sample_data) - -# # Write using Daft with different partition columns -# partition_configs = [ -# ["str_col"], -# ["int_col"], -# ["date_col"], -# ["timestamp_col"], -# ["nullable_str"], -# ["str_col", "int_col"] -# ] - -# for partition_by in partition_configs: -# partition_dir = os.path.join(tmpdir, "_".join(partition_by)) -# os.makedirs(partition_dir, exist_ok=True) - -# # Write partitioned dataset using Daft -# daft_df.write_parquet( -# partition_dir, -# partition_cols=partition_by -# ) - -# # Read back with Daft -# read_df = daft.read_parquet( -# os.path.join(partition_dir, "**"), -# hive_partitioning=True -# ) - -# assert_tables_equal(read_df, sample_data) - -# def test_null_handling(sample_data): -# with tempfile.TemporaryDirectory() as tmpdir: -# # Write dataset with nullable columns as partitions -# ds.write_dataset( -# sample_data, -# tmpdir, -# format="parquet", -# partitioning=ds.partitioning( -# pa.schema([ -# sample_data.schema.field("nullable_str"), -# sample_data.schema.field("nullable_int") -# ]), -# flavor="hive" -# ) -# ) - -# # Read back with Daft -# daft_df = daft.read_parquet( -# os.path.join(tmpdir, "**"), -# hive_partitioning=True -# ) - -# # Verify null values are handled correctly -# # PyArrow uses __HIVE_DEFAULT_PARTITION__ for nulls -# null_count_daft = daft_df.where(daft.col("nullable_str").is_null()).count_rows() -# expected_null_count = len([x for x in sample_data["nullable_str"] if x.is_null()]) -# assert null_count_daft == expected_null_count - - -# @pytest.fixture(scope="session") -# def public_storage_io_config() -> daft.io.IOConfig: -# return daft.io.IOConfig( -# azure=daft.io.AzureConfig(storage_account="dafttestdata", anonymous=True), -# s3=daft.io.S3Config(region_name="us-west-2", anonymous=True), -# gcs=daft.io.GCSConfig(anonymous=True), -# ) + assert_tables_equal(daft_df.to_arrow(), pa_table) -# def check_file(public_storage_io_config, read_fn, uri): -# # These tables are partitioned on id1 (utf8) and id4 (int64). -# df = read_fn(uri, hive_partitioning=True, io_config=public_storage_io_config) -# column_names = df.schema().column_names() -# assert "id1" in column_names -# assert "id4" in column_names -# assert len(column_names) == 9 -# assert df.count_rows() == 100000 -# # Test schema inference on partition columns. -# pa_schema = df.schema().to_pyarrow_schema() -# assert pa_schema.field("id1").type == pa.large_string() -# assert pa_schema.field("id4").type == pa.int64() -# # Test that schema hints work on partition columns. -# df = read_fn( -# uri, -# hive_partitioning=True, -# io_config=public_storage_io_config, -# schema={ -# "id4": daft.DataType.int32(), -# }, -# ) -# pa_schema = df.schema().to_pyarrow_schema() -# assert pa_schema.field("id1").type == pa.large_string() -# assert pa_schema.field("id4").type == pa.int32() - -# # Test selects on a partition column and a non-partition columns. -# df_select = df.select("id2", "id1") -# column_names = df_select.schema().column_names() -# assert "id1" in column_names -# assert "id2" in column_names -# assert "id4" not in column_names -# assert len(column_names) == 2 -# # TODO(desmond): .count_rows currently returns 0 if the first column in the Select is a -# # partition column. -# assert df_select.count_rows() == 100000 - -# # Test filtering on partition columns. -# df_filter = df.where((daft.col("id1") == "id003") & (daft.col("id4") == 4) & (daft.col("id3") == "id0000000971")) -# column_names = df_filter.schema().column_names() -# assert "id1" in column_names -# assert "id4" in column_names -# assert len(column_names) == 9 -# assert df_filter.count_rows() == 1 - - -# @pytest.mark.integration() -# def test_hive_style_reads_s3_csv(public_storage_io_config): -# uri = "s3://daft-public-data/test_fixtures/hive-style/test.csv/**" -# check_file(public_storage_io_config, daft.read_csv, uri) - - -# @pytest.mark.integration() -# def test_hive_style_reads_s3_json(public_storage_io_config): -# uri = "s3://daft-public-data/test_fixtures/hive-style/test.json/**" -# check_file(public_storage_io_config, daft.read_json, uri) - +@pytest.mark.parametrize( + "partition_by", + [ + ["str_col"], + ["int_col"], + ["date_col"], + # TODO(desmond): Same issue as the timestamp issue mentioned above. + # ['timestamp_col'], + ["nullable_str"], + ["nullable_int"], + ["str_col", "int_col"], # Test multiple partition columns. + ["nullable_str", "nullable_int"], # Test multiple partition columns with nulls. + ], +) +# TODO(desmond): Daft does not currently have a write_json API. +@pytest.mark.parametrize("file_format", ["csv", "parquet"]) +@pytest.mark.parametrize("filter", [True, False]) +def test_hive_daft_roundtrip(tmpdir, partition_by, file_format, filter): + filepath = f"{tmpdir}" + source = daft.from_arrow(SAMPLE_DATA) -# @pytest.mark.integration() -# def test_hive_style_reads_s3_parquet(public_storage_io_config): -# uri = "s3://daft-public-data/test_fixtures/hive-style/test.parquet/**" -# check_file(public_storage_io_config, daft.read_parquet, uri) + glob_path = os.path.join(tmpdir, "**") + target = () + if file_format == "csv": + source.write_csv(filepath, partition_cols=[daft.col(col) for col in partition_by]) + target = daft.read_csv( + glob_path, + schema={ + "nullable_str": daft.DataType.string(), + "nullable_int": daft.DataType.int64(), + }, + hive_partitioning=True, + ) + # TODO(desmond): Daft has an inconsistency with handling null string columns when using + # `from_arrow` vs `read_csv`. For now we read back an unpartitioned CSV table to check the + # result against. + plain_filepath = f"{tmpdir}-plain" + source.write_csv(plain_filepath) + source = daft.read_csv( + f"{plain_filepath}/**", + schema={ + "nullable_str": daft.DataType.string(), + "nullable_int": daft.DataType.int64(), + }, + ) + if file_format == "json": + source.write_json(filepath, partition_cols=[daft.col(col) for col in partition_by]) + target = daft.read_json( + glob_path, + schema={ + "nullable_str": daft.DataType.string(), + "nullable_int": daft.DataType.int64(), + }, + hive_partitioning=True, + ) + if file_format == "parquet": + source.write_parquet(filepath, partition_cols=[daft.col(col) for col in partition_by]) + target = daft.read_parquet( + glob_path, + schema={ + "nullable_str": daft.DataType.string(), + "nullable_int": daft.DataType.int64(), + }, + hive_partitioning=True, + ) + if filter: + first_col = partition_by[0] + sample_value = SAMPLE_DATA[first_col][0].as_py() + source = source.where(daft.col(first_col) == sample_value) + target = target.where(daft.col(first_col) == sample_value) + assert_tables_equal(target.to_arrow(), source.to_arrow())