Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
rjzamora committed Sep 28, 2023
1 parent d0c9df7 commit c5d2811
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 66 deletions.
168 changes: 102 additions & 66 deletions dask/dataframe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,6 +867,67 @@ def get_partition(self, n):
msg = f"n must be 0 <= n < {self.npartitions}"
raise ValueError(msg)

def _shuffle_drop_duplicates(
self, split_out, split_every, shuffle, ignore_index, **kwargs
):
# Make sure we have a DataFrame to shuffle
if isinstance(self, Index):
df = self.to_frame(name=self.name or "__index__")
elif isinstance(self, Series):
df = self.to_frame(name=self.name or "__series__")
else:
df = self

# Choose appropriate shuffle partitioning
split_every = 8 if split_every is None else split_every
shuffle_npartitions = max(
df.npartitions // (split_every or df.npartitions),
split_out,
)

# Deduplicate, then shuffle, then deduplicate again
chunk = M.drop_duplicates
deduplicated = (
df.map_partitions(
chunk,
token="drop-duplicates-chunk",
meta=df._meta,
ignore_index=ignore_index,
enforce_metadata=False,
transform_divisions=False,
**kwargs,
)
.shuffle(
kwargs.get("subset", None) or list(df.columns),
ignore_index=ignore_index,
npartitions=shuffle_npartitions,
shuffle=shuffle,
)
.map_partitions(
chunk,
meta=df._meta,
ignore_index=ignore_index,
token="drop-duplicates-agg",
transform_divisions=False,
**kwargs,
)
)

# Convert back to Series/Index if necessary
if isinstance(self, Index):
deduplicated = deduplicated.set_index(
self.name or "__index__", sort=False
).index
if deduplicated.name == "__index__":
deduplicated.name = None
elif isinstance(self, Series):
deduplicated = deduplicated[self.name or "__series__"]
if deduplicated.name == "__series__":
deduplicated.name = None

# Return `split_out` partitions
return deduplicated.repartition(npartitions=split_out)

@derived_from(
pd.DataFrame,
inconsistencies="keep=False will raise a ``NotImplementedError``",
Expand All @@ -893,82 +954,24 @@ def drop_duplicates(
if kwargs.get("keep", True) is False:
raise NotImplementedError("drop_duplicates with keep=False")

chunk = M.drop_duplicates

if isinstance(self, Index) and self.known_divisions:
# Simple case that we are acting on an Index
# with known divisions
repartition_npartitions = max(
self.npartitions // (split_every or self.npartitions),
return self._drop_duplicates_known_divisions(
split_out,
)
return self.map_partitions(
chunk,
token="drop-duplicates-chunk",
meta=self._meta,
transform_divisions=False,
**kwargs,
).repartition(
npartitions=repartition_npartitions
).map_partitions(
chunk,
token="drop-duplicates-agg",
meta=self._meta,
transform_divisions=False,
split_every,
**kwargs,
).repartition(npartitions=split_out)
)

shuffle = _determine_split_out_shuffle(shuffle, split_out)
if shuffle:
# Make sure we have a DataFrame to shuffle
if isinstance(self, Index):
df = self.to_frame(name=self.name or "__index__")
elif isinstance(self, Series):
df = self.to_frame(name=self.name or "__series__")
else:
df = self

# Choose appropriate shuffle partitioning
split_every = 8 if split_every is None else split_every
shuffle_npartitions = max(
df.npartitions // (split_every or df.npartitions),
return self._shuffle_drop_duplicates(
split_out,
)

# Deduplicate, then shuffle, then deduplicate again
deduplicated = df.map_partitions(
chunk,
token="drop-duplicates-chunk",
meta=df._meta,
enforce_metadata=False,
transform_divisions=False,
**kwargs,
).shuffle(
subset or list(df.columns),
ignore_index=ignore_index,
npartitions=shuffle_npartitions,
shuffle=shuffle,
).map_partitions(
chunk,
meta=df._meta,
token="drop-duplicates-agg",
transform_divisions=False,
split_every,
shuffle,
ignore_index,
**kwargs,
)

# Convert back to Series/Index if necessary
if isinstance(self, Index):
deduplicated = deduplicated.set_index(self.name or "__index__", sort=False).index
if deduplicated.name == "__index__":
deduplicated.name = None
elif isinstance(self, Series):
deduplicated = deduplicated[self.name or "__series__"]
if deduplicated.name == "__series__":
deduplicated.name = None

# Return `split_out` partitions
return deduplicated.repartition(npartitions=split_out)

chunk = M.drop_duplicates
return aca(
self,
chunk=chunk,
Expand Down Expand Up @@ -4832,6 +4835,39 @@ def memory_usage(self, deep=False):
token=self._token_prefix + "memory-usage",
)

def _drop_duplicates_known_divisions(
self,
split_out,
split_every,
**kwargs,
):
# Simple `drop_duplicates` case that we are acting on
# an Index with known divisions
assert self.known_divisions, "Requires known divisions"
chunk = M.drop_duplicates
repartition_npartitions = max(
self.npartitions // (split_every or self.npartitions),
split_out,
)
return (
self.map_partitions(
chunk,
token="drop-duplicates-chunk",
meta=self._meta,
transform_divisions=False,
**kwargs,
)
.repartition(npartitions=repartition_npartitions)
.map_partitions(
chunk,
token="drop-duplicates-agg",
meta=self._meta,
transform_divisions=False,
**kwargs,
)
.repartition(npartitions=split_out)
)


class DataFrame(_Frame):
"""
Expand Down
7 changes: 7 additions & 0 deletions dask/dataframe/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1161,6 +1161,13 @@ def test_drop_duplicates(shuffle):
sol = full.index.drop_duplicates()
assert_eq(res, sol)
assert_eq(res2, sol)

_d = d.clear_divisions()
res = _d.index.drop_duplicates()
res2 = _d.index.drop_duplicates(split_every=2, shuffle=shuffle)
sol = full.index.drop_duplicates()
assert_eq(res, sol)
assert_eq(res2, sol)
assert res._name != res2._name

with pytest.raises(NotImplementedError):
Expand Down

0 comments on commit c5d2811

Please sign in to comment.