diff --git a/tests/extension/oneapi_non_uniform_groups/group_barrier.cpp b/tests/extension/oneapi_non_uniform_groups/group_barrier.cpp index a7e9b357b..8ccdd7332 100644 --- a/tests/extension/oneapi_non_uniform_groups/group_barrier.cpp +++ b/tests/extension/oneapi_non_uniform_groups/group_barrier.cpp @@ -27,21 +27,10 @@ namespace non_uniform_groups::tests { template class test_fence; -TEST_CASE("Non-uniform-group barriers", - "[oneapi_non_uniform_groups][group_func]") { - auto queue = once_per_unit::get_queue(); - - non_uniform_group_barrier>(queue); - non_uniform_group_barrier>( - queue); - non_uniform_group_barrier>( - queue); - non_uniform_group_barrier>( - queue); - non_uniform_group_barrier>( - queue); - non_uniform_group_barrier>(queue); - non_uniform_group_barrier(queue); +TEMPLATE_LIST_TEST_CASE("Non-uniform-group barriers", + "[oneapi_non_uniform_groups][group_func]", + GroupPackTypes) { + for_all_combinations(TestType{}); } } // namespace non_uniform_groups::tests diff --git a/tests/extension/oneapi_non_uniform_groups/group_barrier.h b/tests/extension/oneapi_non_uniform_groups/group_barrier.h index 3c00a4898..26fe8555c 100644 --- a/tests/extension/oneapi_non_uniform_groups/group_barrier.h +++ b/tests/extension/oneapi_non_uniform_groups/group_barrier.h @@ -27,190 +27,195 @@ class non_uniform_group_barrier_kernel; /** * @brief Provides test for arbitraty non-uniform group barriers * @tparam GroupT Type of the non-uniform group to test with - * @tparam T Type pointed by Ptr */ template -void non_uniform_group_barrier(sycl::queue& queue) { - const std::string group_name = NonUniformGroupHelper::get_name(); - - INFO("Testing group-of predicate function for " + group_name); - if (!NonUniformGroupHelper::is_supported(queue.get_device())) { - SKIP("Device does not support " + group_name); - } +struct non_uniform_group_barrier_test { + void operator()() { + auto queue = once_per_unit::get_queue(); + const std::string group_name = NonUniformGroupHelper::get_name(); + + INFO("Testing group-of predicate function for " + group_name); + if (!NonUniformGroupHelper::is_supported(queue.get_device())) { + SKIP("Device does not support " + group_name); + } - std::vector supported_barriers = - queue.get_context() - .get_info(); - - using sms = std::tuple; - // indices of the tuple components - enum s { scope = 0, support = 1, test = 2 }; - - constexpr int non_uniform_group_barrier_variants = 5; - std::array - non_uniform_group_barriers{{{sycl::memory_scope::sub_group, true, true}, - {sycl::memory_scope::sub_group, true, true}, - {sycl::memory_scope::work_group, true, true}, - {sycl::memory_scope::device, true, true}, - {sycl::memory_scope::system, true, true}}}; - std::array - non_uniform_group_barriers_names{ - "default", "sycl::memory_scope::sub_group", - "sycl::memory_scope::work_group", "sycl::memory_scope::device", - "sycl::memory_scope::system"}; - for (auto& barrier : non_uniform_group_barriers) { - auto& sb = supported_barriers; - if (std::find(sb.begin(), sb.end(), std::get(barrier)) == - sb.end()) { - std::get(barrier) = false; + std::vector supported_barriers = + queue.get_context() + .get_info(); + + using sms = std::tuple; + // indices of the tuple components + enum s { scope = 0, support = 1, test = 2 }; + + constexpr int non_uniform_group_barrier_variants = 5; + std::array + non_uniform_group_barriers{ + {{sycl::memory_scope::sub_group, true, true}, + {sycl::memory_scope::sub_group, true, true}, + {sycl::memory_scope::work_group, true, true}, + {sycl::memory_scope::device, true, true}, + {sycl::memory_scope::system, true, true}}}; + std::array + non_uniform_group_barriers_names{ + "default", "sycl::memory_scope::sub_group", + "sycl::memory_scope::work_group", "sycl::memory_scope::device", + "sycl::memory_scope::system"}; + for (auto& barrier : non_uniform_group_barriers) { + auto& sb = supported_barriers; + if (std::find(sb.begin(), sb.end(), std::get(barrier)) == + sb.end()) { + std::get(barrier) = false; + } } - } - using el_type = int32_t; - sycl::device device = queue.get_device(); - - // Check the maximum number elements of type "el_type" that can be - // placed in the device's global and local memory. Since the test - // tries to allocate local and global buffers with a size equal to - // the work group size, the latter must be limited by the allowed - // buffer size. - uint64_t global_mem_size_in_bytes = - device.get_info(); - uint64_t global_mem_size_in_elements = - global_mem_size_in_bytes / sizeof(el_type); - - uint64_t local_mem_size_in_bytes = - device.get_info(); - uint64_t local_mem_size_in_elements = - local_mem_size_in_bytes / sizeof(el_type); - - uint64_t work_items_limit = - std::min(global_mem_size_in_elements, local_mem_size_in_elements); - - sycl::range<1> work_group_range = - sycl_cts::util::work_group_range<1>(queue, work_items_limit); - size_t work_group_size = work_group_range.size(); - - for (size_t test_case = 0; - test_case < NonUniformGroupHelper::num_test_cases; ++test_case) { - const std::string test_case_name = - NonUniformGroupHelper::get_test_case_name(test_case); - INFO("Running test case (" + std::to_string(test_case) + ") with " + - test_case_name); - - std::vector v(work_group_size, 0); - sycl::buffer global_mem(v.data(), - sycl::range<1>(work_group_size)); - - sycl::buffer non_uniform_group_barriers_buf( - non_uniform_group_barriers.data(), - sycl::range<1>(non_uniform_group_barrier_variants)); - - queue.submit([&](sycl::handler& cgh) { - sycl::nd_range<1> executionRange(work_group_range, work_group_range); - - auto non_uniform_group_barriers_acc = - non_uniform_group_barriers_buf - .get_access(cgh); - - sycl::local_accessor local_acc( - sycl::range<1>(work_group_size), cgh); - sycl::accessor global_acc = - global_mem.get_access(cgh); - - cgh.parallel_for>( - executionRange, [=](sycl::nd_item<1> item) { - sycl::sub_group sub_group = item.get_sub_group(); - - // If this item is not participating in the group, leave early. - if (!NonUniformGroupHelper::should_participate(sub_group, - test_case)) - return; - - GroupT non_uniform_group = - NonUniformGroupHelper::create(sub_group, test_case); - - size_t llid = non_uniform_group.get_local_linear_id(); - size_t max_id = non_uniform_group.get_local_linear_range() - 1; - - static_assert(std::is_same_v, - "Return type of group_barrier(GroupT g) is wrong\n"); - static_assert( - std::is_same_v, - "Return type of group_barrier(GroupT g, " - "memory_scope fence_scope) is wrong\n"); - - // test of default barrier - local_acc[llid] = llid; - sycl::group_barrier(non_uniform_group); - - if (local_acc[max_id - llid] != max_id - llid) - std::get(non_uniform_group_barriers_acc[0]) = false; - sycl::group_barrier(non_uniform_group); - - local_acc[llid] = 1; - sycl::group_barrier(non_uniform_group); - - if (local_acc[max_id - llid] != 1) - std::get(non_uniform_group_barriers_acc[0]) = false; - sycl::group_barrier(non_uniform_group); - - // tests for other barriers - for (int i = 1; i < non_uniform_group_barrier_variants; ++i) { - auto& barrier = non_uniform_group_barriers_acc[i]; - - if ((sub_group.get_group_linear_id() == 0) && - (non_uniform_group.get_group_linear_id() == - NonUniformGroupHelper< - GroupT>::preferred_single_worker_group_id(test_case)) && - std::get(barrier)) { - local_acc[llid] = llid; - global_acc[llid] = llid; - sycl::group_barrier(non_uniform_group); - - if (local_acc[max_id - llid] != max_id - llid || - global_acc[max_id - llid] != max_id - llid) - std::get(barrier) = false; - sycl::group_barrier(non_uniform_group); - - switch (std::get(barrier)) { - case sycl::memory_scope::sub_group: - case sycl::memory_scope::work_group: - local_acc[llid] = 1; - sycl::group_barrier(non_uniform_group, - std::get(barrier)); - - if (local_acc[max_id - llid] != 1) - std::get(barrier) = false; - sycl::group_barrier(non_uniform_group); - - [[fallthrough]]; - default: - global_acc[llid] = 1; - sycl::group_barrier(non_uniform_group, - std::get(barrier)); - - if (global_acc[max_id - llid] != 1) - std::get(barrier) = false; - sycl::group_barrier(non_uniform_group); + using el_type = int32_t; + sycl::device device = queue.get_device(); + + // Check the maximum number elements of type "el_type" that can be + // placed in the device's global and local memory. Since the test + // tries to allocate local and global buffers with a size equal to + // the work group size, the latter must be limited by the allowed + // buffer size. + uint64_t global_mem_size_in_bytes = + device.get_info(); + uint64_t global_mem_size_in_elements = + global_mem_size_in_bytes / sizeof(el_type); + + uint64_t local_mem_size_in_bytes = + device.get_info(); + uint64_t local_mem_size_in_elements = + local_mem_size_in_bytes / sizeof(el_type); + + uint64_t work_items_limit = + std::min(global_mem_size_in_elements, local_mem_size_in_elements); + + sycl::range<1> work_group_range = + sycl_cts::util::work_group_range<1>(queue, work_items_limit); + size_t work_group_size = work_group_range.size(); + + for (size_t test_case = 0; + test_case < NonUniformGroupHelper::num_test_cases; + ++test_case) { + const std::string test_case_name = + NonUniformGroupHelper::get_test_case_name(test_case); + INFO("Running test case (" + std::to_string(test_case) + ") with " + + test_case_name); + + std::vector v(work_group_size, 0); + sycl::buffer global_mem(v.data(), + sycl::range<1>(work_group_size)); + + sycl::buffer non_uniform_group_barriers_buf( + non_uniform_group_barriers.data(), + sycl::range<1>(non_uniform_group_barrier_variants)); + + queue.submit([&](sycl::handler& cgh) { + sycl::nd_range<1> executionRange(work_group_range, work_group_range); + + auto non_uniform_group_barriers_acc = + non_uniform_group_barriers_buf + .get_access(cgh); + + sycl::local_accessor local_acc( + sycl::range<1>(work_group_size), cgh); + sycl::accessor global_acc = + global_mem.get_access(cgh); + + cgh.parallel_for>( + executionRange, [=](sycl::nd_item<1> item) { + sycl::sub_group sub_group = item.get_sub_group(); + + // If this item is not participating in the group, leave early. + if (!NonUniformGroupHelper::should_participate(sub_group, + test_case)) + return; + + GroupT non_uniform_group = + NonUniformGroupHelper::create(sub_group, test_case); + + size_t llid = non_uniform_group.get_local_linear_id(); + size_t max_id = non_uniform_group.get_local_linear_range() - 1; + + static_assert( + std::is_same_v, + "Return type of group_barrier(GroupT g) is wrong\n"); + static_assert( + std::is_same_v, + "Return type of group_barrier(GroupT g, " + "memory_scope fence_scope) is wrong\n"); + + // test of default barrier + local_acc[llid] = llid; + sycl::group_barrier(non_uniform_group); + + if (local_acc[max_id - llid] != max_id - llid) + std::get(non_uniform_group_barriers_acc[0]) = false; + sycl::group_barrier(non_uniform_group); + + local_acc[llid] = 1; + sycl::group_barrier(non_uniform_group); + + if (local_acc[max_id - llid] != 1) + std::get(non_uniform_group_barriers_acc[0]) = false; + sycl::group_barrier(non_uniform_group); + + // tests for other barriers + for (int i = 1; i < non_uniform_group_barrier_variants; ++i) { + auto& barrier = non_uniform_group_barriers_acc[i]; + + if ((sub_group.get_group_linear_id() == 0) && + (non_uniform_group.get_group_linear_id() == + NonUniformGroupHelper:: + preferred_single_worker_group_id(test_case)) && + std::get(barrier)) { + local_acc[llid] = llid; + global_acc[llid] = llid; + sycl::group_barrier(non_uniform_group); + + if (local_acc[max_id - llid] != max_id - llid || + global_acc[max_id - llid] != max_id - llid) + std::get(barrier) = false; + sycl::group_barrier(non_uniform_group); + + switch (std::get(barrier)) { + case sycl::memory_scope::sub_group: + case sycl::memory_scope::work_group: + local_acc[llid] = 1; + sycl::group_barrier(non_uniform_group, + std::get(barrier)); + + if (local_acc[max_id - llid] != 1) + std::get(barrier) = false; + sycl::group_barrier(non_uniform_group); + + [[fallthrough]]; + default: + global_acc[llid] = 1; + sycl::group_barrier(non_uniform_group, + std::get(barrier)); + + if (global_acc[max_id - llid] != 1) + std::get(barrier) = false; + sycl::group_barrier(non_uniform_group); + } } } - } - }); - }); - - for (int i = 0; i < non_uniform_group_barrier_variants; ++i) { - bool result = std::get(non_uniform_group_barriers[i]); - std::string work_group = - sycl_cts::util::work_group_print(work_group_range); - CAPTURE(group_name, work_group); - INFO("Result of group_barrier invocation for sub-group and " - << non_uniform_group_barriers_names[i] << " memory scope is " - << (result ? "right" : "wrong")); - CHECK(result); + }); + }); + + for (int i = 0; i < non_uniform_group_barrier_variants; ++i) { + bool result = std::get(non_uniform_group_barriers[i]); + std::string work_group = + sycl_cts::util::work_group_print(work_group_range); + CAPTURE(group_name, work_group); + INFO("Result of group_barrier invocation for sub-group and " + << non_uniform_group_barriers_names[i] << " memory scope is " + << (result ? "right" : "wrong")); + CHECK(result); + } } } -} +}; diff --git a/tests/extension/oneapi_non_uniform_groups/group_broadcast.cpp b/tests/extension/oneapi_non_uniform_groups/group_broadcast.cpp index 410578a94..d7b0f58ac 100644 --- a/tests/extension/oneapi_non_uniform_groups/group_broadcast.cpp +++ b/tests/extension/oneapi_non_uniform_groups/group_broadcast.cpp @@ -22,25 +22,13 @@ namespace non_uniform_groups::tests { -using BroadcastTypes = CustomTypes; - TEMPLATE_LIST_TEST_CASE("Non-uniform group broadcast and select", "[oneapi_non_uniform_groups][group_func][type_list]", - BroadcastTypes) { + GroupPackTypes) { auto queue = once_per_unit::get_queue(); - broadcast_non_uniform_group, - TestType>(queue); - broadcast_non_uniform_group, - TestType>(queue); - broadcast_non_uniform_group, - TestType>(queue); - broadcast_non_uniform_group, - TestType>(queue); - broadcast_non_uniform_group, - TestType>(queue); - broadcast_non_uniform_group, - TestType>(queue); - broadcast_non_uniform_group(queue); + + for_all_combinations( + TestType{}, CustomTypePack{}, queue); } } // namespace non_uniform_groups::tests diff --git a/tests/extension/oneapi_non_uniform_groups/group_broadcast.h b/tests/extension/oneapi_non_uniform_groups/group_broadcast.h index c68e142ab..b742d8012 100644 --- a/tests/extension/oneapi_non_uniform_groups/group_broadcast.h +++ b/tests/extension/oneapi_non_uniform_groups/group_broadcast.h @@ -30,170 +30,177 @@ class broadcast_non_uniform_group_kernel; * @tparam T Type pointed by Ptr */ template -void broadcast_non_uniform_group(sycl::queue& queue) { - const std::string group_name = NonUniformGroupHelper::get_name(); +struct broadcast_non_uniform_group_test { + void operator()(sycl::queue& queue) { + const std::string group_name = NonUniformGroupHelper::get_name(); - INFO("Testing broadcast and select for " + group_name); - if (!NonUniformGroupHelper::is_supported(queue.get_device())) { - SKIP("Device does not support " + group_name); - } - - // 4 functions - constexpr int test_matrix = 4; - const std::string test_names[test_matrix] = { - "T group_broadcast(GroupT g, T x)", - "T group_broadcast(GroupT g, T x, GroupT::linear_id_type " - "local_linear_id)", - "T group_broadcast(GroupT g, T x, GroupT::id_type local_id)", - "T select_from_group(GroupT g, T x, GroupT::id_type local_id)"}; - - sycl::range<1> work_group_range = sycl_cts::util::work_group_range<1>(queue); - - for (size_t test_case = 0; - test_case < NonUniformGroupHelper::num_test_cases; ++test_case) { - const std::string test_case_name = - NonUniformGroupHelper::get_test_case_name(test_case); - INFO("Running test case (" + std::to_string(test_case) + ") with " + - test_case_name); - // array to return results - T origin_values[test_matrix] = {splat_init(0)}; - T broadcasted_values[test_matrix] = {splat_init(0)}; - { - sycl::buffer origin_values_buf(origin_values, - sycl::range<1>(test_matrix)); - sycl::buffer broadcasted_values_buf(broadcasted_values, - sycl::range<1>(test_matrix)); - - queue.submit([&](sycl::handler& cgh) { - auto origin_values_acc = - origin_values_buf - .template get_access(cgh); - auto broadcasted_values_acc = - broadcasted_values_buf - .template get_access(cgh); - - sycl::nd_range<1> executionRange(work_group_range, work_group_range); - // Values computed in a kernel depend on global linear id. We need to - // make sure that there are no overflows - REQUIRE(executionRange.get_global_range().size() < - std::numeric_limits::max() / 100); - - cgh.parallel_for>(executionRange, [=](sycl::nd_item<1> item) { - sycl::sub_group sub_group = item.get_sub_group(); - - // If this item is not participating in the group, leave early. - if (!NonUniformGroupHelper::should_participate(sub_group, - test_case)) - return; - - GroupT non_uniform_group = - NonUniformGroupHelper::create(sub_group, test_case); - - // Each work-item computes a unique value - T value_to_broadcast(splat_init( - static_cast(item.get_global_linear_id() * 100 + - non_uniform_group.get_local_id()))); - - // To simplify the test, we are only checking the first group in - // the first sub-group. - size_t preferred_group_id = - NonUniformGroupHelper::preferred_single_worker_group_id( - test_case); - if (item.get_sub_group().get_group_id()[0] == 0 && - non_uniform_group.get_group_id()[0] == preferred_group_id) { - // Find local id of first, last and some third sub-group item in - // between. Will be used to check different combinations of - // broadcasting and receiving work-items - sycl::id<1> first_id = 0; - sycl::id<1> mid_id = non_uniform_group.get_local_range() / 2; - sycl::id<1> last_id = non_uniform_group.get_local_range() - 1; - - // Broadcast from the first work-item - static_assert( - std::is_same_v, - "Return type of group_broadcast(GroupT g, T x) is wrong\n"); - - if (non_uniform_group.leader()) { - // Work-item which does the broadcast stores value to - // broadcast to use it later as a reference - origin_values_acc[0] = value_to_broadcast; - } - auto broadcasted_value = - sycl::group_broadcast(non_uniform_group, value_to_broadcast); - // We read broadcasted value in another work-item - if (non_uniform_group.get_local_id() == last_id) - broadcasted_values_acc[0] = broadcasted_value; - - // Broadcast from the last work-item - static_assert(std::is_same_v, - "Return type of group_broadcast(GroupT g, T x, " - "GroupT::linear_id_type local_linear_id) is wrong\n"); - - if (non_uniform_group.get_local_id() == last_id) { - // Work-item which does the broadcast stores value to - // broadcast to use it later as a reference - origin_values_acc[1] = value_to_broadcast; - } + INFO("Testing broadcast and select for " + group_name); + if (!NonUniformGroupHelper::is_supported(queue.get_device())) { + SKIP("Device does not support " + group_name); + } - broadcasted_value = sycl::group_broadcast( - non_uniform_group, value_to_broadcast, - non_uniform_group.get_local_linear_range() - 1); - // We read broadcasted value in another work-item - if (non_uniform_group.get_local_id() == mid_id) - broadcasted_values_acc[1] = broadcasted_value; - - // Broadcast from a mid work-item - static_assert(std::is_same_v, - "Return type of group_broadcast(GroupT g, T x, " - "GroupT::id_type local_id) is wrong\n"); - - if (non_uniform_group.get_local_id() == mid_id) { - // Work-item which does the broadcast stores value to - // broadcast to use it later as a reference - origin_values_acc[2] = value_to_broadcast; - } - broadcasted_value = sycl::group_broadcast( - non_uniform_group, value_to_broadcast, mid_id); - // We read broadcasted value in another work-item - if (non_uniform_group.get_local_id() == first_id) - broadcasted_values_acc[2] = broadcasted_value; - - // Select from the first work-item - static_assert(std::is_same_v, - "Return type of select_from_group(GroupT g, T x, " - "GroupT::id_type local_id) is wrong\n"); - - if (non_uniform_group.get_local_id() == first_id) { - // Work-item which does the broadcast stores value to - // broadcast to use it later as a reference - origin_values_acc[3] = value_to_broadcast; + // 4 functions + constexpr int test_matrix = 4; + const std::string test_names[test_matrix] = { + "T group_broadcast(GroupT g, T x)", + "T group_broadcast(GroupT g, T x, GroupT::linear_id_type " + "local_linear_id)", + "T group_broadcast(GroupT g, T x, GroupT::id_type local_id)", + "T select_from_group(GroupT g, T x, GroupT::id_type local_id)"}; + + sycl::range<1> work_group_range = + sycl_cts::util::work_group_range<1>(queue); + + for (size_t test_case = 0; + test_case < NonUniformGroupHelper::num_test_cases; + ++test_case) { + const std::string test_case_name = + NonUniformGroupHelper::get_test_case_name(test_case); + INFO("Running test case (" + std::to_string(test_case) + ") with " + + test_case_name); + // array to return results + T origin_values[test_matrix] = {splat_init(0)}; + T broadcasted_values[test_matrix] = {splat_init(0)}; + { + sycl::buffer origin_values_buf(origin_values, + sycl::range<1>(test_matrix)); + sycl::buffer broadcasted_values_buf(broadcasted_values, + sycl::range<1>(test_matrix)); + + queue.submit([&](sycl::handler& cgh) { + auto origin_values_acc = + origin_values_buf + .template get_access(cgh); + auto broadcasted_values_acc = + broadcasted_values_buf + .template get_access(cgh); + + sycl::nd_range<1> executionRange(work_group_range, work_group_range); + // Values computed in a kernel depend on global linear id. We need to + // make sure that there are no overflows + REQUIRE(executionRange.get_global_range().size() < + std::numeric_limits::max() / 100); + + cgh.parallel_for>(executionRange, [=](sycl::nd_item<1> item) { + sycl::sub_group sub_group = item.get_sub_group(); + + // If this item is not participating in the group, leave early. + if (!NonUniformGroupHelper::should_participate(sub_group, + test_case)) + return; + + GroupT non_uniform_group = + NonUniformGroupHelper::create(sub_group, test_case); + + // Each work-item computes a unique value + T value_to_broadcast(splat_init( + static_cast(item.get_global_linear_id() * 100 + + non_uniform_group.get_local_id()))); + + // To simplify the test, we are only checking the first group in + // the first sub-group. + size_t preferred_group_id = + NonUniformGroupHelper::preferred_single_worker_group_id( + test_case); + if (item.get_sub_group().get_group_id()[0] == 0 && + non_uniform_group.get_group_id()[0] == preferred_group_id) { + // Find local id of first, last and some third sub-group item in + // between. Will be used to check different combinations of + // broadcasting and receiving work-items + sycl::id<1> first_id = 0; + sycl::id<1> mid_id = non_uniform_group.get_local_range() / 2; + sycl::id<1> last_id = non_uniform_group.get_local_range() - 1; + + // Broadcast from the first work-item + static_assert( + std::is_same_v, + "Return type of group_broadcast(GroupT g, T x) is wrong\n"); + + if (non_uniform_group.leader()) { + // Work-item which does the broadcast stores value to + // broadcast to use it later as a reference + origin_values_acc[0] = value_to_broadcast; + } + auto broadcasted_value = + sycl::group_broadcast(non_uniform_group, value_to_broadcast); + // We read broadcasted value in another work-item + if (non_uniform_group.get_local_id() == last_id) + broadcasted_values_acc[0] = broadcasted_value; + + // Broadcast from the last work-item + static_assert( + std::is_same_v, + "Return type of group_broadcast(GroupT g, T x, " + "GroupT::linear_id_type local_linear_id) is wrong\n"); + + if (non_uniform_group.get_local_id() == last_id) { + // Work-item which does the broadcast stores value to + // broadcast to use it later as a reference + origin_values_acc[1] = value_to_broadcast; + } + + broadcasted_value = sycl::group_broadcast( + non_uniform_group, value_to_broadcast, + non_uniform_group.get_local_linear_range() - 1); + // We read broadcasted value in another work-item + if (non_uniform_group.get_local_id() == mid_id) + broadcasted_values_acc[1] = broadcasted_value; + + // Broadcast from a mid work-item + static_assert(std::is_same_v, + "Return type of group_broadcast(GroupT g, T x, " + "GroupT::id_type local_id) is wrong\n"); + + if (non_uniform_group.get_local_id() == mid_id) { + // Work-item which does the broadcast stores value to + // broadcast to use it later as a reference + origin_values_acc[2] = value_to_broadcast; + } + broadcasted_value = sycl::group_broadcast( + non_uniform_group, value_to_broadcast, mid_id); + // We read broadcasted value in another work-item + if (non_uniform_group.get_local_id() == first_id) + broadcasted_values_acc[2] = broadcasted_value; + + // Select from the first work-item + static_assert( + std::is_same_v, + "Return type of select_from_group(GroupT g, T x, " + "GroupT::id_type local_id) is wrong\n"); + + if (non_uniform_group.get_local_id() == first_id) { + // Work-item which does the broadcast stores value to + // broadcast to use it later as a reference + origin_values_acc[3] = value_to_broadcast; + } + broadcasted_value = sycl::select_from_group( + non_uniform_group, value_to_broadcast, first_id); + // We read broadcasted value in another work-item + if (non_uniform_group.get_local_id() == mid_id) + broadcasted_values_acc[3] = broadcasted_value; } - broadcasted_value = sycl::select_from_group( - non_uniform_group, value_to_broadcast, first_id); - // We read broadcasted value in another work-item - if (non_uniform_group.get_local_id() == mid_id) - broadcasted_values_acc[3] = broadcasted_value; - } + }); }); - }); - } - for (int i = 0; i < test_matrix; ++i) { - std::string work_group = - sycl_cts::util::work_group_print(work_group_range); - CAPTURE(group_name, work_group); - INFO("Return value of " - << test_names[i] << " with T = " << type_name() << " is " - << (equal(broadcasted_values[i], origin_values[i]) ? "right" - : "wrong")); - CHECK(equal(broadcasted_values[i], origin_values[i])); + } + for (int i = 0; i < test_matrix; ++i) { + std::string work_group = + sycl_cts::util::work_group_print(work_group_range); + CAPTURE(group_name, work_group); + INFO("Return value of " + << test_names[i] << " with T = " << type_name() << " is " + << (equal(broadcasted_values[i], origin_values[i]) ? "right" + : "wrong")); + CHECK(equal(broadcasted_values[i], origin_values[i])); + } } } -} +}; diff --git a/tests/extension/oneapi_non_uniform_groups/group_broadcast_fp16.cpp b/tests/extension/oneapi_non_uniform_groups/group_broadcast_fp16.cpp index ecf78db75..f28162274 100644 --- a/tests/extension/oneapi_non_uniform_groups/group_broadcast_fp16.cpp +++ b/tests/extension/oneapi_non_uniform_groups/group_broadcast_fp16.cpp @@ -22,24 +22,14 @@ namespace non_uniform_groups::tests { -TEST_CASE("Non-uniform group broadcast and select", - "[oneapi_non_uniform_groups][group_func][fp16]") { +TEMPLATE_LIST_TEST_CASE("Non-uniform group broadcast and select", + "[oneapi_non_uniform_groups][group_func][fp16]", + GroupPackTypes) { auto queue = once_per_unit::get_queue(); + if (queue.get_device().has(sycl::aspect::fp16)) { - broadcast_non_uniform_group, - sycl::half>(queue); - broadcast_non_uniform_group< - oneapi_ext::fixed_size_group<1, sycl::sub_group>, sycl::half>(queue); - broadcast_non_uniform_group< - oneapi_ext::fixed_size_group<2, sycl::sub_group>, sycl::half>(queue); - broadcast_non_uniform_group< - oneapi_ext::fixed_size_group<4, sycl::sub_group>, sycl::half>(queue); - broadcast_non_uniform_group< - oneapi_ext::fixed_size_group<8, sycl::sub_group>, sycl::half>(queue); - broadcast_non_uniform_group, - sycl::half>(queue); - broadcast_non_uniform_group( - queue); + for_all_combinations( + TestType{}, unnamed_type_pack{}, queue); } else { WARN("Device does not support half precision floating point operations."); } diff --git a/tests/extension/oneapi_non_uniform_groups/group_broadcast_fp64.cpp b/tests/extension/oneapi_non_uniform_groups/group_broadcast_fp64.cpp index 1a55529e6..355f7116c 100644 --- a/tests/extension/oneapi_non_uniform_groups/group_broadcast_fp64.cpp +++ b/tests/extension/oneapi_non_uniform_groups/group_broadcast_fp64.cpp @@ -22,22 +22,13 @@ namespace non_uniform_groups::tests { -TEST_CASE("Non-uniform group broadcast and select", "[group_func][fp64]") { +TEMPLATE_LIST_TEST_CASE("Non-uniform group broadcast and select", + "[group_func][fp64]", GroupPackTypes) { auto queue = once_per_unit::get_queue(); + if (queue.get_device().has(sycl::aspect::fp64)) { - broadcast_non_uniform_group, - double>(queue); - broadcast_non_uniform_group< - oneapi_ext::fixed_size_group<1, sycl::sub_group>, double>(queue); - broadcast_non_uniform_group< - oneapi_ext::fixed_size_group<2, sycl::sub_group>, double>(queue); - broadcast_non_uniform_group< - oneapi_ext::fixed_size_group<4, sycl::sub_group>, double>(queue); - broadcast_non_uniform_group< - oneapi_ext::fixed_size_group<8, sycl::sub_group>, double>(queue); - broadcast_non_uniform_group, - double>(queue); - broadcast_non_uniform_group(queue); + for_all_combinations( + TestType{}, unnamed_type_pack{}, queue); } else { WARN("Device does not support double precision floating point operations."); } diff --git a/tests/extension/oneapi_non_uniform_groups/group_joint_reduce.cpp.in b/tests/extension/oneapi_non_uniform_groups/group_joint_reduce.cpp.in index c8ff919dd..781f12986 100644 --- a/tests/extension/oneapi_non_uniform_groups/group_joint_reduce.cpp.in +++ b/tests/extension/oneapi_non_uniform_groups/group_joint_reduce.cpp.in @@ -28,19 +28,12 @@ namespace non_uniform_groups::tests { // clang-format on using ReduceTypes = Types; -TEST_CASE(CTS_TYPE_NAME + " non-uniform group joint reduce functions", - "[oneapi_non_uniform_groups][group_func][type_list]") { +TEMPLATE_LIST_TEST_CASE(CTS_TYPE_NAME + " non-uniform group joint reduce functions", + "[oneapi_non_uniform_groups][group_func][type_list]", GroupPackTypes) { auto queue = once_per_unit::get_queue(); const auto Operators = get_op_types(); const auto RetType = unnamed_type_pack(); - const auto GroupTypes = unnamed_type_pack< - oneapi_ext::ballot_group, - oneapi_ext::fixed_size_group<1, sycl::sub_group>, - oneapi_ext::fixed_size_group<2, sycl::sub_group>, - oneapi_ext::fixed_size_group<4, sycl::sub_group>, - oneapi_ext::fixed_size_group<8, sycl::sub_group>, - oneapi_ext::tangle_group, - oneapi_ext::opportunistic_group>(); + const auto GroupTypes = TestType{}; if constexpr (std::is_same_v, sycl::half>) { if (!queue.get_device().has(sycl::aspect::fp16)) @@ -60,20 +53,13 @@ TEST_CASE(CTS_TYPE_NAME + " non-uniform group joint reduce functions", TEMPLATE_LIST_TEST_CASE( CTS_TYPE_NAME + " non-uniform group joint reduce functions with init", - "[oneapi_non_uniform_groups][group_func][type_list]", ReduceTypes) { + "[oneapi_non_uniform_groups][group_func][type_list]", GroupPackTypes) { auto queue = once_per_unit::get_queue(); const auto Operators = get_op_types(); const auto RetType = unnamed_type_pack(); - const auto ReducedType = unnamed_type_pack(); - const auto GroupTypes = unnamed_type_pack< - oneapi_ext::ballot_group, - oneapi_ext::fixed_size_group<1, sycl::sub_group>, - oneapi_ext::fixed_size_group<2, sycl::sub_group>, - oneapi_ext::fixed_size_group<4, sycl::sub_group>, - oneapi_ext::fixed_size_group<8, sycl::sub_group>, - oneapi_ext::tangle_group, - oneapi_ext::opportunistic_group>(); + const auto ReducedType = Types{}; + const auto GroupTypes = TestType{}; if constexpr (std::is_same_v, sycl::half>) { if (!queue.get_device().has(sycl::aspect::fp16)) diff --git a/tests/extension/oneapi_non_uniform_groups/group_joint_scan.cpp.in b/tests/extension/oneapi_non_uniform_groups/group_joint_scan.cpp.in index c5351edf8..763840ecc 100644 --- a/tests/extension/oneapi_non_uniform_groups/group_joint_scan.cpp.in +++ b/tests/extension/oneapi_non_uniform_groups/group_joint_scan.cpp.in @@ -28,21 +28,14 @@ namespace non_uniform_groups::tests { -using TestType = unnamed_type_pack; +using CurrentType = unnamed_type_pack; using ScanTypes = Types; #endif // !SYCL_CTS_COMPILING_WITH_HIPSYCL -TEST_CASE(CTS_TYPE_NAME + " non-uniform group joint scan functions", - "[oneapi_non_uniform_groups][group_func][type_list]"){ +TEMPLATE_LIST_TEST_CASE( + CTS_TYPE_NAME + " non-uniform group joint scan functions", + "[oneapi_non_uniform_groups][group_func][type_list]", GroupPackTypes){ auto queue = once_per_unit::get_queue(); - const auto GroupTypes = unnamed_type_pack< - oneapi_ext::ballot_group, - oneapi_ext::fixed_size_group<1, sycl::sub_group>, - oneapi_ext::fixed_size_group<2, sycl::sub_group>, - oneapi_ext::fixed_size_group<4, sycl::sub_group>, - oneapi_ext::fixed_size_group<8, sycl::sub_group>, - oneapi_ext::tangle_group, - oneapi_ext::opportunistic_group>(); if constexpr (std::is_same_v, sycl::half>) { if (!queue.get_device().has(sycl::aspect::fp16)) @@ -56,21 +49,14 @@ TEST_CASE(CTS_TYPE_NAME + " non-uniform group joint scan functions", "operations."); } - for_all_combinations(GroupTypes, TestType{}, + for_all_combinations(TestType{}, CurrentType{}, ScanTypes{}, queue); }; -TEST_CASE(CTS_TYPE_NAME + " non-uniform group joint scan functions with init", - "[oneapi_non_uniform_groups][group_func][type_list]"){ +TEMPLATE_LIST_TEST_CASE( + CTS_TYPE_NAME + " non-uniform group joint scan functions with init", + "[oneapi_non_uniform_groups][group_func][type_list]", GroupPackTypes){ auto queue = once_per_unit::get_queue(); - const auto GroupTypes = unnamed_type_pack< - oneapi_ext::ballot_group, - oneapi_ext::fixed_size_group<1, sycl::sub_group>, - oneapi_ext::fixed_size_group<2, sycl::sub_group>, - oneapi_ext::fixed_size_group<4, sycl::sub_group>, - oneapi_ext::fixed_size_group<8, sycl::sub_group>, - oneapi_ext::tangle_group, - oneapi_ext::opportunistic_group>(); if constexpr (std::is_same_v, sycl::half>) { if (!queue.get_device().has(sycl::aspect::fp16)) @@ -85,7 +71,7 @@ TEST_CASE(CTS_TYPE_NAME + " non-uniform group joint scan functions with init", } for_all_combinations( - GroupTypes, TestType{}, ScanTypes{}, ScanTypes{}, queue); + TestType{}, CurrentType{}, ScanTypes{}, ScanTypes{}, queue); }; } // namespace non_uniform_groups::tests diff --git a/tests/extension/oneapi_non_uniform_groups/group_of.cpp b/tests/extension/oneapi_non_uniform_groups/group_of.cpp index 17787b09c..78536e5d0 100644 --- a/tests/extension/oneapi_non_uniform_groups/group_of.cpp +++ b/tests/extension/oneapi_non_uniform_groups/group_of.cpp @@ -23,61 +23,27 @@ namespace non_uniform_groups::tests { // use wide types to exclude truncation of init values -using WideTypes = std::tuple; +static const auto wide_types = + named_type_pack::generate( + "int32_t", "uint32_t", "int64_t", "uint64_t", "float"); TEMPLATE_LIST_TEST_CASE("Non-uniform group joint of bool functions", "[oneapi_non_uniform_groups][group_func][type_list]", - WideTypes) { - auto queue = once_per_unit::get_queue(); - joint_of_group, TestType>(queue); - joint_of_group, TestType>( - queue); - joint_of_group, TestType>( - queue); - joint_of_group, TestType>( - queue); - joint_of_group, TestType>( - queue); - joint_of_group, TestType>(queue); - joint_of_group(queue); + GroupPackTypes) { + for_all_combinations(TestType{}, wide_types); } TEMPLATE_LIST_TEST_CASE( "Non-uniform group of bool functions with predicate functions", - "[oneapi_non_uniform_groups][group_func][type_list]", WideTypes) { - auto queue = once_per_unit::get_queue(); - predicate_function_of_non_uniform_group< - oneapi_ext::ballot_group, TestType>(queue); - predicate_function_of_non_uniform_group< - oneapi_ext::fixed_size_group<1, sycl::sub_group>, TestType>(queue); - predicate_function_of_non_uniform_group< - oneapi_ext::fixed_size_group<2, sycl::sub_group>, TestType>(queue); - predicate_function_of_non_uniform_group< - oneapi_ext::fixed_size_group<4, sycl::sub_group>, TestType>(queue); - predicate_function_of_non_uniform_group< - oneapi_ext::fixed_size_group<8, sycl::sub_group>, TestType>(queue); - predicate_function_of_non_uniform_group< - oneapi_ext::tangle_group, TestType>(queue); - predicate_function_of_non_uniform_group(queue); + "[oneapi_non_uniform_groups][group_func][type_list]", GroupPackTypes) { + for_all_combinations( + TestType{}, wide_types); } -TEST_CASE("Non-uniform group of bool functions", - "[oneapi_non_uniform_groups][group_func]") { - auto queue = once_per_unit::get_queue(); - bool_function_of_non_uniform_group>( - queue); - bool_function_of_non_uniform_group< - oneapi_ext::fixed_size_group<1, sycl::sub_group>>(queue); - bool_function_of_non_uniform_group< - oneapi_ext::fixed_size_group<2, sycl::sub_group>>(queue); - bool_function_of_non_uniform_group< - oneapi_ext::fixed_size_group<4, sycl::sub_group>>(queue); - bool_function_of_non_uniform_group< - oneapi_ext::fixed_size_group<8, sycl::sub_group>>(queue); - bool_function_of_non_uniform_group>( - queue); - bool_function_of_non_uniform_group(queue); +TEMPLATE_LIST_TEST_CASE("Non-uniform group of bool functions", + "[oneapi_non_uniform_groups][group_func]", + GroupPackTypes) { + for_all_combinations(TestType{}); } } // namespace non_uniform_groups::tests diff --git a/tests/extension/oneapi_non_uniform_groups/group_of.h b/tests/extension/oneapi_non_uniform_groups/group_of.h index 3bd5fcb0a..de73fac36 100644 --- a/tests/extension/oneapi_non_uniform_groups/group_of.h +++ b/tests/extension/oneapi_non_uniform_groups/group_of.h @@ -31,127 +31,290 @@ class joint_of_group_kernel; * @tparam T Type pointed by Ptr */ template -void joint_of_group(sycl::queue& queue) { - const std::string group_name = NonUniformGroupHelper::get_name(); +struct joint_of_group_test { + void operator()(const std::string& type_name) { + auto queue = once_per_unit::get_queue(); - INFO("Testing group-of predicate function for " + group_name); - if (!NonUniformGroupHelper::is_supported(queue.get_device())) { - SKIP("Device does not support " + group_name); + const std::string group_name = NonUniformGroupHelper::get_name(); + + INFO("Testing group-of predicate function for " + group_name); + if (!NonUniformGroupHelper::is_supported(queue.get_device())) { + SKIP("Device does not support " + group_name); + } + + // 3 functions * 4 predicates + constexpr int test_matrix = 3; + const std::string test_names[test_matrix] = { + "bool joint_any_of(GroupT g, Ptr first, Ptr last, Predicate pred)", + "bool joint_all_of(GroupT g, Ptr first, Ptr last, Predicate pred)", + "bool joint_none_of(GroupT g, Ptr first, Ptr last, Predicate pred)"}; + constexpr int test_cases = 4; + const std::string test_cases_names[test_cases] = {"none true", "one true", + "some true", "all true"}; + + sycl::range<1> work_group_range = + sycl_cts::util::work_group_range<1>(queue); + size_t work_group_size = work_group_range.size(); + + const size_t sizes[3] = {5, work_group_size / 2, 3 * work_group_size}; + for (size_t test_case = 0; + test_case < NonUniformGroupHelper::num_test_cases; + ++test_case) { + const std::string test_case_name = + NonUniformGroupHelper::get_test_case_name(test_case); + INFO("Running test case (" + std::to_string(test_case) + ") with " + + test_case_name); + + for (size_t size : sizes) { + std::vector v(size); + std::iota(v.begin(), v.end(), 1); + + sycl::buffer v_sycl(v.data(), sycl::range<1>(size)); + + sycl::buffer res_sycl( + sycl::range<2>(work_group_size, test_matrix * test_cases)); + + queue.submit([&](sycl::handler& cgh) { + auto v_acc = + v_sycl.template get_access(cgh); + auto res_acc = + res_sycl.get_access(cgh); + + sycl::nd_range<1> executionRange(work_group_range, work_group_range); + + cgh.parallel_for>(executionRange, [=](sycl::nd_item<1> item) { + size_t gid = item.get_global_linear_id(); + sycl::sub_group sub_group = item.get_sub_group(); + + // If this item is not participating in the group, leave early. + if (!NonUniformGroupHelper::should_participate(sub_group, + test_case)) { + // If an item is not participating, its results are trivially + // correct. + for (unsigned i = 0; i < test_matrix * test_cases; ++i) + res_acc[gid][i] = true; + return; + } + + GroupT non_uniform_group = + NonUniformGroupHelper::create(sub_group, test_case); + + T* v_begin = v_acc.get_pointer(); + T* v_end = v_begin + v_acc.size(); + + // predicates + auto none_true = [&](T i) { return i == 0; }; + auto one_true = [&](T i) { return i == 1; }; + auto some_true = [&](T i) { return i > size / 2; }; + auto all_true = [&](T i) { return i <= size; }; + + static_assert( + std::is_same_v, + "Return type of joint_any_of(GroupT g, Ptr first, Ptr last, " + "Predicate pred) is wrong\n"); + res_acc[gid][0] = !sycl::joint_any_of(non_uniform_group, v_begin, + v_end, none_true); + res_acc[gid][1] = + sycl::joint_any_of(non_uniform_group, v_begin, v_end, one_true); + res_acc[gid][2] = sycl::joint_any_of(non_uniform_group, v_begin, + v_end, some_true); + res_acc[gid][3] = + sycl::joint_any_of(non_uniform_group, v_begin, v_end, all_true); + + static_assert( + std::is_same_v, + "Return type of joint_all_of(GroupT g, Ptr first, Ptr last, " + "Predicate pred) is wrong\n"); + res_acc[gid][4] = !sycl::joint_all_of(non_uniform_group, v_begin, + v_end, none_true); + res_acc[gid][5] = !sycl::joint_all_of(non_uniform_group, v_begin, + v_end, one_true); + res_acc[gid][6] = !sycl::joint_all_of(non_uniform_group, v_begin, + v_end, some_true); + res_acc[gid][7] = + sycl::joint_all_of(non_uniform_group, v_begin, v_end, all_true); + + static_assert( + std::is_same_v, + "Return type of joint_none_of(GroupT g, Ptr first, Ptr last, " + "Predicate pred) is wrong\n"); + res_acc[gid][8] = sycl::joint_none_of(non_uniform_group, v_begin, + v_end, none_true); + res_acc[gid][9] = !sycl::joint_none_of(non_uniform_group, v_begin, + v_end, one_true); + res_acc[gid][10] = !sycl::joint_none_of(non_uniform_group, v_begin, + v_end, some_true); + res_acc[gid][11] = !sycl::joint_none_of(non_uniform_group, v_begin, + v_end, all_true); + }); + }); + { + sycl::host_accessor res_host{res_sycl}; + for (size_t gid = 0; gid < work_group_size; ++gid) { + int index = 0; + for (int i = 0; i < test_matrix; ++i) + for (int j = 0; j < test_cases; ++j) { + std::string work_group = + sycl_cts::util::work_group_print(work_group_range); + CAPTURE(group_name, work_group); + INFO("Value of " << test_names[i] << " with " + << test_cases_names[j] << " for item " << gid + << " predicate is " + << (res_host[gid][index] ? "right" : "wrong")); + CHECK(res_host[gid][index++]); + } + } + } + } + } } +}; - // 3 functions * 4 predicates - constexpr int test_matrix = 3; - const std::string test_names[test_matrix] = { - "bool joint_any_of(GroupT g, Ptr first, Ptr last, Predicate pred)", - "bool joint_all_of(GroupT g, Ptr first, Ptr last, Predicate pred)", - "bool joint_none_of(GroupT g, Ptr first, Ptr last, Predicate pred)"}; - constexpr int test_cases = 4; - const std::string test_cases_names[test_cases] = {"none true", "one true", - "some true", "all true"}; - - sycl::range<1> work_group_range = sycl_cts::util::work_group_range<1>(queue); - size_t work_group_size = work_group_range.size(); - - const size_t sizes[3] = {5, work_group_size / 2, 3 * work_group_size}; - for (size_t test_case = 0; - test_case < NonUniformGroupHelper::num_test_cases; ++test_case) { - const std::string test_case_name = - NonUniformGroupHelper::get_test_case_name(test_case); - INFO("Running test case (" + std::to_string(test_case) + ") with " + - test_case_name); - - for (size_t size : sizes) { - std::vector v(size); - std::iota(v.begin(), v.end(), 1); - - sycl::buffer v_sycl(v.data(), sycl::range<1>(size)); +template +class predicate_function_of_non_uniform_group_kernel; +/** + * @brief Provides test for arbitraty non-uniform group bool of operations with + * predicate functions + * @tparam GroupT Type of the non-uniform group to test with + * @tparam T Type pointed by Ptr + */ +template +struct predicate_function_of_non_uniform_group_test { + void operator()(const std::string& type_name) { + auto queue = once_per_unit::get_queue(); + + const std::string group_name = NonUniformGroupHelper::get_name(); + + // 3 functions * 4 predicates + constexpr int test_matrix = 3; + const std::string test_names[test_matrix] = { + "bool any_of_group(GroupT g, T x, Predicate pred)", + "bool all_of_group(GroupT g, T x, Predicate pred)", + "bool none_of_group(GroupT g, T x, Predicate pred)"}; + constexpr int test_cases = 4; + const std::string test_cases_names[test_cases] = {"none true", "one true", + "some true", "all true"}; + + sycl::range<1> work_group_range = + sycl_cts::util::work_group_range<1>(queue); + + for (size_t test_case = 0; + test_case < NonUniformGroupHelper::num_test_cases; + ++test_case) { + const std::string test_case_name = + NonUniformGroupHelper::get_test_case_name(test_case); + INFO("Running test case (" + std::to_string(test_case) + ") with " + + test_case_name); + + // test cases: 4 predicates * 3 functions + constexpr int total_case_count = test_matrix * test_cases; sycl::buffer res_sycl( - sycl::range<2>(work_group_size, test_matrix * test_cases)); + sycl::range<2>(work_group_range.size(), total_case_count)); queue.submit([&](sycl::handler& cgh) { - auto v_acc = - v_sycl.template get_access(cgh); auto res_acc = res_sycl.get_access(cgh); sycl::nd_range<1> executionRange(work_group_range, work_group_range); - cgh.parallel_for>( - executionRange, [=](sycl::nd_item<1> item) { - size_t gid = item.get_global_linear_id(); - sycl::sub_group sub_group = item.get_sub_group(); - - // If this item is not participating in the group, leave early. - if (!NonUniformGroupHelper::should_participate( - sub_group, test_case)) { - // If an item is not participating, its results are trivially - // correct. - for (unsigned i = 0; i < test_matrix * test_cases; ++i) - res_acc[gid][i] = true; - return; - } - - GroupT non_uniform_group = - NonUniformGroupHelper::create(sub_group, test_case); - - T* v_begin = v_acc.get_pointer(); - T* v_end = v_begin + v_acc.size(); - - // predicates - auto none_true = [&](T i) { return i == 0; }; - auto one_true = [&](T i) { return i == 1; }; - auto some_true = [&](T i) { return i > size / 2; }; - auto all_true = [&](T i) { return i <= size; }; - - static_assert( - std::is_same_v, - "Return type of joint_any_of(GroupT g, Ptr first, Ptr last, " - "Predicate pred) is wrong\n"); - res_acc[gid][0] = !sycl::joint_any_of(non_uniform_group, v_begin, - v_end, none_true); - res_acc[gid][1] = sycl::joint_any_of(non_uniform_group, v_begin, - v_end, one_true); - res_acc[gid][2] = sycl::joint_any_of(non_uniform_group, v_begin, - v_end, some_true); - res_acc[gid][3] = sycl::joint_any_of(non_uniform_group, v_begin, - v_end, all_true); - - static_assert( - std::is_same_v, - "Return type of joint_all_of(GroupT g, Ptr first, Ptr last, " - "Predicate pred) is wrong\n"); - res_acc[gid][4] = !sycl::joint_all_of(non_uniform_group, v_begin, - v_end, none_true); - res_acc[gid][5] = !sycl::joint_all_of(non_uniform_group, v_begin, - v_end, one_true); - res_acc[gid][6] = !sycl::joint_all_of(non_uniform_group, v_begin, - v_end, some_true); - res_acc[gid][7] = sycl::joint_all_of(non_uniform_group, v_begin, - v_end, all_true); - - static_assert( - std::is_same_v, - "Return type of joint_none_of(GroupT g, Ptr first, Ptr last, " - "Predicate pred) is wrong\n"); - res_acc[gid][8] = sycl::joint_none_of(non_uniform_group, v_begin, - v_end, none_true); - res_acc[gid][9] = !sycl::joint_none_of(non_uniform_group, v_begin, - v_end, one_true); - res_acc[gid][10] = !sycl::joint_none_of( - non_uniform_group, v_begin, v_end, some_true); - res_acc[gid][11] = !sycl::joint_none_of(non_uniform_group, - v_begin, v_end, all_true); - }); + cgh.parallel_for>(executionRange, [=](sycl::nd_item<1> item) { + size_t gid = item.get_global_linear_id(); + sycl::sub_group sub_group = item.get_sub_group(); + + // If this item is not participating in the group, leave early. + if (!NonUniformGroupHelper::should_participate(sub_group, + test_case)) + return; + + GroupT non_uniform_group = + NonUniformGroupHelper::create(sub_group, test_case); + + size_t size = non_uniform_group.get_local_linear_range(); + + // Use the non-uniform group local ID (plus 1) as a variable against + // which to test our predicates. Note that this has a well-defined set + // of values [1,2,...,N] where N is the non-uniform group size. Note + // that the non-uniform group could also just be of size 1. + T local_var(non_uniform_group.get_local_linear_id() + 1); + + // predicates + // The variable is never 1 for any member of the non-uniform group + auto none_true = [&](T i) { return i == 0; }; + // Exactly one member of the non-uniform group has value 1 (the first) + auto one_true = [&](T i) { return i == 1; }; + // Some (or all, for non-uniform groups of size 1) members of the + // non-uniform group have this value + auto some_true = [&](T i) { return i > size / 2; }; + // The variable is less than or equal to the non-uniform group size + // for all members of the non-uniform group. + auto all_true = [&](T i) { return i <= size; }; + + { + static_assert( + std::is_same_v, + "Return type of any_of_group(GroupT g, bool pred) is wrong\n"); + res_acc[gid][0] = + !sycl::any_of_group(non_uniform_group, local_var, none_true); + res_acc[gid][1] = + sycl::any_of_group(non_uniform_group, local_var, one_true); + res_acc[gid][2] = + sycl::any_of_group(non_uniform_group, local_var, some_true); + res_acc[gid][3] = + sycl::any_of_group(non_uniform_group, local_var, all_true); + + static_assert( + std::is_same_v, + "Return type of all_of_group(GroupT g, bool pred) is wrong\n"); + res_acc[gid][4] = + !sycl::all_of_group(non_uniform_group, local_var, none_true); + // Note that 'one_true' returns true for the first item. Thus in the + // case that the non-uniform group size is 1, check that all items + // match; otherwise check that not all items match. + res_acc[gid][5] = + sycl::all_of_group(non_uniform_group, local_var, one_true) ^ + (size != 1); + // Note that 'some_true' returns true for the first item if the + // non-uniform group size is 1. In that case, check that all items + // match; otherwise check that not all items match. + res_acc[gid][6] = + sycl::all_of_group(non_uniform_group, local_var, some_true) ^ + (size != 1); + res_acc[gid][7] = + sycl::all_of_group(non_uniform_group, local_var, all_true); + + static_assert( + std::is_same_v, + "Return type of none_of_group(GroupT g, bool pred) is " + "wrong\n"); + res_acc[gid][8] = + sycl::none_of_group(non_uniform_group, local_var, none_true); + res_acc[gid][9] = + !sycl::none_of_group(non_uniform_group, local_var, one_true); + res_acc[gid][10] = + !sycl::none_of_group(non_uniform_group, local_var, some_true); + res_acc[gid][11] = + !sycl::none_of_group(non_uniform_group, local_var, all_true); + } + }); }); + { sycl::host_accessor res_host{res_sycl}; - for (size_t gid = 0; gid < work_group_size; ++gid) { + for (size_t gid = 0; gid < work_group_range.size(); ++gid) { int index = 0; for (int i = 0; i < test_matrix; ++i) for (int j = 0; j < test_cases; ++j) { @@ -168,160 +331,7 @@ void joint_of_group(sycl::queue& queue) { } } } -} - -template -class predicate_function_of_non_uniform_group_kernel; - -/** - * @brief Provides test for arbitraty non-uniform group bool of operations with - * predicate functions - * @tparam GroupT Type of the non-uniform group to test with - * @tparam T Type pointed by Ptr - */ -template -void predicate_function_of_non_uniform_group(sycl::queue& queue) { - const std::string group_name = NonUniformGroupHelper::get_name(); - - INFO("Testing group-of predicate function for " + group_name); - if (!NonUniformGroupHelper::is_supported(queue.get_device())) { - SKIP("Device does not support " + group_name); - } - - // 3 functions * 4 predicates - constexpr int test_matrix = 3; - const std::string test_names[test_matrix] = { - "bool any_of_group(GroupT g, T x, Predicate pred)", - "bool all_of_group(GroupT g, T x, Predicate pred)", - "bool none_of_group(GroupT g, T x, Predicate pred)"}; - constexpr int test_cases = 4; - const std::string test_cases_names[test_cases] = {"none true", "one true", - "some true", "all true"}; - - sycl::range<1> work_group_range = sycl_cts::util::work_group_range<1>(queue); - - for (size_t test_case = 0; - test_case < NonUniformGroupHelper::num_test_cases; ++test_case) { - const std::string test_case_name = - NonUniformGroupHelper::get_test_case_name(test_case); - INFO("Running test case (" + std::to_string(test_case) + ") with " + - test_case_name); - - // test cases: 4 predicates * 3 functions - constexpr int total_case_count = test_matrix * test_cases; - sycl::buffer res_sycl( - sycl::range<2>(work_group_range.size(), total_case_count)); - - queue.submit([&](sycl::handler& cgh) { - auto res_acc = res_sycl.get_access(cgh); - - sycl::nd_range<1> executionRange(work_group_range, work_group_range); - - cgh.parallel_for>(executionRange, [=](sycl::nd_item<1> item) { - size_t gid = item.get_global_linear_id(); - sycl::sub_group sub_group = item.get_sub_group(); - - // If this item is not participating in the group, leave early. - if (!NonUniformGroupHelper::should_participate(sub_group, - test_case)) - return; - - GroupT non_uniform_group = - NonUniformGroupHelper::create(sub_group, test_case); - - size_t size = non_uniform_group.get_local_linear_range(); - - // Use the non-uniform group local ID (plus 1) as a variable against - // which to test our predicates. Note that this has a well-defined set - // of values [1,2,...,N] where N is the non-uniform group size. Note - // that the non-uniform group could also just be of size 1. - T local_var(non_uniform_group.get_local_linear_id() + 1); - - // predicates - // The variable is never 1 for any member of the non-uniform group - auto none_true = [&](T i) { return i == 0; }; - // Exactly one member of the non-uniform group has value 1 (the first) - auto one_true = [&](T i) { return i == 1; }; - // Some (or all, for non-uniform groups of size 1) members of the - // non-uniform group have this value - auto some_true = [&](T i) { return i > size / 2; }; - // The variable is less than or equal to the non-uniform group size - // for all members of the non-uniform group. - auto all_true = [&](T i) { return i <= size; }; - - { - static_assert( - std::is_same_v, - "Return type of any_of_group(GroupT g, bool pred) is wrong\n"); - res_acc[gid][0] = - !sycl::any_of_group(non_uniform_group, local_var, none_true); - res_acc[gid][1] = - sycl::any_of_group(non_uniform_group, local_var, one_true); - res_acc[gid][2] = - sycl::any_of_group(non_uniform_group, local_var, some_true); - res_acc[gid][3] = - sycl::any_of_group(non_uniform_group, local_var, all_true); - - static_assert( - std::is_same_v, - "Return type of all_of_group(GroupT g, bool pred) is wrong\n"); - res_acc[gid][4] = - !sycl::all_of_group(non_uniform_group, local_var, none_true); - // Note that 'one_true' returns true for the first item. Thus in the - // case that the non-uniform group size is 1, check that all items - // match; otherwise check that not all items match. - res_acc[gid][5] = - sycl::all_of_group(non_uniform_group, local_var, one_true) ^ - (size != 1); - // Note that 'some_true' returns true for the first item if the - // non-uniform group size is 1. In that case, check that all items - // match; otherwise check that not all items match. - res_acc[gid][6] = - sycl::all_of_group(non_uniform_group, local_var, some_true) ^ - (size != 1); - res_acc[gid][7] = - sycl::all_of_group(non_uniform_group, local_var, all_true); - - static_assert(std::is_same_v, - "Return type of none_of_group(GroupT g, bool pred) is " - "wrong\n"); - res_acc[gid][8] = - sycl::none_of_group(non_uniform_group, local_var, none_true); - res_acc[gid][9] = - !sycl::none_of_group(non_uniform_group, local_var, one_true); - res_acc[gid][10] = - !sycl::none_of_group(non_uniform_group, local_var, some_true); - res_acc[gid][11] = - !sycl::none_of_group(non_uniform_group, local_var, all_true); - } - }); - }); - - { - sycl::host_accessor res_host{res_sycl}; - for (size_t gid = 0; gid < work_group_range.size(); ++gid) { - int index = 0; - for (int i = 0; i < test_matrix; ++i) - for (int j = 0; j < test_cases; ++j) { - std::string work_group = - sycl_cts::util::work_group_print(work_group_range); - CAPTURE(group_name, work_group); - INFO("Value of " << test_names[i] << " with " << test_cases_names[j] - << " for item " << gid << " predicate is " - << (res_host[gid][index] ? "right" : "wrong")); - CHECK(res_host[gid][index++]); - } - } - } - } -} +}; template class predicate_function_of_non_uniform_group_bool_kernel; @@ -331,147 +341,154 @@ class predicate_function_of_non_uniform_group_bool_kernel; * @tparam GroupT Type of the non-uniform group to test with */ template -void bool_function_of_non_uniform_group(sycl::queue& queue) { - const std::string group_name = NonUniformGroupHelper::get_name(); +struct bool_function_of_non_uniform_group_test { + void operator()() { + auto queue = once_per_unit::get_queue(); + const std::string group_name = NonUniformGroupHelper::get_name(); + + INFO("Testing group-of bool function for " + group_name); + if (!NonUniformGroupHelper::is_supported(queue.get_device())) { + SKIP("Device does not support " + group_name); + } - INFO("Testing group-of bool function for " + group_name); - if (!NonUniformGroupHelper::is_supported(queue.get_device())) { - SKIP("Device does not support " + group_name); - } + // 3 functions * 4 predicates + constexpr int test_matrix = 3; + const std::string test_names[test_matrix] = { + "bool any_of_group(GroupT g, bool pred)", + "bool all_of_group(GroupT g, bool pred)", + "bool none_of_group(GroupT g, bool pred)"}; + constexpr int test_cases = 4; + const std::string test_cases_names[test_cases] = {"none true", "one true", + "some true", "all true"}; + + using T = size_t; + + sycl::range<1> work_group_range = + sycl_cts::util::work_group_range<1>(queue); + + for (size_t test_case = 0; + test_case < NonUniformGroupHelper::num_test_cases; + ++test_case) { + const std::string test_case_name = + NonUniformGroupHelper::get_test_case_name(test_case); + INFO("Running test case (" + std::to_string(test_case) + ") with " + + test_case_name); + + // test cases: 4 predicates * 3 functions + constexpr int total_case_count = test_matrix * test_cases; + sycl::buffer res_sycl( + sycl::range<2>(work_group_range.size(), total_case_count)); - // 3 functions * 4 predicates - constexpr int test_matrix = 3; - const std::string test_names[test_matrix] = { - "bool any_of_group(GroupT g, bool pred)", - "bool all_of_group(GroupT g, bool pred)", - "bool none_of_group(GroupT g, bool pred)"}; - constexpr int test_cases = 4; - const std::string test_cases_names[test_cases] = {"none true", "one true", - "some true", "all true"}; - - using T = size_t; - - sycl::range<1> work_group_range = sycl_cts::util::work_group_range<1>(queue); - - for (size_t test_case = 0; - test_case < NonUniformGroupHelper::num_test_cases; ++test_case) { - const std::string test_case_name = - NonUniformGroupHelper::get_test_case_name(test_case); - INFO("Running test case (" + std::to_string(test_case) + ") with " + - test_case_name); - - // test cases: 4 predicates * 3 functions - constexpr int total_case_count = test_matrix * test_cases; - sycl::buffer res_sycl( - sycl::range<2>(work_group_range.size(), total_case_count)); - - queue.submit([&](sycl::handler& cgh) { - auto res_acc = res_sycl.get_access(cgh); - - sycl::nd_range<1> executionRange(work_group_range, work_group_range); - - cgh.parallel_for>(executionRange, [=](sycl::nd_item<1> item) { - size_t gid = item.get_global_linear_id(); - sycl::sub_group sub_group = item.get_sub_group(); - - // If this item is not participating in the group, leave early. - if (!NonUniformGroupHelper::should_participate(sub_group, - test_case)) - return; - - GroupT non_uniform_group = - NonUniformGroupHelper::create(sub_group, test_case); - - size_t size = non_uniform_group.get_local_linear_range(); - - // Use the non-uniform group local ID (plus 1) as a variable against - // which to test our predicates. Note that this has a well-defined set - // of values [1,2,...,N] where N is the non-uniform group size. Note - // that the non-uniform group could also just be of size 1. - T local_var(non_uniform_group.get_local_linear_id() + 1); - - // predicates - // The variable is never 1 for any member of the non-uniform group - auto none_true = [&](T i) { return i == 0; }; - // Exactly one member of the non-uniform group has value 1 (the first) - auto one_true = [&](T i) { return i == 1; }; - // Some (or all, for non-uniform groups of size 1) members of the - // non-uniform group have this value - auto some_true = [&](T i) { return i > size / 2; }; - // The variable is less than or equal to the non-uniform group size - // for all members of the non-uniform group. - auto all_true = [&](T i) { return i <= size; }; + queue.submit([&](sycl::handler& cgh) { + auto res_acc = res_sycl.get_access(cgh); - { - static_assert( - std::is_same_v, - "Return type of any_of_group(GroupT g, bool pred) is wrong\n"); - res_acc[gid][0] = - !sycl::any_of_group(non_uniform_group, none_true(local_var)); - res_acc[gid][1] = - sycl::any_of_group(non_uniform_group, one_true(local_var)); - res_acc[gid][2] = - sycl::any_of_group(non_uniform_group, some_true(local_var)); - res_acc[gid][3] = - sycl::any_of_group(non_uniform_group, all_true(local_var)); - - static_assert( - std::is_same_v, - "Return type of all_of_group(GroupT g, bool pred) is wrong\n"); - res_acc[gid][4] = - !sycl::all_of_group(non_uniform_group, none_true(local_var)); - // Note that 'one_true' returns true for the first item. Thus in the - // case that the non-uniform group size is 1, check that all items - // match; otherwise check that not all items match. - res_acc[gid][5] = - sycl::all_of_group(non_uniform_group, one_true(local_var)) ^ - (size != 1); - // Note that 'some_true' returns true for the first item if the - // non-uniform group size is 1. In that case, check that all items - // match; otherwise check that not all items match. - res_acc[gid][6] = - sycl::all_of_group(non_uniform_group, some_true(local_var)) ^ - (size != 1); - res_acc[gid][7] = - sycl::all_of_group(non_uniform_group, all_true(local_var)); - - static_assert(std::is_same_v, - "Return type of none_of_group(GroupT g, bool pred) is " - "wrong\n"); - res_acc[gid][8] = - sycl::none_of_group(non_uniform_group, none_true(local_var)); - res_acc[gid][9] = - !sycl::none_of_group(non_uniform_group, one_true(local_var)); - res_acc[gid][10] = - !sycl::none_of_group(non_uniform_group, some_true(local_var)); - res_acc[gid][11] = - !sycl::none_of_group(non_uniform_group, all_true(local_var)); - } - }); - }); - - { - sycl::host_accessor res_host{res_sycl}; - for (size_t gid = 0; gid < work_group_range.size(); ++gid) { - int index = 0; - for (int i = 0; i < test_matrix; ++i) - for (int j = 0; j < test_cases; ++j) { - std::string work_group = - sycl_cts::util::work_group_print(work_group_range); - CAPTURE(group_name, work_group); - INFO("Value of " << test_names[i] << " with " << test_cases_names[j] - << " for item " << gid << " predicate is " - << (res_host[gid][index] ? "right" : "wrong")); - CHECK(res_host[gid][index++]); + sycl::nd_range<1> executionRange(work_group_range, work_group_range); + + cgh.parallel_for>(executionRange, [=](sycl::nd_item<1> item) { + size_t gid = item.get_global_linear_id(); + sycl::sub_group sub_group = item.get_sub_group(); + + // If this item is not participating in the group, leave early. + if (!NonUniformGroupHelper::should_participate(sub_group, + test_case)) + return; + + GroupT non_uniform_group = + NonUniformGroupHelper::create(sub_group, test_case); + + size_t size = non_uniform_group.get_local_linear_range(); + + // Use the non-uniform group local ID (plus 1) as a variable against + // which to test our predicates. Note that this has a well-defined set + // of values [1,2,...,N] where N is the non-uniform group size. Note + // that the non-uniform group could also just be of size 1. + T local_var(non_uniform_group.get_local_linear_id() + 1); + + // predicates + // The variable is never 1 for any member of the non-uniform group + auto none_true = [&](T i) { return i == 0; }; + // Exactly one member of the non-uniform group has value 1 (the first) + auto one_true = [&](T i) { return i == 1; }; + // Some (or all, for non-uniform groups of size 1) members of the + // non-uniform group have this value + auto some_true = [&](T i) { return i > size / 2; }; + // The variable is less than or equal to the non-uniform group size + // for all members of the non-uniform group. + auto all_true = [&](T i) { return i <= size; }; + + { + static_assert( + std::is_same_v, + "Return type of any_of_group(GroupT g, bool pred) is wrong\n"); + res_acc[gid][0] = + !sycl::any_of_group(non_uniform_group, none_true(local_var)); + res_acc[gid][1] = + sycl::any_of_group(non_uniform_group, one_true(local_var)); + res_acc[gid][2] = + sycl::any_of_group(non_uniform_group, some_true(local_var)); + res_acc[gid][3] = + sycl::any_of_group(non_uniform_group, all_true(local_var)); + + static_assert( + std::is_same_v, + "Return type of all_of_group(GroupT g, bool pred) is wrong\n"); + res_acc[gid][4] = + !sycl::all_of_group(non_uniform_group, none_true(local_var)); + // Note that 'one_true' returns true for the first item. Thus in the + // case that the non-uniform group size is 1, check that all items + // match; otherwise check that not all items match. + res_acc[gid][5] = + sycl::all_of_group(non_uniform_group, one_true(local_var)) ^ + (size != 1); + // Note that 'some_true' returns true for the first item if the + // non-uniform group size is 1. In that case, check that all items + // match; otherwise check that not all items match. + res_acc[gid][6] = + sycl::all_of_group(non_uniform_group, some_true(local_var)) ^ + (size != 1); + res_acc[gid][7] = + sycl::all_of_group(non_uniform_group, all_true(local_var)); + + static_assert( + std::is_same_v, + "Return type of none_of_group(GroupT g, bool pred) is " + "wrong\n"); + res_acc[gid][8] = + sycl::none_of_group(non_uniform_group, none_true(local_var)); + res_acc[gid][9] = + !sycl::none_of_group(non_uniform_group, one_true(local_var)); + res_acc[gid][10] = + !sycl::none_of_group(non_uniform_group, some_true(local_var)); + res_acc[gid][11] = + !sycl::none_of_group(non_uniform_group, all_true(local_var)); } + }); + }); + + { + sycl::host_accessor res_host{res_sycl}; + for (size_t gid = 0; gid < work_group_range.size(); ++gid) { + int index = 0; + for (int i = 0; i < test_matrix; ++i) + for (int j = 0; j < test_cases; ++j) { + std::string work_group = + sycl_cts::util::work_group_print(work_group_range); + CAPTURE(group_name, work_group); + INFO("Value of " << test_names[i] << " with " + << test_cases_names[j] << " for item " << gid + << " predicate is " + << (res_host[gid][index] ? "right" : "wrong")); + CHECK(res_host[gid][index++]); + } + } } } } -} +}; diff --git a/tests/extension/oneapi_non_uniform_groups/group_permute.cpp b/tests/extension/oneapi_non_uniform_groups/group_permute.cpp index e187671e2..edaeeddf8 100644 --- a/tests/extension/oneapi_non_uniform_groups/group_permute.cpp +++ b/tests/extension/oneapi_non_uniform_groups/group_permute.cpp @@ -25,22 +25,11 @@ namespace non_uniform_groups::tests { // hipSYCL does not permute right 8-bit types inside groups TEMPLATE_LIST_TEST_CASE("Non-uniform-group permute", "[oneapi_non_uniform_groups][group_func][type_list]", - CustomTypes) { + GroupPackTypes) { auto queue = once_per_unit::get_queue(); - permute_non_uniform_group, - TestType>(queue); - permute_non_uniform_group, - TestType>(queue); - permute_non_uniform_group, - TestType>(queue); - permute_non_uniform_group, - TestType>(queue); - permute_non_uniform_group, - TestType>(queue); - permute_non_uniform_group, - TestType>(queue); - permute_non_uniform_group(queue); + for_all_combinations(TestType{}, + CustomTypePack{}, queue); } } // namespace non_uniform_groups::tests diff --git a/tests/extension/oneapi_non_uniform_groups/group_permute.h b/tests/extension/oneapi_non_uniform_groups/group_permute.h index eff0fd4ca..6b12b8578 100644 --- a/tests/extension/oneapi_non_uniform_groups/group_permute.h +++ b/tests/extension/oneapi_non_uniform_groups/group_permute.h @@ -27,84 +27,92 @@ template class permute_non_uniform_group_kernel; template -void permute_non_uniform_group(sycl::queue& queue) { - const std::string group_name = NonUniformGroupHelper::get_name(); +struct permute_non_uniform_group_test { + void operator()(sycl::queue& queue) { + const std::string group_name = NonUniformGroupHelper::get_name(); - INFO("Testing permute for " + group_name); - if (!NonUniformGroupHelper::is_supported(queue.get_device())) { - SKIP("Device does not support " + group_name); - } + INFO("Testing permute for " + group_name); + if (!NonUniformGroupHelper::is_supported(queue.get_device())) { + SKIP("Device does not support " + group_name); + } - const std::string test_name = - "T permute_group_by_xor(GroupT g, T x, GroupT::linear_id_type mask)"; - - sycl::range<1> work_group_range = sycl_cts::util::work_group_range<1>(queue); - size_t work_group_size = work_group_range.size(); - - for (size_t test_case = 0; - test_case < NonUniformGroupHelper::num_test_cases; ++test_case) { - const std::string test_case_name = - NonUniformGroupHelper::get_test_case_name(test_case); - INFO("Running test case (" + std::to_string(test_case) + ") with " + - test_case_name); - - // array to return results: - std::valarray res(false, work_group_size); - { - sycl::buffer res_sycl(std::begin(res), - sycl::range<1>(work_group_size)); - - queue.submit([&](sycl::handler& cgh) { - auto res_acc = res_sycl.get_access(cgh); - - sycl::nd_range<1> executionRange(work_group_range, work_group_range); - - cgh.parallel_for>( - executionRange, [=](sycl::nd_item<1> item) { - sycl::sub_group sub_group = item.get_sub_group(); - - // If this item is not participating in the group, they fill their - // elements in the result with true and leave early. - if (!NonUniformGroupHelper::should_participate( - sub_group, test_case)) { - res_acc[item.get_local_linear_id()] = true; - return; - } - - GroupT non_uniform_group = - NonUniformGroupHelper::create(sub_group, test_case); - - using lin_id_type = typename GroupT::linear_id_type; - const lin_id_type llid = non_uniform_group.get_local_linear_id(); - - T local_var(splat_init(llid + 1)); - T permuted_var(splat_init(llid + 1)); - - static_assert( - std::is_same_v, - "Return type of permute_group_by_xor(GroupT g, T x, " - "GroupT::linear_id_type mask) is wrong\n"); - - bool res = true; - for (lin_id_type mask = 1u; mask > 0; mask <<= 1) { - permuted_var = sycl::permute_group_by_xor(non_uniform_group, - local_var, mask); - res &= equal(permuted_var, splat_init((llid ^ mask) + 1)) || - ((llid ^ mask) >= - non_uniform_group.get_local_linear_range()); - } - res_acc[item.get_local_linear_id()] = res; - }); - }); + const std::string test_name = + "T permute_group_by_xor(GroupT g, T x, GroupT::linear_id_type mask)"; + + sycl::range<1> work_group_range = + sycl_cts::util::work_group_range<1>(queue); + size_t work_group_size = work_group_range.size(); + + for (size_t test_case = 0; + test_case < NonUniformGroupHelper::num_test_cases; + ++test_case) { + const std::string test_case_name = + NonUniformGroupHelper::get_test_case_name(test_case); + INFO("Running test case (" + std::to_string(test_case) + ") with " + + test_case_name); + + // array to return results: + std::valarray res(false, work_group_size); + { + sycl::buffer res_sycl(std::begin(res), + sycl::range<1>(work_group_size)); + + queue.submit([&](sycl::handler& cgh) { + auto res_acc = + res_sycl.get_access(cgh); + + sycl::nd_range<1> executionRange(work_group_range, work_group_range); + + cgh.parallel_for>( + executionRange, [=](sycl::nd_item<1> item) { + sycl::sub_group sub_group = item.get_sub_group(); + + // If this item is not participating in the group, they fill + // their elements in the result with true and leave early. + if (!NonUniformGroupHelper::should_participate( + sub_group, test_case)) { + res_acc[item.get_local_linear_id()] = true; + return; + } + + GroupT non_uniform_group = + NonUniformGroupHelper::create(sub_group, test_case); + + using lin_id_type = typename GroupT::linear_id_type; + const lin_id_type llid = + non_uniform_group.get_local_linear_id(); + + T local_var(splat_init(llid + 1)); + T permuted_var(splat_init(llid + 1)); + + static_assert( + std::is_same_v, + "Return type of permute_group_by_xor(GroupT g, T x, " + "GroupT::linear_id_type mask) is wrong\n"); + + bool res = true; + for (lin_id_type mask = 1u; mask > 0; mask <<= 1) { + permuted_var = sycl::permute_group_by_xor(non_uniform_group, + local_var, mask); + res &= + equal(permuted_var, splat_init((llid ^ mask) + 1)) || + ((llid ^ mask) >= + non_uniform_group.get_local_linear_range()); + } + res_acc[item.get_local_linear_id()] = res; + }); + }); + } + bool result = res[0]; + for (size_t j = 1; j < work_group_size; ++j) result &= res[j]; + + std::string work_group = + sycl_cts::util::work_group_print(work_group_range); + CAPTURE(group_name, work_group); + INFO("Value of " << test_name << " with T = " << type_name() << " is " + << (result ? "right" : "wrong")); + CHECK(result); } - bool result = res[0]; - for (size_t j = 1; j < work_group_size; ++j) result &= res[j]; - - std::string work_group = sycl_cts::util::work_group_print(work_group_range); - CAPTURE(group_name, work_group); - INFO("Value of " << test_name << " with T = " << type_name() << " is " - << (result ? "right" : "wrong")); - CHECK(result); } -} +}; diff --git a/tests/extension/oneapi_non_uniform_groups/group_permute_fp16.cpp b/tests/extension/oneapi_non_uniform_groups/group_permute_fp16.cpp index 16cbac422..00763dc4f 100644 --- a/tests/extension/oneapi_non_uniform_groups/group_permute_fp16.cpp +++ b/tests/extension/oneapi_non_uniform_groups/group_permute_fp16.cpp @@ -22,25 +22,14 @@ namespace non_uniform_groups::tests { -TEST_CASE("Non-uniform-group permute", - "[oneapi_non_uniform_groups][group_func][fp16]") { +TEMPLATE_LIST_TEST_CASE("Non-uniform-group permute", + "[oneapi_non_uniform_groups][group_func][fp16]", + GroupPackTypes) { auto queue = once_per_unit::get_queue(); if (queue.get_device().has(sycl::aspect::fp16)) { - permute_non_uniform_group, - sycl::half>(queue); - permute_non_uniform_group, - sycl::half>(queue); - permute_non_uniform_group, - sycl::half>(queue); - permute_non_uniform_group, - sycl::half>(queue); - permute_non_uniform_group, - sycl::half>(queue); - permute_non_uniform_group, - sycl::half>(queue); - permute_non_uniform_group( - queue); + for_all_combinations( + TestType{}, unnamed_type_pack{}, queue); } else { WARN("Device does not support half precision floating point operations."); } diff --git a/tests/extension/oneapi_non_uniform_groups/group_permute_fp64.cpp b/tests/extension/oneapi_non_uniform_groups/group_permute_fp64.cpp index 5b56a070e..6564bc1ce 100644 --- a/tests/extension/oneapi_non_uniform_groups/group_permute_fp64.cpp +++ b/tests/extension/oneapi_non_uniform_groups/group_permute_fp64.cpp @@ -22,24 +22,14 @@ namespace non_uniform_groups::tests { -TEST_CASE("Non-uniform-group permute", - "[oneapi_non_uniform_groups][group_func][fp64]") { +TEMPLATE_LIST_TEST_CASE("Non-uniform-group permute", + "[oneapi_non_uniform_groups][group_func][fp64]", + GroupPackTypes) { auto queue = once_per_unit::get_queue(); if (queue.get_device().has(sycl::aspect::fp64)) { - permute_non_uniform_group, - double>(queue); - permute_non_uniform_group, - double>(queue); - permute_non_uniform_group, - double>(queue); - permute_non_uniform_group, - double>(queue); - permute_non_uniform_group, - double>(queue); - permute_non_uniform_group, - double>(queue); - permute_non_uniform_group(queue); + for_all_combinations( + TestType{}, unnamed_type_pack{}, queue); } else { WARN("Device does not support double precision floating point operations."); } diff --git a/tests/extension/oneapi_non_uniform_groups/group_reduce_over_group.cpp.in b/tests/extension/oneapi_non_uniform_groups/group_reduce_over_group.cpp.in index 5b3885deb..a34eb25b4 100644 --- a/tests/extension/oneapi_non_uniform_groups/group_reduce_over_group.cpp.in +++ b/tests/extension/oneapi_non_uniform_groups/group_reduce_over_group.cpp.in @@ -28,20 +28,14 @@ namespace non_uniform_groups::tests { // clang-format on using ReduceTypes = Types; -TEST_CASE(CTS_TYPE_NAME + " non-uniform group reduce functions", - "[oneapi_non_uniform_groups][group_func][type_list]") { +TEMPLATE_LIST_TEST_CASE(CTS_TYPE_NAME + " non-uniform group reduce functions", + "[oneapi_non_uniform_groups][group_func][type_list]", + GroupPackTypes) { auto queue = once_per_unit::get_queue(); // Get binary operators from TestType const auto Operators = get_op_types(); const auto RetType = unnamed_type_pack(); - const auto GroupTypes = unnamed_type_pack< - oneapi_ext::ballot_group, - oneapi_ext::fixed_size_group<1, sycl::sub_group>, - oneapi_ext::fixed_size_group<2, sycl::sub_group>, - oneapi_ext::fixed_size_group<4, sycl::sub_group>, - oneapi_ext::fixed_size_group<8, sycl::sub_group>, - oneapi_ext::tangle_group, - oneapi_ext::opportunistic_group>(); + const auto GroupTypes = TestType{}; if constexpr (std::is_same_v, sycl::half>) { if (!queue.get_device().has(sycl::aspect::fp16)) @@ -62,21 +56,14 @@ TEST_CASE(CTS_TYPE_NAME + " non-uniform group reduce functions", TEMPLATE_LIST_TEST_CASE(CTS_TYPE_NAME + " non-uniform group reduce functions with init", "[oneapi_non_uniform_groups][group_func][type_list]", - ReduceTypes) { + GroupPackTypes) { auto queue = once_per_unit::get_queue(); // Get binary operators from T const auto Operators = get_op_types(); const auto RetType = unnamed_type_pack(); - const auto ReducedType = unnamed_type_pack(); - const auto GroupTypes = unnamed_type_pack< - oneapi_ext::ballot_group, - oneapi_ext::fixed_size_group<1, sycl::sub_group>, - oneapi_ext::fixed_size_group<2, sycl::sub_group>, - oneapi_ext::fixed_size_group<4, sycl::sub_group>, - oneapi_ext::fixed_size_group<8, sycl::sub_group>, - oneapi_ext::tangle_group, - oneapi_ext::opportunistic_group>(); + const auto ReducedType = ReduceTypes{}; + const auto GroupTypes = TestType{}; if constexpr (std::is_same_v, sycl::half>) { if (!queue.get_device().has(sycl::aspect::fp16)) diff --git a/tests/extension/oneapi_non_uniform_groups/group_scan_over_group.cpp.in b/tests/extension/oneapi_non_uniform_groups/group_scan_over_group.cpp.in index c4c1e214a..58c90534d 100644 --- a/tests/extension/oneapi_non_uniform_groups/group_scan_over_group.cpp.in +++ b/tests/extension/oneapi_non_uniform_groups/group_scan_over_group.cpp.in @@ -27,20 +27,13 @@ namespace non_uniform_groups::tests { -using TestType = unnamed_type_pack; +using CurrentType = unnamed_type_pack; using ScanTypes = Types; -TEST_CASE(CTS_TYPE_NAME + " non-uniform group scan functions", - "[oneapi_non_uniform_groups][group_func][type_list]"){ +TEMPLATE_LIST_TEST_CASE(CTS_TYPE_NAME + " non-uniform group scan functions", + "[oneapi_non_uniform_groups][group_func][type_list]", + GroupPackTypes){ auto queue = once_per_unit::get_queue(); - const auto GroupTypes = unnamed_type_pack< - oneapi_ext::ballot_group, - oneapi_ext::fixed_size_group<1, sycl::sub_group>, - oneapi_ext::fixed_size_group<2, sycl::sub_group>, - oneapi_ext::fixed_size_group<4, sycl::sub_group>, - oneapi_ext::fixed_size_group<8, sycl::sub_group>, - oneapi_ext::tangle_group, - oneapi_ext::opportunistic_group>(); if constexpr (std::is_same_v, sycl::half>) { if (!queue.get_device().has(sycl::aspect::fp16)) @@ -53,20 +46,14 @@ TEST_CASE(CTS_TYPE_NAME + " non-uniform group scan functions", "Device does not support double precision floating point " "operations."); } - for_all_combinations(GroupTypes, TestType{}, queue); + for_all_combinations(TestType{}, CurrentType{}, queue); }; -TEST_CASE(CTS_TYPE_NAME + " non-uniform group scan functions with init", - "[oneapi_non_uniform_groups][group_func][type_list]"){ +TEMPLATE_LIST_TEST_CASE( + CTS_TYPE_NAME + " non-uniform group scan functions with init", + "[oneapi_non_uniform_groups][group_func][type_list]", + GroupPackTypes){ auto queue = once_per_unit::get_queue(); - const auto GroupTypes = unnamed_type_pack< - oneapi_ext::ballot_group, - oneapi_ext::fixed_size_group<1, sycl::sub_group>, - oneapi_ext::fixed_size_group<2, sycl::sub_group>, - oneapi_ext::fixed_size_group<4, sycl::sub_group>, - oneapi_ext::fixed_size_group<8, sycl::sub_group>, - oneapi_ext::tangle_group, - oneapi_ext::opportunistic_group>(); if constexpr (std::is_same_v, sycl::half>) { if (!queue.get_device().has(sycl::aspect::fp16)) @@ -79,7 +66,7 @@ TEST_CASE(CTS_TYPE_NAME + " non-uniform group scan functions with init", "Device does not support double precision floating point " "operations."); } - for_all_combinations(GroupTypes, TestType{}, + for_all_combinations(TestType{}, CurrentType{}, ScanTypes{}, queue); }; diff --git a/tests/extension/oneapi_non_uniform_groups/group_shift.cpp b/tests/extension/oneapi_non_uniform_groups/group_shift.cpp index 8b474e15b..e90a674f4 100644 --- a/tests/extension/oneapi_non_uniform_groups/group_shift.cpp +++ b/tests/extension/oneapi_non_uniform_groups/group_shift.cpp @@ -25,22 +25,11 @@ namespace non_uniform_groups::tests { // errors in hipSYCL with bool and 8-bit types - only in group shifts TEMPLATE_LIST_TEST_CASE("Non-uniform-group shift", "[oneapi_non_uniform_groups][group_func][type_list]", - CustomTypes) { + GroupPackTypes) { auto queue = once_per_unit::get_queue(); - shift_non_uniform_group, TestType>( - queue); - shift_non_uniform_group, - TestType>(queue); - shift_non_uniform_group, - TestType>(queue); - shift_non_uniform_group, - TestType>(queue); - shift_non_uniform_group, - TestType>(queue); - shift_non_uniform_group, TestType>( - queue); - shift_non_uniform_group(queue); + for_all_combinations(TestType{}, + CustomTypePack{}, queue); } } // namespace non_uniform_groups::tests diff --git a/tests/extension/oneapi_non_uniform_groups/group_shift.h b/tests/extension/oneapi_non_uniform_groups/group_shift.h index 45cbbe499..a904749bb 100644 --- a/tests/extension/oneapi_non_uniform_groups/group_shift.h +++ b/tests/extension/oneapi_non_uniform_groups/group_shift.h @@ -27,127 +27,126 @@ template class shift_non_uniform_group_kernel; template -void shift_non_uniform_group(sycl::queue& queue) { - const std::string group_name = NonUniformGroupHelper::get_name(); +struct shift_non_uniform_group_test { + void operator()(sycl::queue& queue) { + const std::string group_name = NonUniformGroupHelper::get_name(); - INFO("Testing permute for " + group_name); - if (!NonUniformGroupHelper::is_supported(queue.get_device())) { - SKIP("Device does not support " + group_name); - } - - // 4 functions - constexpr int test_matrix = 4; - const std::string test_names[test_matrix] = { - "T shift_group_left(GroupT g, T x)", - "T shift_group_left(GroupT g, T x, GroupT::linear_id_type delta)", - "T shift_group_right(GroupT g, T x)", - "T shift_group_right(GroupT g, T x, GroupT::linear_id_type delta)"}; - - sycl::range<1> work_group_range = sycl_cts::util::work_group_range<1>(queue); - size_t work_group_size = work_group_range.size(); - - for (size_t test_case = 0; - test_case < NonUniformGroupHelper::num_test_cases; ++test_case) { - const std::string test_case_name = - NonUniformGroupHelper::get_test_case_name(test_case); - INFO("Running test case (" + std::to_string(test_case) + ") with " + - test_case_name); - - // array to return results: - std::valarray res(false, test_matrix * work_group_size); - { - sycl::buffer res_sycl( - std::begin(res), sycl::range<1>(test_matrix * work_group_size)); - - queue.submit([&](sycl::handler& cgh) { - auto res_acc = res_sycl.get_access(cgh); - - sycl::nd_range<1> executionRange(work_group_range, work_group_range); - - cgh.parallel_for>( - executionRange, [=](sycl::nd_item<1> item) { - sycl::sub_group sub_group = item.get_sub_group(); - - // If this item is not participating in the group, they fill their - // elements in the result with true and leave early. - if (!NonUniformGroupHelper::should_participate( - sub_group, test_case)) { - res_acc[0 * work_group_size + item.get_local_linear_id()] = - true; - res_acc[1 * work_group_size + item.get_local_linear_id()] = - true; - res_acc[2 * work_group_size + item.get_local_linear_id()] = - true; - res_acc[3 * work_group_size + item.get_local_linear_id()] = - true; - return; - } - - GroupT non_uniform_group = - NonUniformGroupHelper::create(sub_group, test_case); - const typename GroupT::linear_id_type llid = - non_uniform_group.get_local_linear_id(); - - T local_var(splat_init(llid + 1)); - T shifted_var(splat_init(llid + 1)); - - static_assert( - std::is_same_v, - "Return type of shift_group_left(GroupT g, T x) is wrong\n"); - - shifted_var = - sycl::shift_group_left(non_uniform_group, local_var); - res_acc[0 * work_group_size + item.get_local_linear_id()] = - equal(shifted_var, splat_init(llid + 2)) || - (llid + 1 >= non_uniform_group.get_local_linear_range()); - - static_assert( - std::is_same_v, - "Return type of shift_group_left(GroupT g, T x, " - "GroupT::linear_id_type delta) is wrong\n"); - - shifted_var = - sycl::shift_group_left(non_uniform_group, local_var, 3); - res_acc[1 * work_group_size + item.get_local_linear_id()] = - equal(shifted_var, splat_init(llid + 4)) || - (llid + 3 >= non_uniform_group.get_local_linear_range()); - - static_assert( - std::is_same_v, - "Return type of shift_group_right(GroupT g, T x) is wrong\n"); - - shifted_var = - sycl::shift_group_right(non_uniform_group, local_var); - res_acc[2 * work_group_size + item.get_local_linear_id()] = - equal(shifted_var, splat_init(llid)) || (llid < 1); - - static_assert( - std::is_same_v, - "Return type of shift_group_right(GroupT g, T x, " - "GroupT::linear_id_type delta) is wrong\n"); - - shifted_var = - sycl::shift_group_right(non_uniform_group, local_var, 2); - res_acc[3 * work_group_size + item.get_local_linear_id()] = - equal(shifted_var, splat_init(llid - 1)) || (llid < 2); - }); - }); + INFO("Testing permute for " + group_name); + if (!NonUniformGroupHelper::is_supported(queue.get_device())) { + SKIP("Device does not support " + group_name); } - for (int i = 0; i < test_matrix; ++i) { - bool result = res[i * work_group_size]; - for (size_t j = 1; j < work_group_size; ++j) - result &= res[i * work_group_size + j]; - - std::string work_group = - sycl_cts::util::work_group_print(work_group_range); - CAPTURE(group_name, work_group); - INFO("Value of " << test_names[i] << " with T = " << type_name() - << " is " << (result ? "right" : "wrong")); - CHECK(result); + + // 4 functions + constexpr int test_matrix = 4; + const std::string test_names[test_matrix] = { + "T shift_group_left(GroupT g, T x)", + "T shift_group_left(GroupT g, T x, GroupT::linear_id_type delta)", + "T shift_group_right(GroupT g, T x)", + "T shift_group_right(GroupT g, T x, GroupT::linear_id_type delta)"}; + + sycl::range<1> work_group_range = + sycl_cts::util::work_group_range<1>(queue); + size_t work_group_size = work_group_range.size(); + + for (size_t test_case = 0; + test_case < NonUniformGroupHelper::num_test_cases; + ++test_case) { + const std::string test_case_name = + NonUniformGroupHelper::get_test_case_name(test_case); + INFO("Running test case (" + std::to_string(test_case) + ") with " + + test_case_name); + + // array to return results: + std::valarray res(false, test_matrix * work_group_size); + { + sycl::buffer res_sycl( + std::begin(res), sycl::range<1>(test_matrix * work_group_size)); + + queue.submit([&](sycl::handler& cgh) { + auto res_acc = + res_sycl.get_access(cgh); + + sycl::nd_range<1> executionRange(work_group_range, work_group_range); + + cgh.parallel_for>(executionRange, [=](sycl::nd_item<1> item) { + sycl::sub_group sub_group = item.get_sub_group(); + + // If this item is not participating in the group, they fill their + // elements in the result with true and leave early. + if (!NonUniformGroupHelper::should_participate(sub_group, + test_case)) { + res_acc[0 * work_group_size + item.get_local_linear_id()] = true; + res_acc[1 * work_group_size + item.get_local_linear_id()] = true; + res_acc[2 * work_group_size + item.get_local_linear_id()] = true; + res_acc[3 * work_group_size + item.get_local_linear_id()] = true; + return; + } + + GroupT non_uniform_group = + NonUniformGroupHelper::create(sub_group, test_case); + const typename GroupT::linear_id_type llid = + non_uniform_group.get_local_linear_id(); + + T local_var(splat_init(llid + 1)); + T shifted_var(splat_init(llid + 1)); + + static_assert( + std::is_same_v, + "Return type of shift_group_left(GroupT g, T x) is wrong\n"); + + shifted_var = sycl::shift_group_left(non_uniform_group, local_var); + res_acc[0 * work_group_size + item.get_local_linear_id()] = + equal(shifted_var, splat_init(llid + 2)) || + (llid + 1 >= non_uniform_group.get_local_linear_range()); + + static_assert( + std::is_same_v, + "Return type of shift_group_left(GroupT g, T x, " + "GroupT::linear_id_type delta) is wrong\n"); + + shifted_var = + sycl::shift_group_left(non_uniform_group, local_var, 3); + res_acc[1 * work_group_size + item.get_local_linear_id()] = + equal(shifted_var, splat_init(llid + 4)) || + (llid + 3 >= non_uniform_group.get_local_linear_range()); + + static_assert( + std::is_same_v, + "Return type of shift_group_right(GroupT g, T x) is wrong\n"); + + shifted_var = sycl::shift_group_right(non_uniform_group, local_var); + res_acc[2 * work_group_size + item.get_local_linear_id()] = + equal(shifted_var, splat_init(llid)) || (llid < 1); + + static_assert( + std::is_same_v, + "Return type of shift_group_right(GroupT g, T x, " + "GroupT::linear_id_type delta) is wrong\n"); + + shifted_var = + sycl::shift_group_right(non_uniform_group, local_var, 2); + res_acc[3 * work_group_size + item.get_local_linear_id()] = + equal(shifted_var, splat_init(llid - 1)) || (llid < 2); + }); + }); + } + for (int i = 0; i < test_matrix; ++i) { + bool result = res[i * work_group_size]; + for (size_t j = 1; j < work_group_size; ++j) + result &= res[i * work_group_size + j]; + + std::string work_group = + sycl_cts::util::work_group_print(work_group_range); + CAPTURE(group_name, work_group); + INFO("Value of " << test_names[i] << " with T = " << type_name() + << " is " << (result ? "right" : "wrong")); + CHECK(result); + } } } -} +}; diff --git a/tests/extension/oneapi_non_uniform_groups/group_shift_fp16.cpp b/tests/extension/oneapi_non_uniform_groups/group_shift_fp16.cpp index 483c8daeb..f5794dbb4 100644 --- a/tests/extension/oneapi_non_uniform_groups/group_shift_fp16.cpp +++ b/tests/extension/oneapi_non_uniform_groups/group_shift_fp16.cpp @@ -22,24 +22,14 @@ namespace non_uniform_groups::tests { -TEST_CASE("Non-uniform-group shift", - "[oneapi_non_uniform_groups][group_func][fp16]") { +TEMPLATE_LIST_TEST_CASE("Non-uniform-group shift", + "[oneapi_non_uniform_groups][group_func][fp16]", + GroupPackTypes) { auto queue = once_per_unit::get_queue(); if (queue.get_device().has(sycl::aspect::fp16)) { - shift_non_uniform_group, - sycl::half>(queue); - shift_non_uniform_group, - sycl::half>(queue); - shift_non_uniform_group, - sycl::half>(queue); - shift_non_uniform_group, - sycl::half>(queue); - shift_non_uniform_group, - sycl::half>(queue); - shift_non_uniform_group, - sycl::half>(queue); - shift_non_uniform_group(queue); + for_all_combinations( + TestType{}, unnamed_type_pack{}, queue); } else { WARN("Device does not support half precision floating point operations."); } diff --git a/tests/extension/oneapi_non_uniform_groups/group_shift_fp64.cpp b/tests/extension/oneapi_non_uniform_groups/group_shift_fp64.cpp index 047f10764..0dc5e5f5d 100644 --- a/tests/extension/oneapi_non_uniform_groups/group_shift_fp64.cpp +++ b/tests/extension/oneapi_non_uniform_groups/group_shift_fp64.cpp @@ -22,24 +22,14 @@ namespace non_uniform_groups::tests { -TEST_CASE("Non-uniform-group shift", - "[oneapi_non_uniform_groups][group_func][fp64]") { +TEMPLATE_LIST_TEST_CASE("Non-uniform-group shift", + "[oneapi_non_uniform_groups][group_func][fp64]", + GroupPackTypes) { auto queue = once_per_unit::get_queue(); if (queue.get_device().has(sycl::aspect::fp64)) { - shift_non_uniform_group, double>( - queue); - shift_non_uniform_group, - double>(queue); - shift_non_uniform_group, - double>(queue); - shift_non_uniform_group, - double>(queue); - shift_non_uniform_group, - double>(queue); - shift_non_uniform_group, double>( - queue); - shift_non_uniform_group(queue); + for_all_combinations( + TestType{}, unnamed_type_pack{}, queue); } else { WARN("Device does not support double precision floating point operations."); } diff --git a/tests/extension/oneapi_non_uniform_groups/non_uniform_group_common.h b/tests/extension/oneapi_non_uniform_groups/non_uniform_group_common.h index b011fb563..16af8e4d3 100644 --- a/tests/extension/oneapi_non_uniform_groups/non_uniform_group_common.h +++ b/tests/extension/oneapi_non_uniform_groups/non_uniform_group_common.h @@ -22,6 +22,16 @@ namespace oneapi_ext = sycl::ext::oneapi::experimental; +// Group packs to test for. +using GroupPackTypes = std::tuple< + unnamed_type_pack>, + unnamed_type_pack, + oneapi_ext::fixed_size_group<2, sycl::sub_group>, + oneapi_ext::fixed_size_group<4, sycl::sub_group>, + oneapi_ext::fixed_size_group<8, sycl::sub_group>>, + unnamed_type_pack>, + unnamed_type_pack>; + // Helper class for working with non-uniform group of type GroupT. If the // result is empty the work-item does not participate in the execution. template diff --git a/tests/group_functions/group_functions_common.h b/tests/group_functions/group_functions_common.h index 2c91b71cb..d91a98526 100644 --- a/tests/group_functions/group_functions_common.h +++ b/tests/group_functions/group_functions_common.h @@ -234,14 +234,23 @@ using ExtendedTypes = concatenation, sycl::vec>>::type; +using ExtendedTypePack = + concatenation, + sycl::vec>>::type; #else using ExtendedTypes = concatenation< FundamentalTypes, std::tuple, sycl::vec, sycl::marray, sycl::marray>>::type; +using ExtendedTypePack = concatenation< + Types, unnamed_type_pack< + bool, sycl::vec, sycl::vec, + sycl::marray, sycl::marray>>::type; #endif using CustomTypes = concatenation::type; +using CustomTypePack = concatenation::type; template inline auto get_op_types() {