Skip to content

Commit

Permalink
Numpy 2.0 compatibility (#42)
Browse files Browse the repository at this point in the history
Fixes `AttributeError: module 'numpy' has no attribute 'product'`
  • Loading branch information
dweindl authored Oct 8, 2024
1 parent 34f8f73 commit 0833d10
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tests/test_derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@


def rosenbrock(input_value, output_shape):
size = np.product(output_shape)
size = np.prod(output_shape)
values = [rosen(input_value + i * 0.01) for i in range(size)]
output = np.array(values).reshape(output_shape)
return output


def rosenbrock_der(input_value, output_shape):
size = np.product(output_shape)
size = np.prod(output_shape)
input_shape = input_value.shape
values = [rosen_der(input_value + i * 0.01) for i in range(size)]
# The input shape is the "deepest" dimension(s), i.e. expect
Expand Down

0 comments on commit 0833d10

Please sign in to comment.