Skip to content

Commit

Permalink
Implement __array_prepare__ and __array_wrap__ by __array_ufunc__
Browse files Browse the repository at this point in the history
For numpy 2.0 compatiblity
Still needs a bit of cleanup and more extensive tests
  • Loading branch information
mstimberg committed Feb 9, 2024
1 parent 1a20c8c commit 657c22b
Show file tree
Hide file tree
Showing 4 changed files with 262 additions and 414 deletions.
26 changes: 6 additions & 20 deletions brian2/core/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,8 @@ class VariableView:
``G.var_``).
"""

__array_priority__ = 10

def __init__(self, name, variable, group, dimensions=None):
self.name = name
self.variable = variable
Expand Down Expand Up @@ -1349,27 +1351,11 @@ def __array__(self, dtype=None):
)
return np.asanyarray(self[:], dtype=dtype)

def __array_prepare__(self, array, context=None):
if self.dim is None:
return array
else:
this = self[:]
if isinstance(this, Quantity):
return Quantity.__array_prepare__(this, array, context=context)
else:
return array

def __array_wrap__(self, out_arr, context=None, return_scalar=False):
if self.dim is None:
return out_arr
def __array__ufunc__(self, ufunc, method, *inputs, **kwargs):
if method == "__call__":
return ufunc(self[:], *inputs, **kwargs)
else:
this = self[:]
if isinstance(this, Quantity):
return Quantity.__array_wrap__(
self[:], out_arr, context=context, return_scalar=return_scalar
)
else:
return out_arr
return NotImplemented

def __len__(self):
return len(self.get_item(slice(None), level=1))
Expand Down
5 changes: 1 addition & 4 deletions brian2/tests/test_neurongroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,10 +1100,7 @@ def test_state_variables():

G.v = -70 * mV
# Numpy methods should be able to deal with state variables
# (discarding units)
assert_allclose(np.mean(G.v), float(-70 * mV))
# Getting the content should return a Quantity object which then natively
# supports numpy functions that access a method
assert_allclose(np.mean(G.v), -70 * mV)
assert_allclose(np.mean(G.v[:]), -70 * mV)

# You should also be able to set variables with a string
Expand Down
42 changes: 19 additions & 23 deletions brian2/tests/test_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,11 @@ def test_power():
value ** (3 * volt / volt), np.asarray(value) ** 3, kilogram**3
)
with pytest.raises(DimensionMismatchError):
value ** (2 * volt)
# FIXME: Not that if float(exponent) is a special value such as 1 or 2
# numpy will actually use a ufunc such as identity or square, which will
# not raise a DimensionMismatchError. This is a limitation of the current
# implementation.
value ** (2 * mV)
with pytest.raises(TypeError):
value ** np.array([2, 3])

Expand All @@ -710,20 +714,26 @@ def test_power():
def test_inplace_operations():
q = np.arange(10) * volt
q_orig = q.copy()
q_id = id(q)
q_ref = q

q *= 2
assert np.all(q == 2 * q_orig) and id(q) == q_id
assert np.array_equal(q, 2 * q_orig)
assert np.array_equal(q_ref, q)
q /= 2
assert np.all(q == q_orig) and id(q) == q_id
assert np.array_equal(q, q_orig)
assert np.array_equal(q_ref, q)
q += 1 * volt
assert np.all(q == q_orig + 1 * volt) and id(q) == q_id
assert np.array_equal(q, q_orig + 1 * volt)
assert np.array_equal(q_ref, q)
q -= 1 * volt
assert np.all(q == q_orig) and id(q) == q_id
assert np.array_equal(q, q_orig)
assert np.array_equal(q_ref, q)
q **= 2
assert np.all(q == q_orig**2) and id(q) == q_id
assert np.array_equal(q, q_orig**2)
assert np.array_equal(q_ref, q)
q **= 0.5
assert np.all(q == q_orig) and id(q) == q_id
assert np.array_equal(q, q_orig)
assert np.array_equal(q_ref, q)

def illegal_add(q2):
q = np.arange(10) * volt
Expand All @@ -748,7 +758,7 @@ def illegal_pow(q2):
q **= q2

with pytest.raises(DimensionMismatchError):
illegal_pow(1 * volt)
illegal_pow(1 * mV)
with pytest.raises(TypeError):
illegal_pow(np.arange(10))

Expand All @@ -757,7 +767,6 @@ def illegal_pow(q2):
q.__iadd__,
q.__isub__,
q.__imul__,
q.__idiv__,
q.__itruediv__,
q.__ifloordiv__,
q.__imod__,
Expand All @@ -775,7 +784,6 @@ def illegal_pow(q2):
volt.__iadd__,
volt.__isub__,
volt.__imul__,
volt.__idiv__,
volt.__itruediv__,
volt.__ifloordiv__,
volt.__imod__,
Expand All @@ -785,7 +793,6 @@ def illegal_pow(q2):
inplace_op(volt)
for inplace_op in [
volt.dimensions.__imul__,
volt.dimensions.__idiv__,
volt.dimensions.__itruediv__,
volt.dimensions.__ipow__,
]:
Expand Down Expand Up @@ -1217,17 +1224,6 @@ def test_numpy_functions_logical():
# two arguments
result_units = eval(f"np.{ufunc}(value1, value2)")
result_array = eval(f"np.{ufunc}(np.array(value1), np.array(value2))")
# assert that comparing to a string results in "NotImplemented" or an error
try:
result = eval(f'np.{ufunc}(value1, "a string")')
assert result == NotImplemented
except (ValueError, TypeError):
pass # raised on numpy >= 0.10
try:
result = eval(f'np.{ufunc}("a string", value1)')
assert result == NotImplemented
except (ValueError, TypeError):
pass # raised on numpy >= 0.10
assert not isinstance(result_units, Quantity)
assert_equal(result_units, result_array)

Expand Down
Loading

0 comments on commit 657c22b

Please sign in to comment.