diff --git a/CHANGELOG.md b/CHANGELOG.md index bc8ee534..6f9431ae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,9 +4,13 @@ ### Added +- Support `assert` statements in kernels ([GH-366](https://github.com/NVIDIA/warp/issues/336)). + Assertions can only be triggered in `"debug"` mode. + ### Changed ### Fixed + - warp.sim: Fixed a bug in which the color-balancing algorithm was not updating the colorings. ## [1.5.0] - 2024-12-02 diff --git a/docs/debugging.rst b/docs/debugging.rst index f2dc0bdc..c123fa4b 100644 --- a/docs/debugging.rst +++ b/docs/debugging.rst @@ -34,6 +34,42 @@ these issues, Warp supports a simple option to print out all launches and argume wp.config.print_launches = True +Assertions +---------- + +``assert`` statements can be inserted into Warp kernels and user-defined functions to interrupt the program +execution when a provided Boolean expression evaluates to false. Assertions are only active for a module's kernels +when the module is compiled in debug mode (see the :doc:`/configuration` documentation for how to enable debug-mode +compilation). + +The following example will raise an assertion when the kernel is run since the module is compiled +in debug mode and the ``assert`` statement expects that the array passed into the ``expect_ones`` kernel +is an array of ones, but we passed it a single-element array of zeros: + +.. code-block:: python + + import warp as wp + + wp.config.mode = "debug" + + + @wp.kernel + def expect_ones(a: wp.array(dtype=int)): + i = wp.tid() + + assert a[i] == 1, "Array element must be 1" + + + input_array = wp.zeros(1, dtype=int) + + wp.launch(expect_ones, input_array.shape, inputs=[input_array]) + + wp.synchronize_device() + +The output of the program will include a line like the following statement:: + + default_program:49: void expect_ones_133f9859_cuda_kernel_forward(wp::launch_bounds_t, wp::array_t): block: [0,0,0], thread: [0,0,0] Assertion `("assert a[i] == 1, \"Array element must be 1\"",var_3)` failed. + Step-Through Debugging ---------------------- diff --git a/warp/codegen.py b/warp/codegen.py index 7b0ff13f..dd022a2c 100644 --- a/warp/codegen.py +++ b/warp/codegen.py @@ -1855,6 +1855,17 @@ def emit_Ellipsis(adj, node): # stubbed @wp.native_func return + def emit_Assert(adj, node): + # eval condition + cond = adj.eval(node.test) + cond = adj.load(cond) + + source_segment = ast.get_source_segment(adj.source, node) + # If a message was provided with the assert, " marks can interfere with the generated code + escaped_segment = source_segment.replace('"', '\\"') + + adj.add_forward(f'assert(("{escaped_segment}",{cond.emit()}));') + def emit_NameConstant(adj, node): if node.value: return adj.add_constant(node.value) @@ -2684,6 +2695,7 @@ def emit_Pass(adj, node): ast.Tuple: emit_Tuple, ast.Pass: emit_Pass, ast.Ellipsis: emit_Ellipsis, + ast.Assert: emit_Assert, } def eval(adj, node): diff --git a/warp/tests/test_assert.py b/warp/tests/test_assert.py new file mode 100644 index 00000000..47822710 --- /dev/null +++ b/warp/tests/test_assert.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved. +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import unittest + +import warp as wp +from warp.tests.unittest_utils import * + + +@wp.kernel +def expect_ones(a: wp.array(dtype=int)): + i = wp.tid() + + assert a[i] == 1 + + +@wp.kernel +def expect_ones_with_msg(a: wp.array(dtype=int)): + i = wp.tid() + + assert a[i] == 1, "Array element must be 1" + + +@wp.kernel +def expect_ones_compound(a: wp.array(dtype=int)): + i = wp.tid() + + assert a[i] > 0 and a[i] < 2 + + +@wp.func +def expect_ones_function(value: int): + assert value == 1, "Array element must be 1" + + +@wp.kernel +def expect_ones_call_function(a: wp.array(dtype=int)): + i = wp.tid() + expect_ones_function(a[i]) + + +class TestAssertRelease(unittest.TestCase): + """Assert test cases that are to be run with Warp in release mode.""" + + @classmethod + def setUpClass(cls): + cls._saved_mode = wp.get_module_options()["mode"] + cls._saved_cache_kernels = wp.config.cache_kernels + + wp.config.mode = "release" + wp.config.cache_kernels = False + + @classmethod + def tearDownClass(cls): + wp.set_module_options({"mode": cls._saved_mode}) + wp.config.cache_kernels = cls._saved_cache_kernels + + def test_basic_assert_false_condition(self): + with wp.ScopedDevice("cpu"): + wp.load_module(device=wp.get_device()) + + input_array = wp.zeros(1, dtype=int) + + capture = StdErrCapture() + capture.begin() + + wp.launch(expect_ones, input_array.shape, inputs=[input_array]) + + output = capture.end() + + self.assertEqual(output, "", f"Kernel should not print anything to stderr, got {output}") + + def test_basic_assert_with_msg(self): + with wp.ScopedDevice("cpu"): + wp.load_module(device=wp.get_device()) + + input_array = wp.zeros(1, dtype=int) + + capture = StdErrCapture() + capture.begin() + + wp.launch(expect_ones_with_msg, input_array.shape, inputs=[input_array]) + + output = capture.end() + + self.assertEqual(output, "", f"Kernel should not print anything to stderr, got {output}") + + def test_compound_assert_false_condition(self): + with wp.ScopedDevice("cpu"): + wp.load_module(device=wp.get_device()) + + input_array = wp.full(1, value=3, dtype=int) + + capture = StdErrCapture() + capture.begin() + + wp.launch(expect_ones_compound, input_array.shape, inputs=[input_array]) + + output = capture.end() + + self.assertEqual(output, "", f"Kernel should not print anything to stderr, got {output}") + + def test_basic_assert_false_condition_function(self): + with wp.ScopedDevice("cpu"): + wp.load_module(device=wp.get_device()) + + input_array = wp.full(1, value=3, dtype=int) + + capture = StdErrCapture() + capture.begin() + + wp.launch(expect_ones_call_function, input_array.shape, inputs=[input_array]) + + output = capture.end() + + self.assertEqual(output, "", f"Kernel should not print anything to stderr, got {output}") + + +# NOTE: Failed assertions on CUDA devices leaves the CUDA context in an unrecoverable state, +# so we currently do not test them. +class TestAssertDebug(unittest.TestCase): + """Assert test cases that are to be run with Warp in debug mode.""" + + @classmethod + def setUpClass(cls): + cls._saved_mode = wp.get_module_options()["mode"] + cls._saved_cache_kernels = wp.config.cache_kernels + + wp.set_module_options({"mode": "debug"}) + wp.config.cache_kernels = False + + @classmethod + def tearDownClass(cls): + wp.set_module_options({"mode": cls._saved_mode}) + wp.config.cache_kernels = cls._saved_cache_kernels + + def test_basic_assert_false_condition(self): + with wp.ScopedDevice("cpu"): + wp.load_module(device=wp.get_device()) + + input_array = wp.zeros(1, dtype=int) + + capture = StdErrCapture() + capture.begin() + + wp.launch(expect_ones, input_array.shape, inputs=[input_array]) + + output = capture.end() + + # Older Windows C runtimes have a bug where stdout sometimes does not get properly flushed. + if output != "" or sys.platform != "win32": + self.assertRegex(output, r"Assertion failed: .*assert a\[i\] == 1") + + def test_basic_assert_true_condition(self): + with wp.ScopedDevice("cpu"): + wp.load_module(device=wp.get_device()) + + input_array = wp.ones(1, dtype=int) + + capture = StdErrCapture() + capture.begin() + + wp.launch(expect_ones, input_array.shape, inputs=[input_array]) + + output = capture.end() + + self.assertEqual(output, "", f"Kernel should not print anything to stderr, got {output}") + + def test_basic_assert_with_msg(self): + with wp.ScopedDevice("cpu"): + wp.load_module(device=wp.get_device()) + + input_array = wp.zeros(1, dtype=int) + + capture = StdErrCapture() + capture.begin() + + wp.launch(expect_ones_with_msg, input_array.shape, inputs=[input_array]) + + output = capture.end() + + # Older Windows C runtimes have a bug where stdout sometimes does not get properly flushed. + if output != "" or sys.platform != "win32": + self.assertRegex(output, r"Assertion failed: .*assert a\[i\] == 1.*Array element must be 1") + + def test_compound_assert_true_condition(self): + with wp.ScopedDevice("cpu"): + wp.load_module(device=wp.get_device()) + + input_array = wp.ones(1, dtype=int) + + capture = StdErrCapture() + capture.begin() + + wp.launch(expect_ones_compound, input_array.shape, inputs=[input_array]) + + output = capture.end() + + self.assertEqual(output, "", f"Kernel should not print anything to stderr, got {output}") + + def test_compound_assert_false_condition(self): + with wp.ScopedDevice("cpu"): + wp.load_module(device=wp.get_device()) + + input_array = wp.full(1, value=3, dtype=int) + + capture = StdErrCapture() + capture.begin() + + wp.launch(expect_ones_compound, input_array.shape, inputs=[input_array]) + + output = capture.end() + + # Older Windows C runtimes have a bug where stdout sometimes does not get properly flushed. + if output != "" or sys.platform != "win32": + self.assertRegex(output, r"Assertion failed: .*assert a\[i\] > 0 and a\[i\] < 2") + + def test_basic_assert_false_condition_function(self): + with wp.ScopedDevice("cpu"): + wp.load_module(device=wp.get_device()) + + input_array = wp.full(1, value=3, dtype=int) + + capture = StdErrCapture() + capture.begin() + + wp.launch(expect_ones_call_function, input_array.shape, inputs=[input_array]) + + output = capture.end() + + # Older Windows C runtimes have a bug where stdout sometimes does not get properly flushed. + if output != "" or sys.platform != "win32": + self.assertRegex(output, r"Assertion failed: .*assert value == 1.*Array element must be 1") + + +if __name__ == "__main__": + wp.clear_kernel_cache() + unittest.main(verbosity=2) diff --git a/warp/tests/test_codegen.py b/warp/tests/test_codegen.py index dc894327..fd63a6ef 100644 --- a/warp/tests/test_codegen.py +++ b/warp/tests/test_codegen.py @@ -396,48 +396,29 @@ def test_unresolved_symbol(test, device): def test_error_global_var(test, device): - arr = wp.array( - (1.0, 2.0, 3.0), - dtype=float, - device=device, - ) + arr = wp.array((1.0, 2.0, 3.0), dtype=float, device=device) - def kernel_1_fn( - out: wp.array(dtype=float), - ): + def kernel_1_fn(out: wp.array(dtype=float)): out[0] = arr[0] - def kernel_2_fn( - out: wp.array(dtype=float), - ): + def kernel_2_fn(out: wp.array(dtype=float)): out[0] = arr - def kernel_3_fn( - out: wp.array(dtype=float), - ): + def kernel_3_fn(out: wp.array(dtype=float)): out[0] = wp.lower_bound(arr, 2.0) out = wp.empty_like(arr) kernel = wp.Kernel(func=kernel_1_fn) - with test.assertRaisesRegex( - TypeError, - r"Invalid external reference type: ", - ): + with test.assertRaisesRegex(TypeError, r"Invalid external reference type: "): wp.launch(kernel, dim=out.shape, inputs=(), outputs=(out,), device=device) kernel = wp.Kernel(func=kernel_2_fn) - with test.assertRaisesRegex( - TypeError, - r"Invalid external reference type: ", - ): + with test.assertRaisesRegex(TypeError, r"Invalid external reference type: "): wp.launch(kernel, dim=out.shape, inputs=(), outputs=(out,), device=device) kernel = wp.Kernel(func=kernel_3_fn) - with test.assertRaisesRegex( - TypeError, - r"Invalid external reference type: ", - ): + with test.assertRaisesRegex(TypeError, r"Invalid external reference type: "): wp.launch(kernel, dim=out.shape, inputs=(), outputs=(out,), device=device) @@ -469,16 +450,12 @@ def kernel_4_fn(): wp.launch(kernel, dim=1, device=device) kernel = wp.Kernel(func=kernel_3_fn) - with test.assertRaisesRegex( - RuntimeError, - r"Construct `ast.Dict` not supported in kernels.", - ): + with test.assertRaisesRegex(RuntimeError, r"Construct `ast.Dict` not supported in kernels."): wp.launch(kernel, dim=1, device=device) kernel = wp.Kernel(func=kernel_4_fn) with test.assertRaisesRegex( - RuntimeError, - r"Tuple constructs are not supported in kernels. Use vectors like `wp.vec3\(\)` instead.", + RuntimeError, r"Tuple constructs are not supported in kernels. Use vectors like `wp.vec3\(\)` instead." ): wp.launch(kernel, dim=1, device=device) @@ -491,10 +468,7 @@ def kernel_2_fn(): x = wp.dot(wp.vec2(1.0, 2.0), wp.vec2h(wp.float16(1.0), wp.float16(2.0))) kernel = wp.Kernel(func=kernel_1_fn) - with test.assertRaisesRegex( - RuntimeError, - r"Input types must be the same, got \['int32', 'float32'\]", - ): + with test.assertRaisesRegex(RuntimeError, r"Input types must be the same, got \['int32', 'float32'\]"): wp.launch(kernel, dim=1, device=device) kernel = wp.Kernel(func=kernel_2_fn) @@ -704,12 +678,7 @@ class TestCodeGen(unittest.TestCase): TestCodeGen, name="test_dynamic_for_rename", kernel=test_dynamic_for_rename, inputs=[10], dim=1, devices=devices ) add_kernel_test( - TestCodeGen, - name="test_dynamic_for_inplace", - kernel=test_dynamic_for_inplace, - inputs=[10], - dim=1, - devices=devices, + TestCodeGen, name="test_dynamic_for_inplace", kernel=test_dynamic_for_inplace, inputs=[10], dim=1, devices=devices ) add_kernel_test(TestCodeGen, name="test_reassign", kernel=test_reassign, dim=1, devices=devices) add_kernel_test( @@ -754,12 +723,7 @@ class TestCodeGen(unittest.TestCase): ) add_kernel_test( - TestCodeGen, - name="test_range_static_sum", - kernel=test_range_static_sum, - dim=1, - expect=[10, 10, 10], - devices=devices, + TestCodeGen, name="test_range_static_sum", kernel=test_range_static_sum, dim=1, expect=[10, 10, 10], devices=devices ) add_kernel_test( TestCodeGen, @@ -789,20 +753,9 @@ class TestCodeGen(unittest.TestCase): devices=devices, ) add_kernel_test( - TestCodeGen, - name="test_range_dynamic_nested", - kernel=test_range_dynamic_nested, - dim=1, - inputs=[4], - devices=devices, -) -add_kernel_test( - TestCodeGen, - name="test_range_expression", - kernel=test_range_expression, - dim=1, - devices=devices, + TestCodeGen, name="test_range_dynamic_nested", kernel=test_range_dynamic_nested, dim=1, inputs=[4], devices=devices ) +add_kernel_test(TestCodeGen, name="test_range_expression", kernel=test_range_expression, dim=1, devices=devices) add_kernel_test(TestCodeGen, name="test_while_zero", kernel=test_while, dim=1, inputs=[0], devices=devices) add_kernel_test(TestCodeGen, name="test_while_positive", kernel=test_while, dim=1, inputs=[16], devices=devices) diff --git a/warp/tests/unittest_utils.py b/warp/tests/unittest_utils.py index a94e6a36..6d668ce1 100644 --- a/warp/tests/unittest_utils.py +++ b/warp/tests/unittest_utils.py @@ -128,8 +128,13 @@ def get_cuda_test_devices(mode=None): return [d for d in devices if d.is_cuda] -# redirects and captures all stdout output (including from C-libs) -class StdOutCapture: +class StreamCapture: + def __init__(self, stream_name): + self.stream_name = stream_name # 'stdout' or 'stderr' + self.saved = None + self.target = None + self.tempfile = None + def begin(self): # Flush the stream buffers managed by libc. # This is needed at the moment due to Carbonite not flushing the logs @@ -137,14 +142,15 @@ def begin(self): if LIBC is not None: LIBC.fflush(None) - # save original - self.saved = sys.stdout + # Get the stream object (sys.stdout or sys.stderr) + self.saved = getattr(sys, self.stream_name) self.target = os.dup(self.saved.fileno()) # create temporary capture stream import io import tempfile + # Create temporary capture stream self.tempfile = io.TextIOWrapper( tempfile.TemporaryFile(buffering=0), encoding="utf-8", @@ -153,33 +159,46 @@ def begin(self): write_through=True, ) + # Redirect the stream os.dup2(self.tempfile.fileno(), self.saved.fileno()) - - sys.stdout = self.tempfile + setattr(sys, self.stream_name, self.tempfile) def end(self): # The following sleep doesn't seem to fix the test_print failure on Windows # if sys.platform == "win32": # # Workaround for what seems to be a Windows-specific bug where - # # the output of CUDA's `printf` is not being immediately flushed - # # despite the context synchronisation. + # # the output of CUDA's printf is not being immediately flushed + # # despite the context synchronization. # time.sleep(0.01) - if LIBC is not None: LIBC.fflush(None) + # Restore the original stream os.dup2(self.target, self.saved.fileno()) os.close(self.target) + # Read the captured output self.tempfile.seek(0) res = self.tempfile.buffer.read() self.tempfile.close() - sys.stdout = self.saved + # Restore the stream object + setattr(sys, self.stream_name, self.saved) return str(res.decode("utf-8")) +# Subclasses for specific streams +class StdErrCapture(StreamCapture): + def __init__(self): + super().__init__("stderr") + + +class StdOutCapture(StreamCapture): + def __init__(self): + super().__init__("stdout") + + class CheckOutput: def __init__(self, test): self.test = test