Skip to content

Commit

Permalink
Support using assert in kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
shi-eric authored and mmacklin committed Dec 9, 2024
1 parent ab887e3 commit 7933bbb
Show file tree
Hide file tree
Showing 6 changed files with 337 additions and 71 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@

### Added

- Support `assert` statements in kernels ([GH-366](https://github.com/NVIDIA/warp/issues/336)).
Assertions can only be triggered in `"debug"` mode.

### Changed

### Fixed

- warp.sim: Fixed a bug in which the color-balancing algorithm was not updating the colorings.

## [1.5.0] - 2024-12-02
Expand Down
36 changes: 36 additions & 0 deletions docs/debugging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,42 @@ these issues, Warp supports a simple option to print out all launches and argume

wp.config.print_launches = True

Assertions
----------

``assert`` statements can be inserted into Warp kernels and user-defined functions to interrupt the program
execution when a provided Boolean expression evaluates to false. Assertions are only active for a module's kernels
when the module is compiled in debug mode (see the :doc:`/configuration` documentation for how to enable debug-mode
compilation).

The following example will raise an assertion when the kernel is run since the module is compiled
in debug mode and the ``assert`` statement expects that the array passed into the ``expect_ones`` kernel
is an array of ones, but we passed it a single-element array of zeros:

.. code-block:: python
import warp as wp
wp.config.mode = "debug"
@wp.kernel
def expect_ones(a: wp.array(dtype=int)):
i = wp.tid()
assert a[i] == 1, "Array element must be 1"
input_array = wp.zeros(1, dtype=int)
wp.launch(expect_ones, input_array.shape, inputs=[input_array])
wp.synchronize_device()
The output of the program will include a line like the following statement::

default_program:49: void expect_ones_133f9859_cuda_kernel_forward(wp::launch_bounds_t, wp::array_t<int>): block: [0,0,0], thread: [0,0,0] Assertion `("assert a[i] == 1, \"Array element must be 1\"",var_3)` failed.


Step-Through Debugging
----------------------
Expand Down
12 changes: 12 additions & 0 deletions warp/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1855,6 +1855,17 @@ def emit_Ellipsis(adj, node):
# stubbed @wp.native_func
return

def emit_Assert(adj, node):
# eval condition
cond = adj.eval(node.test)
cond = adj.load(cond)

source_segment = ast.get_source_segment(adj.source, node)
# If a message was provided with the assert, " marks can interfere with the generated code
escaped_segment = source_segment.replace('"', '\\"')

adj.add_forward(f'assert(("{escaped_segment}",{cond.emit()}));')

def emit_NameConstant(adj, node):
if node.value:
return adj.add_constant(node.value)
Expand Down Expand Up @@ -2684,6 +2695,7 @@ def emit_Pass(adj, node):
ast.Tuple: emit_Tuple,
ast.Pass: emit_Pass,
ast.Ellipsis: emit_Ellipsis,
ast.Assert: emit_Assert,
}

def eval(adj, node):
Expand Down
242 changes: 242 additions & 0 deletions warp/tests/test_assert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
# Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

import unittest

import warp as wp
from warp.tests.unittest_utils import *


@wp.kernel
def expect_ones(a: wp.array(dtype=int)):
i = wp.tid()

assert a[i] == 1


@wp.kernel
def expect_ones_with_msg(a: wp.array(dtype=int)):
i = wp.tid()

assert a[i] == 1, "Array element must be 1"


@wp.kernel
def expect_ones_compound(a: wp.array(dtype=int)):
i = wp.tid()

assert a[i] > 0 and a[i] < 2


@wp.func
def expect_ones_function(value: int):
assert value == 1, "Array element must be 1"


@wp.kernel
def expect_ones_call_function(a: wp.array(dtype=int)):
i = wp.tid()
expect_ones_function(a[i])


class TestAssertRelease(unittest.TestCase):
"""Assert test cases that are to be run with Warp in release mode."""

@classmethod
def setUpClass(cls):
cls._saved_mode = wp.get_module_options()["mode"]
cls._saved_cache_kernels = wp.config.cache_kernels

wp.config.mode = "release"
wp.config.cache_kernels = False

@classmethod
def tearDownClass(cls):
wp.set_module_options({"mode": cls._saved_mode})
wp.config.cache_kernels = cls._saved_cache_kernels

def test_basic_assert_false_condition(self):
with wp.ScopedDevice("cpu"):
wp.load_module(device=wp.get_device())

input_array = wp.zeros(1, dtype=int)

capture = StdErrCapture()
capture.begin()

wp.launch(expect_ones, input_array.shape, inputs=[input_array])

output = capture.end()

self.assertEqual(output, "", f"Kernel should not print anything to stderr, got {output}")

def test_basic_assert_with_msg(self):
with wp.ScopedDevice("cpu"):
wp.load_module(device=wp.get_device())

input_array = wp.zeros(1, dtype=int)

capture = StdErrCapture()
capture.begin()

wp.launch(expect_ones_with_msg, input_array.shape, inputs=[input_array])

output = capture.end()

self.assertEqual(output, "", f"Kernel should not print anything to stderr, got {output}")

def test_compound_assert_false_condition(self):
with wp.ScopedDevice("cpu"):
wp.load_module(device=wp.get_device())

input_array = wp.full(1, value=3, dtype=int)

capture = StdErrCapture()
capture.begin()

wp.launch(expect_ones_compound, input_array.shape, inputs=[input_array])

output = capture.end()

self.assertEqual(output, "", f"Kernel should not print anything to stderr, got {output}")

def test_basic_assert_false_condition_function(self):
with wp.ScopedDevice("cpu"):
wp.load_module(device=wp.get_device())

input_array = wp.full(1, value=3, dtype=int)

capture = StdErrCapture()
capture.begin()

wp.launch(expect_ones_call_function, input_array.shape, inputs=[input_array])

output = capture.end()

self.assertEqual(output, "", f"Kernel should not print anything to stderr, got {output}")


# NOTE: Failed assertions on CUDA devices leaves the CUDA context in an unrecoverable state,
# so we currently do not test them.
class TestAssertDebug(unittest.TestCase):
"""Assert test cases that are to be run with Warp in debug mode."""

@classmethod
def setUpClass(cls):
cls._saved_mode = wp.get_module_options()["mode"]
cls._saved_cache_kernels = wp.config.cache_kernels

wp.set_module_options({"mode": "debug"})
wp.config.cache_kernels = False

@classmethod
def tearDownClass(cls):
wp.set_module_options({"mode": cls._saved_mode})
wp.config.cache_kernels = cls._saved_cache_kernels

def test_basic_assert_false_condition(self):
with wp.ScopedDevice("cpu"):
wp.load_module(device=wp.get_device())

input_array = wp.zeros(1, dtype=int)

capture = StdErrCapture()
capture.begin()

wp.launch(expect_ones, input_array.shape, inputs=[input_array])

output = capture.end()

# Older Windows C runtimes have a bug where stdout sometimes does not get properly flushed.
if output != "" or sys.platform != "win32":
self.assertRegex(output, r"Assertion failed: .*assert a\[i\] == 1")

def test_basic_assert_true_condition(self):
with wp.ScopedDevice("cpu"):
wp.load_module(device=wp.get_device())

input_array = wp.ones(1, dtype=int)

capture = StdErrCapture()
capture.begin()

wp.launch(expect_ones, input_array.shape, inputs=[input_array])

output = capture.end()

self.assertEqual(output, "", f"Kernel should not print anything to stderr, got {output}")

def test_basic_assert_with_msg(self):
with wp.ScopedDevice("cpu"):
wp.load_module(device=wp.get_device())

input_array = wp.zeros(1, dtype=int)

capture = StdErrCapture()
capture.begin()

wp.launch(expect_ones_with_msg, input_array.shape, inputs=[input_array])

output = capture.end()

# Older Windows C runtimes have a bug where stdout sometimes does not get properly flushed.
if output != "" or sys.platform != "win32":
self.assertRegex(output, r"Assertion failed: .*assert a\[i\] == 1.*Array element must be 1")

def test_compound_assert_true_condition(self):
with wp.ScopedDevice("cpu"):
wp.load_module(device=wp.get_device())

input_array = wp.ones(1, dtype=int)

capture = StdErrCapture()
capture.begin()

wp.launch(expect_ones_compound, input_array.shape, inputs=[input_array])

output = capture.end()

self.assertEqual(output, "", f"Kernel should not print anything to stderr, got {output}")

def test_compound_assert_false_condition(self):
with wp.ScopedDevice("cpu"):
wp.load_module(device=wp.get_device())

input_array = wp.full(1, value=3, dtype=int)

capture = StdErrCapture()
capture.begin()

wp.launch(expect_ones_compound, input_array.shape, inputs=[input_array])

output = capture.end()

# Older Windows C runtimes have a bug where stdout sometimes does not get properly flushed.
if output != "" or sys.platform != "win32":
self.assertRegex(output, r"Assertion failed: .*assert a\[i\] > 0 and a\[i\] < 2")

def test_basic_assert_false_condition_function(self):
with wp.ScopedDevice("cpu"):
wp.load_module(device=wp.get_device())

input_array = wp.full(1, value=3, dtype=int)

capture = StdErrCapture()
capture.begin()

wp.launch(expect_ones_call_function, input_array.shape, inputs=[input_array])

output = capture.end()

# Older Windows C runtimes have a bug where stdout sometimes does not get properly flushed.
if output != "" or sys.platform != "win32":
self.assertRegex(output, r"Assertion failed: .*assert value == 1.*Array element must be 1")


if __name__ == "__main__":
wp.clear_kernel_cache()
unittest.main(verbosity=2)
Loading

0 comments on commit 7933bbb

Please sign in to comment.