diff --git a/CHANGELOG.md b/CHANGELOG.md index d462c0a0..7bdbad7b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ ### Changed ### Fixed +- Fix gradient propagation for in-place addition/subtraction operations on custom vector type arrays. ## [1.5.0] - 2024-12-01 diff --git a/warp/codegen.py b/warp/codegen.py index 5ab3834f..0bda4992 100644 --- a/warp/codegen.py +++ b/warp/codegen.py @@ -2596,11 +2596,6 @@ def make_new_assign_statement(): target_type = strip_reference(target.type) if is_array(target_type): - # target_type is not suitable for atomic array accumulation - if target_type.dtype not in warp.types.atomic_types: - make_new_assign_statement() - return - kernel_name = adj.fun_name filename = adj.filename lineno = adj.lineno + adj.fun_lineno diff --git a/warp/tests/test_array.py b/warp/tests/test_array.py index 66ad8f12..d3b299c6 100644 --- a/warp/tests/test_array.py +++ b/warp/tests/test_array.py @@ -2429,6 +2429,16 @@ def inplace_add_rhs(x: wp.array(dtype=float), y: wp.array(dtype=float), z: wp.ar wp.atomic_add(z, 0, a) +vec9 = wp.vec(length=9, dtype=float) + + +@wp.kernel +def inplace_add_custom_vec(x: wp.array(dtype=vec9), y: wp.array(dtype=vec9)): + i = wp.tid() + x[i] += y[i] + x[i] += y[i] + + def test_array_inplace_diff_ops(test, device): N = 3 x1 = wp.ones(N, dtype=float, requires_grad=True, device=device) @@ -2537,6 +2547,18 @@ def test_array_inplace_diff_ops(test, device): assert_np_equal(x.grad.numpy(), np.ones(1, dtype=float)) assert_np_equal(y.grad.numpy(), np.ones(1, dtype=float)) + tape.reset() + + x = wp.zeros(1, dtype=vec9, requires_grad=True, device=device) + y = wp.ones(1, dtype=vec9, requires_grad=True, device=device) + + with tape: + wp.launch(inplace_add_custom_vec, 1, inputs=[x, y], device=device) + + tape.backward(grads={x: wp.ones_like(x)}) + + assert_np_equal(x.numpy(), np.full((1, 9), 2.0, dtype=float)) + assert_np_equal(y.grad.numpy(), np.full((1, 9), 2.0, dtype=float)) @wp.kernel diff --git a/warp/types.py b/warp/types.py index 688b3b43..64db672d 100644 --- a/warp/types.py +++ b/warp/types.py @@ -995,43 +995,6 @@ class spatial_matrixd(matrix(shape=(6, 6), dtype=float64)): spatial_matrixd, ) -atomic_vector_types = ( - vec2i, - vec2ui, - vec2l, - vec2ul, - vec2h, - vec2f, - vec2d, - vec3i, - vec3ui, - vec3l, - vec3ul, - vec3h, - vec3f, - vec3d, - vec4i, - vec4ui, - vec4l, - vec4ul, - vec4h, - vec4f, - vec4d, - mat22h, - mat22f, - mat22d, - mat33h, - mat33f, - mat33d, - mat44h, - mat44f, - mat44d, - quath, - quatf, - quatd, -) -atomic_types = float_types + (int32, uint32, int64, uint64) + atomic_vector_types - np_dtype_to_warp_type = { # Numpy scalar types np.bool_: bool,