Skip to content

Commit

Permalink
Fix for non-atomic type in-place adds
Browse files Browse the repository at this point in the history
  • Loading branch information
daedalus5 committed Nov 26, 2024
1 parent be78b78 commit fe0215b
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 0 deletions.
16 changes: 16 additions & 0 deletions warp/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2596,6 +2596,22 @@ def make_new_assign_statement():
target_type = strip_reference(target.type)

if is_array(target_type):
# target_types int8, uint8, int16, uint16 are not suitable for atomic array accumulation
if target_type.dtype in warp.types.non_atomic_types:
make_new_assign_statement()
return

# the same holds true for vecs/mats/quats that are composed of these types
if (
type_is_vector(target_type.dtype)
or type_is_quaternion(target_type.dtype)
or type_is_matrix(target_type.dtype)
):
dtype = getattr(target_type.dtype, "_wp_scalar_type_", None)
if dtype in warp.types.non_atomic_types:
make_new_assign_statement()
return

kernel_name = adj.fun_name
filename = adj.filename
lineno = adj.lineno + adj.fun_lineno
Expand Down
14 changes: 14 additions & 0 deletions warp/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# license agreement from NVIDIA CORPORATION is strictly prohibited.

import unittest
from typing import Any

import numpy as np

Expand Down Expand Up @@ -2573,6 +2574,12 @@ def inplace_div_1d(x: wp.array(dtype=float), y: wp.array(dtype=float)):
x[i] /= y[i]


@wp.kernel
def inplace_add_non_atomic_types(x: wp.array(dtype=Any), y: wp.array(dtype=Any)):
i = wp.tid()
x[i] += y[i]


def test_array_inplace_non_diff_ops(test, device):
N = 3
x1 = wp.full(N, value=10.0, dtype=float, device=device)
Expand All @@ -2586,6 +2593,13 @@ def test_array_inplace_non_diff_ops(test, device):
wp.launch(inplace_div_1d, N, inputs=[x1, y1], device=device)
assert_np_equal(x1.numpy(), np.full(N, fill_value=2.0, dtype=float))

for dtype in wp.types.non_atomic_types + (wp.vec2b, wp.vec2ub, wp.vec2s, wp.vec2us):
x = wp.full(N, value=0, dtype=dtype, device=device)
y = wp.full(N, value=1, dtype=dtype, device=device)

wp.launch(inplace_add_non_atomic_types, N, inputs=[x, y], device=device)
assert_np_equal(x.numpy(), y.numpy())


@wp.kernel
def inc_scalar(a: wp.array(dtype=float)):
Expand Down
8 changes: 8 additions & 0 deletions warp/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,6 +1044,14 @@ class spatial_matrixd(matrix(shape=(6, 6), dtype=float64)):
float64: np.float64,
}

non_atomic_types = (
int8,
uint8,
int16,
uint16,
int64,
)


def dtype_from_numpy(numpy_dtype):
"""Return the Warp dtype corresponding to a NumPy dtype."""
Expand Down

0 comments on commit fe0215b

Please sign in to comment.