Skip to content

Commit

Permalink
add partd_encode_dispatch
Browse files Browse the repository at this point in the history
  • Loading branch information
rjzamora committed Oct 9, 2023
1 parent 928a95a commit 7ef12d4
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 7 deletions.
8 changes: 8 additions & 0 deletions dask/dataframe/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
make_meta_obj,
meta_lib_from_array,
meta_nonempty,
partd_encode_dispatch,
pyarrow_schema_dispatch,
to_pandas_dispatch,
to_pyarrow_table_dispatch,
Expand Down Expand Up @@ -242,6 +243,13 @@ def default_types_mapper(pyarrow_dtype: pa.DataType) -> object:
return table.to_pandas(types_mapper=types_mapper, **kwargs)


@partd_encode_dispatch.register(pd.DataFrame)
def partd_pandas_blocks(_):
from partd import PandasBlocks

return PandasBlocks


@meta_nonempty.register(pd.DatetimeTZDtype)
@make_meta_dispatch.register(pd.DatetimeTZDtype)
def make_meta_pandas_datetime_tz(x, index=None):
Expand Down
1 change: 1 addition & 0 deletions dask/dataframe/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
is_categorical_dtype_dispatch = Dispatch("is_categorical_dtype")
union_categoricals_dispatch = Dispatch("union_categoricals")
grouper_dispatch = Dispatch("grouper")
partd_encode_dispatch = Dispatch("partd_encode_dispatch")
pyarrow_schema_dispatch = Dispatch("pyarrow_schema_dispatch")
from_pyarrow_table_dispatch = Dispatch("from_pyarrow_table_dispatch")
to_pyarrow_table_dispatch = Dispatch("to_pyarrow_table_dispatch")
Expand Down
27 changes: 20 additions & 7 deletions dask/dataframe/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
from dask.base import compute, compute_as_if_collection, is_dask_collection, tokenize
from dask.dataframe import methods
from dask.dataframe.core import DataFrame, Series, _Frame, map_partitions, new_dd_object
from dask.dataframe.dispatch import group_split_dispatch, hash_object_dispatch
from dask.dataframe.dispatch import (
group_split_dispatch,
hash_object_dispatch,
partd_encode_dispatch,
)
from dask.dataframe.utils import UNKNOWN_CATEGORIES
from dask.highlevelgraph import HighLevelGraph
from dask.layers import ShuffleLayer, SimpleShuffleLayer
Expand Down Expand Up @@ -521,16 +525,23 @@ class maybe_buffered_partd:
If serialized, will return non-buffered partd. Otherwise returns a buffered partd
"""

def __init__(self, buffer=True, tempdir=None):
def __init__(self, encode=None, buffer=True, tempdir=None):
self.tempdir = tempdir or config.get("temporary_directory", None)
self.buffer = buffer
self.compression = config.get("dataframe.shuffle.compression", None)
self.encode = encode

def __reduce__(self):
if self.tempdir:
return (maybe_buffered_partd, (False, self.tempdir))
return (maybe_buffered_partd, (self.encode, False, self.tempdir))
else:
return (maybe_buffered_partd, (False,))
return (
maybe_buffered_partd,
(
self.encode,
False,
),
)

def __call__(self, *args, **kwargs):
import partd
Expand All @@ -555,10 +566,11 @@ def __call__(self, *args, **kwargs):
# Envelope partd file with compression, if set and available
if partd_compression:
file = partd_compression(file)
encode = self.encode or partd.PandasBlocks
if self.buffer:
return partd.PandasBlocks(partd.Buffer(partd.Dict(), file))
return encode(partd.Buffer(partd.Dict(), file))
else:
return partd.PandasBlocks(file)
return encode(file)


def rearrange_by_column_disk(df, column, npartitions=None, compute=False):
Expand All @@ -577,7 +589,8 @@ def rearrange_by_column_disk(df, column, npartitions=None, compute=False):
always_new_token = uuid.uuid1().hex

p = ("zpartd-" + always_new_token,)
dsk1 = {p: (maybe_buffered_partd(),)}
encode = partd_encode_dispatch(df._meta)
dsk1 = {p: (maybe_buffered_partd(encode=encode),)}

# Partition data on disk
name = "shuffle-partition-" + always_new_token
Expand Down

0 comments on commit 7ef12d4

Please sign in to comment.