diff --git a/warp/context.py b/warp/context.py index 177641b7..5f7e1d3c 100644 --- a/warp/context.py +++ b/warp/context.py @@ -3062,6 +3062,9 @@ def __init__(self): self.core.radix_sort_pairs_int_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int] self.core.radix_sort_pairs_int_device.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int] + self.core.radix_sort_pairs_float_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int] + self.core.radix_sort_pairs_float_device.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int] + self.core.runlength_encode_int_host.argtypes = [ ctypes.c_uint64, ctypes.c_uint64, diff --git a/warp/native/sort.cpp b/warp/native/sort.cpp index 86c1f97a..9ce42fff 100644 --- a/warp/native/sort.cpp +++ b/warp/native/sort.cpp @@ -77,12 +77,90 @@ void radix_sort_pairs_host(int* keys, int* values, int n) } } + //http://stereopsis.com/radix.html +inline unsigned int radix_float_to_int(float f) +{ + unsigned int i = reinterpret_cast(f); + unsigned int mask = (unsigned int)(-(int)(i >> 31)) | 0x80000000; + return i ^ mask; +} + +void radix_sort_pairs_host(float* keys, int* values, int n) +{ + static unsigned int tables[2][1 << 16]; + memset(tables, 0, sizeof(tables)); + + float* auxKeys = keys + n; + int* auxValues = values + n; + + // build histograms + for (int i=0; i < n; ++i) + { + const unsigned int k = radix_float_to_int(keys[i]); + const unsigned short low = k & 0xffff; + const unsigned short high = k >> 16; + + ++tables[0][low]; + ++tables[1][high]; + } + + // convert histograms to offset tables in-place + unsigned int offlow = 0; + unsigned int offhigh = 0; + + for (int i=0; i < 65536; ++i) + { + const unsigned int newofflow = offlow + tables[0][i]; + const unsigned int newoffhigh = offhigh + tables[1][i]; + + tables[0][i] = offlow; + tables[1][i] = offhigh; + + offlow = newofflow; + offhigh = newoffhigh; + } + + // pass 1 - sort by low 16 bits + for (int i=0; i < n; ++i) + { + // lookup offset of input + const float f = keys[i]; + const unsigned int k = radix_float_to_int(f); + const int v = values[i]; + const unsigned int b = k & 0xffff; + + // find offset and increment + const unsigned int offset = tables[0][b]++; + + auxKeys[offset] = f; + auxValues[offset] = v; + } + + // pass 2 - sort by high 16 bits + for (int i=0; i < n; ++i) + { + // lookup offset of input + const float f = auxKeys[i]; + const unsigned int k = radix_float_to_int(f); + const int v = auxValues[i]; + + const unsigned int b = k >> 16; + + const unsigned int offset = tables[1][b]++; + + keys[offset] = f; + values[offset] = v; + } +} + #if !WP_ENABLE_CUDA void radix_sort_reserve(void* context, int n, void** mem_out, size_t* size_out) {} void radix_sort_pairs_int_device(uint64_t keys, uint64_t values, int n) {} +void radix_sort_pairs_float_device(uint64_t keys, uint64_t values, int n) {} + #endif // !WP_ENABLE_CUDA @@ -92,3 +170,10 @@ void radix_sort_pairs_int_host(uint64_t keys, uint64_t values, int n) reinterpret_cast(keys), reinterpret_cast(values), n); } + +void radix_sort_pairs_float_host(uint64_t keys, uint64_t values, int n) +{ + radix_sort_pairs_host( + reinterpret_cast(keys), + reinterpret_cast(values), n); +} \ No newline at end of file diff --git a/warp/native/sort.cu b/warp/native/sort.cu index fc78b5f0..8e373995 100644 --- a/warp/native/sort.cu +++ b/warp/native/sort.cu @@ -95,3 +95,37 @@ void radix_sort_pairs_int_device(uint64_t keys, uint64_t values, int n) reinterpret_cast(keys), reinterpret_cast(values), n); } + +void radix_sort_pairs_device(void* context, float* keys, int* values, int n) +{ + ContextGuard guard(context); + + cub::DoubleBuffer d_keys(keys, keys + n); + cub::DoubleBuffer d_values(values, values + n); + + RadixSortTemp temp; + radix_sort_reserve(WP_CURRENT_CONTEXT, n, &temp.mem, &temp.size); + + // sort + check_cuda(cub::DeviceRadixSort::SortPairs( + temp.mem, + temp.size, + d_keys, + d_values, + n, 0, 32, + (cudaStream_t)cuda_stream_get_current())); + + if (d_keys.Current() != keys) + memcpy_d2d(WP_CURRENT_CONTEXT, keys, d_keys.Current(), sizeof(float)*n); + + if (d_values.Current() != values) + memcpy_d2d(WP_CURRENT_CONTEXT, values, d_values.Current(), sizeof(int)*n); +} + +void radix_sort_pairs_float_device(uint64_t keys, uint64_t values, int n) +{ + radix_sort_pairs_device( + WP_CURRENT_CONTEXT, + reinterpret_cast(keys), + reinterpret_cast(values), n); +} diff --git a/warp/native/sort.h b/warp/native/sort.h index 9eded4d9..db68e909 100644 --- a/warp/native/sort.h +++ b/warp/native/sort.h @@ -12,4 +12,6 @@ void radix_sort_reserve(void* context, int n, void** mem_out=NULL, size_t* size_out=NULL); void radix_sort_pairs_host(int* keys, int* values, int n); -void radix_sort_pairs_device(void* context, int* keys, int* values, int n); \ No newline at end of file +void radix_sort_pairs_host(float* keys, int* values, int n); +void radix_sort_pairs_device(void* context, int* keys, int* values, int n); +void radix_sort_pairs_device(void* context, float* keys, int* values, int n); \ No newline at end of file diff --git a/warp/native/warp.h b/warp/native/warp.h index 49310dc7..53758c40 100644 --- a/warp/native/warp.h +++ b/warp/native/warp.h @@ -159,6 +159,9 @@ extern "C" WP_API void radix_sort_pairs_int_host(uint64_t keys, uint64_t values, int n); WP_API void radix_sort_pairs_int_device(uint64_t keys, uint64_t values, int n); + WP_API void radix_sort_pairs_float_host(uint64_t keys, uint64_t values, int n); + WP_API void radix_sort_pairs_float_device(uint64_t keys, uint64_t values, int n); + WP_API void runlength_encode_int_host(uint64_t values, uint64_t run_values, uint64_t run_lengths, uint64_t run_count, int n); WP_API void runlength_encode_int_device(uint64_t values, uint64_t run_values, uint64_t run_lengths, uint64_t run_count, int n); diff --git a/warp/tests/test_utils.py b/warp/tests/test_utils.py index 74ce0bbe..4686a9b7 100644 --- a/warp/tests/test_utils.py +++ b/warp/tests/test_utils.py @@ -79,37 +79,49 @@ def test_array_scan_error_unsupported_dtype(test, device): def test_radix_sort_pairs(test, device): - keys = wp.array((7, 2, 8, 4, 1, 6, 5, 3, 0, 0, 0, 0, 0, 0, 0, 0), dtype=int, device=device) - values = wp.array((1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0), dtype=int, device=device) - wp.utils.radix_sort_pairs(keys, values, 8) - assert_np_equal(keys.numpy()[:8], np.array((1, 2, 3, 4, 5, 6, 7, 8))) - assert_np_equal(values.numpy()[:8], np.array((5, 2, 8, 4, 7, 6, 1, 3))) + keyTypes = [int, wp.float32] + + for keyType in keyTypes: + keys = wp.array((7, 2, 8, 4, 1, 6, 5, 3, 0, 0, 0, 0, 0, 0, 0, 0), dtype=keyType, device=device) + values = wp.array((1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0), dtype=int, device=device) + wp.utils.radix_sort_pairs(keys, values, 8) + assert_np_equal(keys.numpy()[:8], np.array((1, 2, 3, 4, 5, 6, 7, 8))) + assert_np_equal(values.numpy()[:8], np.array((5, 2, 8, 4, 7, 6, 1, 3))) def test_radix_sort_pairs_empty(test, device): - keys = wp.array((), dtype=int, device=device) - values = wp.array((), dtype=int, device=device) - wp.utils.radix_sort_pairs(keys, values, 0) + keyTypes = [int, wp.float32] + + for keyType in keyTypes: + keys = wp.array((), dtype=keyType, device=device) + values = wp.array((), dtype=int, device=device) + wp.utils.radix_sort_pairs(keys, values, 0) def test_radix_sort_pairs_error_insufficient_storage(test, device): - keys = wp.array((1, 2, 3), dtype=int, device=device) - values = wp.array((1, 2, 3), dtype=int, device=device) - with test.assertRaisesRegex( - RuntimeError, - r"Array storage must be large enough to contain 2\*count elements$", - ): - wp.utils.radix_sort_pairs(keys, values, 3) + keyTypes = [int, wp.float32] + + for keyType in keyTypes: + keys = wp.array((1, 2, 3), dtype=keyType, device=device) + values = wp.array((1, 2, 3), dtype=int, device=device) + with test.assertRaisesRegex( + RuntimeError, + r"Array storage must be large enough to contain 2\*count elements$", + ): + wp.utils.radix_sort_pairs(keys, values, 3) def test_radix_sort_pairs_error_unsupported_dtype(test, device): - keys = wp.array((1.0, 2.0, 3.0), dtype=float, device=device) - values = wp.array((1.0, 2.0, 3.0), dtype=float, device=device) - with test.assertRaisesRegex( - RuntimeError, - r"Unsupported data type$", - ): - wp.utils.radix_sort_pairs(keys, values, 1) + keyTypes = [int, wp.float32] + + for keyType in keyTypes: + keys = wp.array((1.0, 2.0, 3.0), dtype=keyType, device=device) + values = wp.array((1.0, 2.0, 3.0), dtype=float, device=device) + with test.assertRaisesRegex( + RuntimeError, + r"Unsupported data type$", + ): + wp.utils.radix_sort_pairs(keys, values, 1) def test_array_sum(test, device): diff --git a/warp/utils.py b/warp/utils.py index 63a31ad5..6600e85a 100644 --- a/warp/utils.py +++ b/warp/utils.py @@ -132,11 +132,15 @@ def radix_sort_pairs(keys, values, count: int): if keys.device.is_cpu: if keys.dtype == wp.int32 and values.dtype == wp.int32: runtime.core.radix_sort_pairs_int_host(keys.ptr, values.ptr, count) + elif keys.dtype == wp.float32 and values.dtype == wp.int32: + runtime.core.radix_sort_pairs_float_host(keys.ptr, values.ptr, count) else: raise RuntimeError("Unsupported data type") elif keys.device.is_cuda: if keys.dtype == wp.int32 and values.dtype == wp.int32: runtime.core.radix_sort_pairs_int_device(keys.ptr, values.ptr, count) + elif keys.dtype == wp.float32 and values.dtype == wp.int32: + runtime.core.radix_sort_pairs_float_device(keys.ptr, values.ptr, count) else: raise RuntimeError("Unsupported data type")