diff --git a/dask/dataframe/backends.py b/dask/dataframe/backends.py index 9ef5d1b360f..ee900c3bfef 100644 --- a/dask/dataframe/backends.py +++ b/dask/dataframe/backends.py @@ -27,6 +27,7 @@ make_meta_obj, meta_lib_from_array, meta_nonempty, + partd_encode_dispatch, pyarrow_schema_dispatch, to_pandas_dispatch, to_pyarrow_table_dispatch, @@ -242,6 +243,13 @@ def default_types_mapper(pyarrow_dtype: pa.DataType) -> object: return table.to_pandas(types_mapper=types_mapper, **kwargs) +@partd_encode_dispatch.register(pd.DataFrame) +def partd_pandas_blocks(_): + from partd import PandasBlocks + + return PandasBlocks + + @meta_nonempty.register(pd.DatetimeTZDtype) @make_meta_dispatch.register(pd.DatetimeTZDtype) def make_meta_pandas_datetime_tz(x, index=None): diff --git a/dask/dataframe/dispatch.py b/dask/dataframe/dispatch.py index 307d5269ae8..9c70302dee8 100644 --- a/dask/dataframe/dispatch.py +++ b/dask/dataframe/dispatch.py @@ -25,6 +25,7 @@ is_categorical_dtype_dispatch = Dispatch("is_categorical_dtype") union_categoricals_dispatch = Dispatch("union_categoricals") grouper_dispatch = Dispatch("grouper") +partd_encode_dispatch = Dispatch("partd_encode_dispatch") pyarrow_schema_dispatch = Dispatch("pyarrow_schema_dispatch") from_pyarrow_table_dispatch = Dispatch("from_pyarrow_table_dispatch") to_pyarrow_table_dispatch = Dispatch("to_pyarrow_table_dispatch") diff --git a/dask/dataframe/shuffle.py b/dask/dataframe/shuffle.py index 4698358495b..556f982fc1f 100644 --- a/dask/dataframe/shuffle.py +++ b/dask/dataframe/shuffle.py @@ -19,7 +19,11 @@ from dask.base import compute, compute_as_if_collection, is_dask_collection, tokenize from dask.dataframe import methods from dask.dataframe.core import DataFrame, Series, _Frame, map_partitions, new_dd_object -from dask.dataframe.dispatch import group_split_dispatch, hash_object_dispatch +from dask.dataframe.dispatch import ( + group_split_dispatch, + hash_object_dispatch, + partd_encode_dispatch, +) from dask.dataframe.utils import UNKNOWN_CATEGORIES from dask.highlevelgraph import HighLevelGraph from dask.layers import ShuffleLayer, SimpleShuffleLayer @@ -521,16 +525,23 @@ class maybe_buffered_partd: If serialized, will return non-buffered partd. Otherwise returns a buffered partd """ - def __init__(self, buffer=True, tempdir=None): + def __init__(self, encode=None, buffer=True, tempdir=None): self.tempdir = tempdir or config.get("temporary_directory", None) self.buffer = buffer self.compression = config.get("dataframe.shuffle.compression", None) + self.encode = encode def __reduce__(self): if self.tempdir: - return (maybe_buffered_partd, (False, self.tempdir)) + return (maybe_buffered_partd, (self.encode, False, self.tempdir)) else: - return (maybe_buffered_partd, (False,)) + return ( + maybe_buffered_partd, + ( + self.encode, + False, + ), + ) def __call__(self, *args, **kwargs): import partd @@ -555,10 +566,11 @@ def __call__(self, *args, **kwargs): # Envelope partd file with compression, if set and available if partd_compression: file = partd_compression(file) + encode = self.encode or partd.PandasBlocks if self.buffer: - return partd.PandasBlocks(partd.Buffer(partd.Dict(), file)) + return encode(partd.Buffer(partd.Dict(), file)) else: - return partd.PandasBlocks(file) + return encode(file) def rearrange_by_column_disk(df, column, npartitions=None, compute=False): @@ -577,7 +589,8 @@ def rearrange_by_column_disk(df, column, npartitions=None, compute=False): always_new_token = uuid.uuid1().hex p = ("zpartd-" + always_new_token,) - dsk1 = {p: (maybe_buffered_partd(),)} + encode = partd_encode_dispatch(df._meta) + dsk1 = {p: (maybe_buffered_partd(encode=encode),)} # Partition data on disk name = "shuffle-partition-" + always_new_token