Skip to content

Commit

Permalink
start adding shuffle-based drop_duplicates code path
Browse files Browse the repository at this point in the history
  • Loading branch information
rjzamora committed Sep 28, 2023
1 parent 2b45b21 commit d0c9df7
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 34 deletions.
102 changes: 100 additions & 2 deletions dask/dataframe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from tlz import first, merge, partition_all, remove, unique

import dask.array as da
from dask import core
from dask import config, core
from dask.array.core import Array, normalize_arg
from dask.bag import map_partitions as map_bag_partitions
from dask.base import (
Expand Down Expand Up @@ -196,6 +196,23 @@ def _concat(args, ignore_index=False):
)


def _determine_split_out_shuffle(shuffle, split_out):
"""Determine the default shuffle behavior based on split_out"""
if shuffle is None:
if split_out > 1:
# FIXME: This is using a different default but it is not fully
# understood why this is a better choice.
# For more context, see
# https://github.com/dask/dask/pull/9826/files#r1072395307
# https://github.com/dask/distributed/issues/5502
return config.get("dataframe.shuffle.method", None) or "tasks"
else:
return False
if shuffle is True:
return config.get("dataframe.shuffle.method", None) or "tasks"
return shuffle


def finalize(results):
return _concat(results)

Expand Down Expand Up @@ -855,7 +872,13 @@ def get_partition(self, n):
inconsistencies="keep=False will raise a ``NotImplementedError``",
)
def drop_duplicates(
self, subset=None, split_every=None, split_out=1, ignore_index=False, **kwargs
self,
subset=None,
split_every=None,
split_out=1,
shuffle=None,
ignore_index=False,
**kwargs,
):
if subset is not None:
# Let pandas error on bad inputs
Expand All @@ -871,6 +894,81 @@ def drop_duplicates(
raise NotImplementedError("drop_duplicates with keep=False")

chunk = M.drop_duplicates

if isinstance(self, Index) and self.known_divisions:
# Simple case that we are acting on an Index
# with known divisions
repartition_npartitions = max(
self.npartitions // (split_every or self.npartitions),
split_out,
)
return self.map_partitions(
chunk,
token="drop-duplicates-chunk",
meta=self._meta,
transform_divisions=False,
**kwargs,
).repartition(
npartitions=repartition_npartitions
).map_partitions(
chunk,
token="drop-duplicates-agg",
meta=self._meta,
transform_divisions=False,
**kwargs,
).repartition(npartitions=split_out)

shuffle = _determine_split_out_shuffle(shuffle, split_out)
if shuffle:
# Make sure we have a DataFrame to shuffle
if isinstance(self, Index):
df = self.to_frame(name=self.name or "__index__")
elif isinstance(self, Series):
df = self.to_frame(name=self.name or "__series__")
else:
df = self

# Choose appropriate shuffle partitioning
split_every = 8 if split_every is None else split_every
shuffle_npartitions = max(
df.npartitions // (split_every or df.npartitions),
split_out,
)

# Deduplicate, then shuffle, then deduplicate again
deduplicated = df.map_partitions(
chunk,
token="drop-duplicates-chunk",
meta=df._meta,
enforce_metadata=False,
transform_divisions=False,
**kwargs,
).shuffle(
subset or list(df.columns),
ignore_index=ignore_index,
npartitions=shuffle_npartitions,
shuffle=shuffle,
).map_partitions(
chunk,
meta=df._meta,
token="drop-duplicates-agg",
transform_divisions=False,
**kwargs,
)

# Convert back to Series/Index if necessary
if isinstance(self, Index):
deduplicated = deduplicated.set_index(self.name or "__index__", sort=False).index
if deduplicated.name == "__index__":
deduplicated.name = None
elif isinstance(self, Series):
deduplicated = deduplicated[self.name or "__series__"]
if deduplicated.name == "__series__":
deduplicated.name = None

# Return `split_out` partitions
return deduplicated.repartition(npartitions=split_out)

return aca(
self,
chunk=chunk,
Expand Down
32 changes: 4 additions & 28 deletions dask/dataframe/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
DataFrame,
Series,
_convert_to_numeric,
_determine_split_out_shuffle,
_extract_meta,
_Frame,
aca,
Expand Down Expand Up @@ -107,21 +108,6 @@ def _determine_levels(by):
return 0


def _determine_shuffle(shuffle, split_out):
"""Determine the default shuffle behavior based on split_out"""
if shuffle is None:
if split_out > 1:
# FIXME: This is using a different default but it is not fully
# understood why this is a better choice.
# For more context, see
# https://github.com/dask/dask/pull/9826/files#r1072395307
# https://github.com/dask/distributed/issues/5502
return config.get("dataframe.shuffle.method", None) or "tasks"
else:
return False
return shuffle


def _normalize_by(df, by):
"""Replace series with column names wherever possible."""
if not isinstance(df, DataFrame):
Expand Down Expand Up @@ -1538,7 +1524,7 @@ def _single_agg(
Aggregation with a single function/aggfunc rather than a compound spec
like in GroupBy.aggregate
"""
shuffle = _determine_shuffle(shuffle, split_out)
shuffle = _determine_split_out_shuffle(shuffle, split_out)

if self.sort is None and split_out > 1:
warnings.warn(SORT_SPLIT_OUT_WARNING, FutureWarning)
Expand Down Expand Up @@ -1998,11 +1984,7 @@ def median(
"aggregation (e.g., shuffle='tasks')"
)

# FIXME: This is using a different default but it is not fully
# understood why this is a better choice. For more context, see
# https://github.com/dask/dask/pull/9826/files#r1072395307
# https://github.com/dask/distributed/issues/5502
shuffle = shuffle or config.get("dataframe.shuffle.method", None) or "tasks"
shuffle = shuffle or _determine_split_out_shuffle(True, split_out)
numeric_only_kwargs = get_numeric_only_kwargs(numeric_only)

with check_numeric_only_deprecation(name="median"):
Expand Down Expand Up @@ -2229,7 +2211,7 @@ def aggregate(
category=FutureWarning,
)
split_out = 1
shuffle = _determine_shuffle(shuffle, split_out)
shuffle = _determine_split_out_shuffle(shuffle, split_out)

relabeling = None
columns = None
Expand Down Expand Up @@ -2337,12 +2319,6 @@ def aggregate(

# If we have a median in the spec, we cannot do an initial
# aggregation.
# FIXME: This is using a different default but it is not fully
# understood why this is a better choice. For more context, see
# https://github.com/dask/dask/pull/9826/files#r1072395307
# https://github.com/dask/distributed/issues/5502
if not isinstance(shuffle, str):
shuffle = config.get("dataframe.shuffle.method", None) or "tasks"
if has_median:
result = _shuffle_aggregate(
chunk_args,
Expand Down
9 changes: 5 additions & 4 deletions dask/dataframe/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1140,23 +1140,24 @@ def test_align_dataframes():
assert_eq(actual, expected, check_index=False, check_divisions=False)


def test_drop_duplicates():
@pytest.mark.parametrize("shuffle", [None, True])
def test_drop_duplicates(shuffle):
res = d.drop_duplicates()
res2 = d.drop_duplicates(split_every=2)
res2 = d.drop_duplicates(split_every=2, shuffle=shuffle)
sol = full.drop_duplicates()
assert_eq(res, sol)
assert_eq(res2, sol)
assert res._name != res2._name

res = d.a.drop_duplicates()
res2 = d.a.drop_duplicates(split_every=2)
res2 = d.a.drop_duplicates(split_every=2, shuffle=shuffle)
sol = full.a.drop_duplicates()
assert_eq(res, sol)
assert_eq(res2, sol)
assert res._name != res2._name

res = d.index.drop_duplicates()
res2 = d.index.drop_duplicates(split_every=2)
res2 = d.index.drop_duplicates(split_every=2, shuffle=shuffle)
sol = full.index.drop_duplicates()
assert_eq(res, sol)
assert_eq(res2, sol)
Expand Down

0 comments on commit d0c9df7

Please sign in to comment.