Skip to content

Commit

Permalink
FEAT: Support fury (#76)
Browse files Browse the repository at this point in the history
  • Loading branch information
codingl2k1 authored Sep 12, 2023
1 parent 2683f2e commit 17881dd
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 13 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/python.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,11 @@ jobs:
run: |
conda install -c conda-forge -c rapidsai ucx-proc=*=cpu ucx ucx-py
- name: Install fury
if: ${{ (matrix.module != 'gpu') && (matrix.os == 'ubuntu-latest') && (matrix.python-version == '3.9') }}
run: |
pip install pyfury
- name: Install on GPU
if: ${{ matrix.module == 'gpu' }}
run: |
Expand Down
46 changes: 33 additions & 13 deletions python/xoscar/serialization/core.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ except (ImportError, AttributeError):
from .._utils import NamedType

from .._utils cimport TypeDispatcher
from .pyfury import get_fury

BUFFER_PICKLE_PROTOCOL = max(pickle.DEFAULT_PROTOCOL, 5)
cdef bint HAS_PICKLE_BUFFER = pickle.HIGHEST_PROTOCOL >= 5
cdef bint _PANDAS_HAS_MGR = hasattr(pd.Series([0]), "_mgr")


cdef TypeDispatcher _serial_dispatcher = TypeDispatcher()
cdef dict _deserializers = dict()

Expand Down Expand Up @@ -218,27 +218,47 @@ def buffered(func):
def pickle_buffers(obj):
cdef list buffers = [None]

if HAS_PICKLE_BUFFER:

fury = get_fury()
if fury is not None:
def buffer_cb(x):
x = x.raw()
if x.ndim > 1:
# ravel n-d memoryview
x = x.cast(x.format)
buffers.append(memoryview(x))
try:
buffers.append(memoryview(x))
except TypeError:
buffers.append(x.to_buffer())

buffers[0] = cloudpickle.dumps(
buffers[0] = b"__fury__"
buffers.append(None)
buffers[1] = fury.serialize(
obj,
buffer_callback=buffer_cb,
protocol=BUFFER_PICKLE_PROTOCOL,
)
else: # pragma: no cover
buffers[0] = cloudpickle.dumps(obj)
else:
if HAS_PICKLE_BUFFER:
def buffer_cb(x):
x = x.raw()
if x.ndim > 1:
# ravel n-d memoryview
x = x.cast(x.format)
buffers.append(memoryview(x))

buffers[0] = cloudpickle.dumps(
obj,
buffer_callback=buffer_cb,
protocol=BUFFER_PICKLE_PROTOCOL,
)
else:
buffers[0] = cloudpickle.dumps(obj)
return buffers


def unpickle_buffers(list buffers):
result = cloudpickle.loads(buffers[0], buffers=buffers[1:])
if buffers[0] == b"__fury__":
fury = get_fury()
if fury is None:
raise Exception("fury is not installed.")
result = fury.deserialize(buffers[1], buffers[2:])
else:
result = cloudpickle.loads(buffers[0], buffers=buffers[1:])

# as pandas prior to 1.1.0 use _data instead of _mgr to hold BlockManager,
# deserializing from high versions may produce mal-functioned pandas objects,
Expand Down
37 changes: 37 additions & 0 deletions python/xoscar/serialization/pyfury.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import os
import threading

_fury = threading.local()
_fury_not_installed = object()
_register_class_list = set()


def register_classes(*args):
instance = get_fury()
if instance is not None:
_register_class_list.update(args)
for c in _register_class_list:
instance.register_class(c)


def get_fury():
if os.environ.get("USE_FURY") in ("1", "true", "True"):
instance = getattr(_fury, "instance", None)
if instance is _fury_not_installed: # pragma: no cover
return None
if instance is not None:
return instance
else:
try:
import pyfury

_fury.instance = instance = pyfury.Fury(
language=pyfury.Language.PYTHON, require_class_registration=False
)
for c in _register_class_list: # pragma: no cover
instance.register_class(c)
print("pyfury is enabled.")
except ImportError: # pragma: no cover
print("pyfury is not installed.")
_fury.instance = _fury_not_installed
return instance
28 changes: 28 additions & 0 deletions python/xoscar/serialization/tests/test_serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from __future__ import annotations

import os
import threading
from collections import OrderedDict, defaultdict
from typing import Any, Dict, List, Tuple
Expand All @@ -39,6 +40,7 @@

cupy = lazy_import("cupy")
cudf = lazy_import("cudf")
pyfury = lazy_import("pyfury")


class CustomList(list):
Expand Down Expand Up @@ -179,6 +181,32 @@ def test_arrow():
np.testing.assert_equal(val, deserialized)


@pytest.mark.skipif(pyfury is None, reason="need pyfury to run the cases")
def test_arrow_fury():
os.environ["USE_FURY"] = "1"
from ..pyfury import register_classes

try:
test_df = pd.DataFrame(
{
"a": np.random.rand(1000),
"b": np.random.choice(list("abcd"), size=(1000,)),
"c": np.random.randint(0, 100, size=(1000,)),
}
)
register_classes(pa.RecordBatch, pa.Table)
test_vals = [
pa.RecordBatch.from_pandas(test_df),
pa.Table.from_pandas(test_df),
]
for val in test_vals:
deserialized = deserialize(*serialize(val))
assert type(val) is type(deserialized)
np.testing.assert_equal(val, deserialized)
finally:
os.environ.pop("USE_FURY")


@pytest.mark.parametrize(
"np_val",
[np.random.rand(100, 100), np.random.rand(100, 100).T],
Expand Down

0 comments on commit 17881dd

Please sign in to comment.