Skip to content

Commit

Permalink
correctly handle ufuncs with uncompressed arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
marco-neumann-by committed Aug 13, 2020
1 parent 4d6ff47 commit ca00500
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 29 deletions.
4 changes: 3 additions & 1 deletion rle_array/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,7 +705,9 @@ def __array_ufunc__(
return NotImplemented

# Defer to the implementation of the ufunc on unwrapped values.
inputs_has_ndarray = any(isinstance(x, np.ndarray) for x in inputs)
inputs = tuple(np.asarray(x) if isinstance(x, RLEArray) else x for x in inputs)

if out:
kwargs["out"] = tuple(
np.asarray(x) if isinstance(x, RLEArray) else x for x in out
Expand All @@ -717,7 +719,7 @@ def __array_ufunc__(
x[:] = y

def maybe_from_sequence(x: np.ndarray) -> Union[RLEArray, np.ndarray]:
if x.ndim == 1:
if (x.ndim == 1) and (not inputs_has_ndarray):
# suitable for RLE compression
return type(self)._from_sequence(x)
else:
Expand Down
33 changes: 5 additions & 28 deletions tests/test_operators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import operator
from typing import Any, Callable, Type, Union, cast
from typing import Any, Callable, cast

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -235,25 +235,9 @@ def test_binary_operator_uncompressed_series(
binary_operator: FBinaryOperator,
) -> None:
actual = binary_operator(rle_series, uncompressed_series2)
if getattr(binary_operator, "__name__", "???") in (
"radd",
"rfloordiv",
"rmod",
"rmul",
"rpow",
"rsub",
"rtruediv",
):
# pd.Series does not implement these operations correctly
expected_dtype: Union[RLEDtype, Type[float]] = RLEDtype(float)
else:
expected_dtype = float

assert actual.dtype == expected_dtype
assert actual.dtype == float

expected = binary_operator(uncompressed_series, uncompressed_series2).astype(
expected_dtype
)
expected = binary_operator(uncompressed_series, uncompressed_series2)
pd.testing.assert_series_equal(actual, expected)


Expand Down Expand Up @@ -295,16 +279,9 @@ def test_binary_bool_operator_uncompressed_series(
binary_bool_operator: FBinaryBoolOperator,
) -> None:
actual = binary_bool_operator(rle_bool_series, uncompressed_bool_series2)
if getattr(binary_bool_operator, "__name__", "???") in ("rand_", "ror_", "rxor"):
# pd.Series does not implement these operations correctly
expected_dtype: Union[RLEDtype, Type[bool]] = RLEDtype(bool)
else:
expected_dtype = bool
assert actual.dtype == expected_dtype
assert actual.dtype == bool

expected = binary_bool_operator(
uncompressed_bool_series, uncompressed_bool_series2
).astype(expected_dtype)
expected = binary_bool_operator(uncompressed_bool_series, uncompressed_bool_series2)
pd.testing.assert_series_equal(actual, expected)


Expand Down
18 changes: 18 additions & 0 deletions tests/test_ufunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,21 @@ def test_2d_broadcast_divmod(array_orig: np.ndarray, array_rle: RLEArray) -> Non
assert actual2.dtype == expected2.dtype
npt.assert_array_equal(actual1, expected1)
npt.assert_array_equal(actual2, expected2)


def test_mixed_typing_mul(array_orig: np.ndarray, array_rle: RLEArray) -> None:
actual = array_orig * array_rle

expected = array_orig * array_orig
assert actual.dtype == expected.dtype
npt.assert_array_equal(actual, expected)


def test_mixed_typing_divmod(array_orig: np.ndarray, array_rle: RLEArray) -> None:
actual1, actual2 = np.divmod(array_orig, array_rle)

expected1, expected2 = np.divmod(array_orig, array_orig)
assert actual1.dtype == expected1.dtype
assert actual2.dtype == expected2.dtype
npt.assert_array_equal(actual1, expected1)
npt.assert_array_equal(actual2, expected2)

0 comments on commit ca00500

Please sign in to comment.