diff --git a/daft/series.py b/daft/series.py index 2430dbf31d..ece133c060 100644 --- a/daft/series.py +++ b/daft/series.py @@ -1,5 +1,6 @@ from __future__ import annotations +import sys from typing import TypeVar import pyarrow as pa @@ -490,8 +491,41 @@ def image(self) -> SeriesImageNamespace: def __reduce__(self) -> tuple: if self.datatype()._is_python_type(): return (Series.from_pylist, (self.to_pylist(), self.name(), "force")) - else: + elif sys.platform == "win32": return (Series.from_arrow, (self.to_arrow(), self.name())) + else: + # Ray Special CloudPickling fast path. + # Only run for Linux and Mac, since windows runs slower for some reason + return ( + Series._from_arrow_table_to_series, + self._to_arrow_table_for_serdes(), + ) + + def _to_arrow_table_for_serdes(self) -> tuple[pa.Table, pa.ExtensionType | None]: + array = self.to_arrow() + if len(array) == 0: + # This is a workaround for: + # pyarrow.lib.ArrowIndexError: buffer slice would exceed buffer length + # when we have 0 length arrays + array = pa.array([], type=array.type) + + if isinstance(array.type, pa.BaseExtensionType): + stype = array.type.storage_type + ltype = array.type + storage_array = array.cast(stype) + return (pa.table({self.name(): storage_array}), ltype) + else: + return (pa.table({self.name(): array}), None) + + @classmethod + def _from_arrow_table_to_series(cls, table: pa.Table, extension_type: pa.ExtensionType | None) -> Series: + # So we can exploit ray's special pickling for arrow tables which doesn't work on pyarrow arrays + assert table.num_columns == 1 + [name] = table.column_names + [array] = table.columns + if extension_type is not None: + array = extension_type.wrap_array(array) + return cls.from_arrow(array, name) SomeSeriesNamespace = TypeVar("SomeSeriesNamespace", bound="SeriesNamespace")