Skip to content

Commit

Permalink
Gradient updates for in-place add/sub using custom types
Browse files Browse the repository at this point in the history
  • Loading branch information
daedalus5 authored and mmacklin committed Nov 22, 2024
1 parent 37fe982 commit aa52f07
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 42 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
### Changed

### Fixed
- Fix gradient propagation for in-place addition/subtraction operations on custom vector type arrays.

## [1.5.0] - 2024-12-01

Expand Down
5 changes: 0 additions & 5 deletions warp/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2596,11 +2596,6 @@ def make_new_assign_statement():
target_type = strip_reference(target.type)

if is_array(target_type):
# target_type is not suitable for atomic array accumulation
if target_type.dtype not in warp.types.atomic_types:
make_new_assign_statement()
return

kernel_name = adj.fun_name
filename = adj.filename
lineno = adj.lineno + adj.fun_lineno
Expand Down
22 changes: 22 additions & 0 deletions warp/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2429,6 +2429,16 @@ def inplace_add_rhs(x: wp.array(dtype=float), y: wp.array(dtype=float), z: wp.ar
wp.atomic_add(z, 0, a)


vec9 = wp.vec(length=9, dtype=float)


@wp.kernel
def inplace_add_custom_vec(x: wp.array(dtype=vec9), y: wp.array(dtype=vec9)):
i = wp.tid()
x[i] += y[i]
x[i] += y[i]


def test_array_inplace_diff_ops(test, device):
N = 3
x1 = wp.ones(N, dtype=float, requires_grad=True, device=device)
Expand Down Expand Up @@ -2537,6 +2547,18 @@ def test_array_inplace_diff_ops(test, device):

assert_np_equal(x.grad.numpy(), np.ones(1, dtype=float))
assert_np_equal(y.grad.numpy(), np.ones(1, dtype=float))
tape.reset()

x = wp.zeros(1, dtype=vec9, requires_grad=True, device=device)
y = wp.ones(1, dtype=vec9, requires_grad=True, device=device)

with tape:
wp.launch(inplace_add_custom_vec, 1, inputs=[x, y], device=device)

tape.backward(grads={x: wp.ones_like(x)})

assert_np_equal(x.numpy(), np.full((1, 9), 2.0, dtype=float))
assert_np_equal(y.grad.numpy(), np.full((1, 9), 2.0, dtype=float))


@wp.kernel
Expand Down
37 changes: 0 additions & 37 deletions warp/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,43 +995,6 @@ class spatial_matrixd(matrix(shape=(6, 6), dtype=float64)):
spatial_matrixd,
)

atomic_vector_types = (
vec2i,
vec2ui,
vec2l,
vec2ul,
vec2h,
vec2f,
vec2d,
vec3i,
vec3ui,
vec3l,
vec3ul,
vec3h,
vec3f,
vec3d,
vec4i,
vec4ui,
vec4l,
vec4ul,
vec4h,
vec4f,
vec4d,
mat22h,
mat22f,
mat22d,
mat33h,
mat33f,
mat33d,
mat44h,
mat44f,
mat44d,
quath,
quatf,
quatd,
)
atomic_types = float_types + (int32, uint32, int64, uint64) + atomic_vector_types

np_dtype_to_warp_type = {
# Numpy scalar types
np.bool_: bool,
Expand Down

0 comments on commit aa52f07

Please sign in to comment.