diff --git a/.github/workflows/python.yaml b/.github/workflows/python.yaml index b9ebd788..3203be9c 100644 --- a/.github/workflows/python.yaml +++ b/.github/workflows/python.yaml @@ -123,6 +123,11 @@ jobs: run: | conda install -c conda-forge -c rapidsai ucx-proc=*=cpu ucx ucx-py + - name: Install fury + if: ${{ (matrix.module != 'gpu') && (matrix.os == 'ubuntu-latest') && (matrix.python-version == '3.9') }} + run: | + pip install pyfury + - name: Install on GPU if: ${{ matrix.module == 'gpu' }} run: | diff --git a/python/xoscar/serialization/core.pyx b/python/xoscar/serialization/core.pyx index c8f94a59..745276f9 100644 --- a/python/xoscar/serialization/core.pyx +++ b/python/xoscar/serialization/core.pyx @@ -51,12 +51,12 @@ except (ImportError, AttributeError): from .._utils import NamedType from .._utils cimport TypeDispatcher +from .pyfury import get_fury BUFFER_PICKLE_PROTOCOL = max(pickle.DEFAULT_PROTOCOL, 5) cdef bint HAS_PICKLE_BUFFER = pickle.HIGHEST_PROTOCOL >= 5 cdef bint _PANDAS_HAS_MGR = hasattr(pd.Series([0]), "_mgr") - cdef TypeDispatcher _serial_dispatcher = TypeDispatcher() cdef dict _deserializers = dict() @@ -218,27 +218,47 @@ def buffered(func): def pickle_buffers(obj): cdef list buffers = [None] - if HAS_PICKLE_BUFFER: - + fury = get_fury() + if fury is not None: def buffer_cb(x): - x = x.raw() - if x.ndim > 1: - # ravel n-d memoryview - x = x.cast(x.format) - buffers.append(memoryview(x)) + try: + buffers.append(memoryview(x)) + except TypeError: + buffers.append(x.to_buffer()) - buffers[0] = cloudpickle.dumps( + buffers[0] = b"__fury__" + buffers.append(None) + buffers[1] = fury.serialize( obj, buffer_callback=buffer_cb, - protocol=BUFFER_PICKLE_PROTOCOL, ) - else: # pragma: no cover - buffers[0] = cloudpickle.dumps(obj) + else: + if HAS_PICKLE_BUFFER: + def buffer_cb(x): + x = x.raw() + if x.ndim > 1: + # ravel n-d memoryview + x = x.cast(x.format) + buffers.append(memoryview(x)) + + buffers[0] = cloudpickle.dumps( + obj, + buffer_callback=buffer_cb, + protocol=BUFFER_PICKLE_PROTOCOL, + ) + else: + buffers[0] = cloudpickle.dumps(obj) return buffers def unpickle_buffers(list buffers): - result = cloudpickle.loads(buffers[0], buffers=buffers[1:]) + if buffers[0] == b"__fury__": + fury = get_fury() + if fury is None: + raise Exception("fury is not installed.") + result = fury.deserialize(buffers[1], buffers[2:]) + else: + result = cloudpickle.loads(buffers[0], buffers=buffers[1:]) # as pandas prior to 1.1.0 use _data instead of _mgr to hold BlockManager, # deserializing from high versions may produce mal-functioned pandas objects, diff --git a/python/xoscar/serialization/pyfury.py b/python/xoscar/serialization/pyfury.py new file mode 100644 index 00000000..c192f554 --- /dev/null +++ b/python/xoscar/serialization/pyfury.py @@ -0,0 +1,37 @@ +import os +import threading + +_fury = threading.local() +_fury_not_installed = object() +_register_class_list = set() + + +def register_classes(*args): + instance = get_fury() + if instance is not None: + _register_class_list.update(args) + for c in _register_class_list: + instance.register_class(c) + + +def get_fury(): + if os.environ.get("USE_FURY") in ("1", "true", "True"): + instance = getattr(_fury, "instance", None) + if instance is _fury_not_installed: # pragma: no cover + return None + if instance is not None: + return instance + else: + try: + import pyfury + + _fury.instance = instance = pyfury.Fury( + language=pyfury.Language.PYTHON, require_class_registration=False + ) + for c in _register_class_list: # pragma: no cover + instance.register_class(c) + print("pyfury is enabled.") + except ImportError: # pragma: no cover + print("pyfury is not installed.") + _fury.instance = _fury_not_installed + return instance diff --git a/python/xoscar/serialization/tests/test_serial.py b/python/xoscar/serialization/tests/test_serial.py index 3a7eedc6..cc293640 100644 --- a/python/xoscar/serialization/tests/test_serial.py +++ b/python/xoscar/serialization/tests/test_serial.py @@ -15,6 +15,7 @@ from __future__ import annotations +import os import threading from collections import OrderedDict, defaultdict from typing import Any, Dict, List, Tuple @@ -39,6 +40,7 @@ cupy = lazy_import("cupy") cudf = lazy_import("cudf") +pyfury = lazy_import("pyfury") class CustomList(list): @@ -179,6 +181,32 @@ def test_arrow(): np.testing.assert_equal(val, deserialized) +@pytest.mark.skipif(pyfury is None, reason="need pyfury to run the cases") +def test_arrow_fury(): + os.environ["USE_FURY"] = "1" + from ..pyfury import register_classes + + try: + test_df = pd.DataFrame( + { + "a": np.random.rand(1000), + "b": np.random.choice(list("abcd"), size=(1000,)), + "c": np.random.randint(0, 100, size=(1000,)), + } + ) + register_classes(pa.RecordBatch, pa.Table) + test_vals = [ + pa.RecordBatch.from_pandas(test_df), + pa.Table.from_pandas(test_df), + ] + for val in test_vals: + deserialized = deserialize(*serialize(val)) + assert type(val) is type(deserialized) + np.testing.assert_equal(val, deserialized) + finally: + os.environ.pop("USE_FURY") + + @pytest.mark.parametrize( "np_val", [np.random.rand(100, 100), np.random.rand(100, 100).T],