Skip to content

Commit

Permalink
add logic to manually parse row-count statistics if the upstream read…
Browse files Browse the repository at this point in the history
…_parquet call doesn't populate the partition statistics (#178)
  • Loading branch information
rjzamora authored Dec 1, 2022
1 parent 78f1f0b commit 4f73ff5
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 19 deletions.
7 changes: 2 additions & 5 deletions merlin/io/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,8 @@ class Dataset:
This overrides the derived schema behavior.
**kwargs :
Key-word arguments to pass through to Dask.dataframe IO function.
For the Parquet engine(s), notable arguments include `filters`,
`aggregate_files`, and `gather_statistics`. Note that users who
do not need to know the number of rows in their dataset (and do
not plan to preserve a file-partition mapping) may wish to use
`gather_statistics=False` for better client-side performance.
For the Parquet engine(s), notable arguments include `filters`
and `aggregate_files` (the latter is experimental).
"""

def __init__(
Expand Down
52 changes: 40 additions & 12 deletions merlin/io/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,34 @@ def set_object_dtypes_from_pa_schema(df, schema):
df._data[col_name] = col.astype(typ)


def _read_partition_lens(part, fs):
# Manually read row-count statistics from the Parquet
# metadata for a single `read_parquet` `part`. Return
# the result in the same format that `read_metadata`
# returns statistics

if not isinstance(part, list):
part = [part]

num_rows = 0
for p in part:
piece = p["piece"]
path = piece[0]
row_groups = piece[1]
if row_groups == [None]:
row_groups = None
with fs.open(path, default_cache="none") as f:
md = pq.ParquetFile(f).metadata
if row_groups is None:
row_groups = list(range(md.num_row_groups))
for rg in row_groups:
row_group = md.row_group(rg)
num_rows += row_group.num_rows

# Ignore column-statistics (for now)
return {"num-rows": num_rows, "columns": []}


def _override_read_metadata(
parent_read_metadata,
fs,
Expand All @@ -205,29 +233,25 @@ def _override_read_metadata(
# For now, disallow the user from setting `chunksize`
if chunksize:
raise ValueError(
"NVTabular does not yet support the explicit use " "of Dask's `chunksize` argument."
"NVTabular does not yet support the explicit use of Dask's `chunksize` argument."
)

# Extract metadata_collector from the dataset "container"
dataset = dataset or {}
metadata_collector = dataset.pop("metadata_collector", None)

# Gather statistics by default.
# This enables optimized length calculations
if gather_statistics is None:
gather_statistics = True
# gather_statistics is deprecated in `dask.dataframe`
if gather_statistics:
warnings.warn(
"`gather_statistics` is now deprecated and will be ignored.",
FutureWarning,
)

# Use a local_kwarg dictionary to make it easier to exclude
# `aggregate_files` for older Dask versions
local_kwargs = {
"index": index,
"filters": filters,
# Use chunksize=1 to "ensure" statistics are gathered
# if `gather_statistics=True`. Note that Dask will bail
# from statistics gathering if it does not expect statistics
# to be "used" after `read_metadata` returns.
"chunksize": 1 if gather_statistics else None,
"gather_statistics": gather_statistics,
"split_row_groups": split_row_groups,
}
if aggregate_row_groups is not None:
Expand All @@ -247,7 +271,7 @@ def _override_read_metadata(
statistics = read_metadata_result[1].copy()

# Process the statistics.
# Note that these steps are usaually performed after
# Note that these steps are usually performed after
# `engine.read_metadata` returns (in Dask), but we are doing
# it ourselves in NVTabular (to capture the expected output
# partitioning plan)
Expand Down Expand Up @@ -400,6 +424,10 @@ def _process_parquet_metadata(self):
# in parts and stats
parts = self._pp_metadata["parts"]
stats = self._pp_metadata["stats"]
if not stats:
# Update stats if `dd.read_parquet` didn't populate it
stats = [_read_partition_lens(part, self.fs) for part in parts]
self._pp_metadata["stats"] = stats
_pp_map = {}
_pp_nrows = []
distinct_files = True
Expand Down
16 changes: 14 additions & 2 deletions tests/unit/io/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@
dask_cudf = pytest.importorskip("dask_cudf")


def _check_partition_lens(ds):
# Simple utility to check that the Parquet metadata
# is correctly encoding the partition lengths
_lens = [len(part) for part in ds.to_ddf().partitions]
assert ds.engine._partition_lens == _lens


def test_validate_dataset_bad_schema(tmpdir):
if Version(dask.__version__) <= Version("2.30.0"):
# Older versions of Dask will not handle schema mismatch
Expand Down Expand Up @@ -157,6 +164,9 @@ def test_dask_dataset_itr(tmpdir, datasets, engine, gpu_memory_frac):
size += chunk.shape[0]
assert chunk["id"].dtype == np.int32

if engine == "parquet":
_check_partition_lens(ds)

assert size == df1.shape[0]
assert len(my_iter) == size

Expand Down Expand Up @@ -330,7 +340,7 @@ def test_dask_dataframe_methods(tmpdir):

@pytest.mark.parametrize("inp_format", ["dask", "dask_cudf", "cudf", "pandas"])
def test_ddf_dataset_itr(tmpdir, datasets, inp_format):
paths = glob.glob(str(datasets["parquet"]) + "/*." + "parquet".split("-")[0])
paths = glob.glob(str(datasets["parquet"]) + "/*." + "parquet".split("-", maxsplit=1)[0])
ddf1 = dask_cudf.read_parquet(paths)[mycols_pq]
df1 = ddf1.compute()
if inp_format == "dask":
Expand Down Expand Up @@ -678,7 +688,7 @@ def test_dataset_shuffle_on_keys(tmpdir, cpu, partition_on, keys, npartitions):

# A successful shuffle will return the same unique-value
# count for both the full dask algorithm and a partition-wise sum
n1 = sum([len(p[keys].drop_duplicates()) for p in ddf2.partitions])
n1 = sum(len(p[keys].drop_duplicates()) for p in ddf2.partitions)
n2 = len(ddf2[keys].drop_duplicates())
assert n1 == n2

Expand Down Expand Up @@ -768,6 +778,7 @@ def test_parquet_aggregate_files(tmpdir, cpu):
path, cpu=cpu, engine="parquet", aggregate_files="timestamp", part_size="1GB"
)
assert ds.to_ddf().npartitions == len(ddf.timestamp.unique())
_check_partition_lens(ds)

# Combining `aggregate_files` and `filters` should work
ds = merlin.io.Dataset(
Expand All @@ -780,3 +791,4 @@ def test_parquet_aggregate_files(tmpdir, cpu):
)
assert ds.to_ddf().npartitions == 1
assert len(ds.to_ddf().timestamp.unique()) == 1
_check_partition_lens(ds)

0 comments on commit 4f73ff5

Please sign in to comment.