Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve kernel name generation for unnamed lambda (kernel templates) #1524

Merged
merged 4 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions test/kt/esimd_radix_sort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,13 +185,13 @@ template <typename T, bool IsAscending, std::uint8_t RadixBits, typename KernelP
void
test_general_cases(sycl::queue q, std::size_t size, KernelParam param)
{
test_usm<T, IsAscending, RadixBits, sycl::usm::alloc::shared>(q, size, TestUtils::get_new_kernel_params<0>(param));
test_usm<T, IsAscending, RadixBits, sycl::usm::alloc::device>(q, size, TestUtils::get_new_kernel_params<1>(param));
test_sycl_iterators<T, IsAscending, RadixBits>(q, size, TestUtils::get_new_kernel_params<2>(param));
test_sycl_buffer<T, IsAscending, RadixBits>(q, size, TestUtils::get_new_kernel_params<3>(param));
test_usm<T, IsAscending, RadixBits, sycl::usm::alloc::shared>(q, size, TestUtils::create_new_kernel_param_idx<0>(param));
test_usm<T, IsAscending, RadixBits, sycl::usm::alloc::device>(q, size, TestUtils::create_new_kernel_param_idx<1>(param));
test_sycl_iterators<T, IsAscending, RadixBits>(q, size, TestUtils::create_new_kernel_param_idx<2>(param));
test_sycl_buffer<T, IsAscending, RadixBits>(q, size, TestUtils::create_new_kernel_param_idx<3>(param));
#if _ENABLE_RANGES_TESTING
test_all_view<T, IsAscending, RadixBits>(q, size, TestUtils::get_new_kernel_params<4>(param));
test_subrange_view<T, IsAscending, RadixBits>(q, size, TestUtils::get_new_kernel_params<5>(param));
test_all_view<T, IsAscending, RadixBits>(q, size, TestUtils::create_new_kernel_param_idx<4>(param));
test_subrange_view<T, IsAscending, RadixBits>(q, size, TestUtils::create_new_kernel_param_idx<5>(param));
#endif // _ENABLE_RANGES_TESTING
}

Expand All @@ -217,11 +217,11 @@ main()
for (auto size : sort_sizes)
{
test_general_cases<TEST_KEY_TYPE, Ascending, TestRadixBits>(
q, size, TestUtils::get_new_kernel_params<0>(params));
q, size, TestUtils::create_new_kernel_param_idx<0>(params));
test_general_cases<TEST_KEY_TYPE, Descending, TestRadixBits>(
q, size, TestUtils::get_new_kernel_params<1>(params));
q, size, TestUtils::create_new_kernel_param_idx<1>(params));
}
test_small_sizes<TEST_KEY_TYPE, Ascending, TestRadixBits>(q, TestUtils::get_new_kernel_params<3>(params));
test_small_sizes<TEST_KEY_TYPE, Ascending, TestRadixBits>(q, TestUtils::create_new_kernel_param_idx<3>(params));
}
catch (const ::std::exception& exc)
{
Expand Down
8 changes: 4 additions & 4 deletions test/kt/esimd_radix_sort_by_key.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,13 @@ int main()
for (auto size : sort_sizes)
{
test_usm<TEST_KEY_TYPE, TEST_VALUE_TYPE, Ascending, TestRadixBits, sycl::usm::alloc::shared>(
q, size, TestUtils::get_new_kernel_params<0>(params));
q, size, TestUtils::create_new_kernel_param_idx<0>(params));
test_usm<TEST_KEY_TYPE, TEST_VALUE_TYPE, Descending, TestRadixBits, sycl::usm::alloc::shared>(
q, size, TestUtils::get_new_kernel_params<1>(params));
q, size, TestUtils::create_new_kernel_param_idx<1>(params));
test_sycl_buffer<TEST_KEY_TYPE, TEST_VALUE_TYPE, Ascending, TestRadixBits>(
q, size, TestUtils::get_new_kernel_params<2>(params));
q, size, TestUtils::create_new_kernel_param_idx<2>(params));
test_sycl_buffer<TEST_KEY_TYPE, TEST_VALUE_TYPE, Descending, TestRadixBits>(
q, size, TestUtils::get_new_kernel_params<3>(params));
q, size, TestUtils::create_new_kernel_param_idx<3>(params));
}
}
catch (const ::std::exception& exc)
Expand Down
8 changes: 4 additions & 4 deletions test/kt/esimd_radix_sort_by_key_out_of_place.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,13 @@ main()
for (auto size : sort_sizes)
{
test_usm<TEST_KEY_TYPE, TEST_VALUE_TYPE, Ascending, TestRadixBits, sycl::usm::alloc::shared>(
q, size, TestUtils::get_new_kernel_params<0>(params));
q, size, TestUtils::create_new_kernel_param_idx<0>(params));
test_usm<TEST_KEY_TYPE, TEST_VALUE_TYPE, Descending, TestRadixBits, sycl::usm::alloc::shared>(
q, size, TestUtils::get_new_kernel_params<1>(params));
q, size, TestUtils::create_new_kernel_param_idx<1>(params));
test_sycl_buffer<TEST_KEY_TYPE, TEST_VALUE_TYPE, Ascending, TestRadixBits>(
q, size, TestUtils::get_new_kernel_params<2>(params));
q, size, TestUtils::create_new_kernel_param_idx<2>(params));
test_sycl_buffer<TEST_KEY_TYPE, TEST_VALUE_TYPE, Descending, TestRadixBits>(
q, size, TestUtils::get_new_kernel_params<3>(params));
q, size, TestUtils::create_new_kernel_param_idx<3>(params));
}
}
catch (const ::std::exception& exc)
Expand Down
18 changes: 9 additions & 9 deletions test/kt/esimd_radix_sort_out_of_place.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,13 +210,13 @@ template <typename T, bool IsAscending, std::uint8_t RadixBits, typename KernelP
void
test_general_cases(sycl::queue q, std::size_t size, KernelParam param)
{
test_usm<T, IsAscending, RadixBits, sycl::usm::alloc::shared>(q, size, TestUtils::get_new_kernel_params<0>(param));
test_usm<T, IsAscending, RadixBits, sycl::usm::alloc::device>(q, size, TestUtils::get_new_kernel_params<1>(param));
test_sycl_iterators<T, IsAscending, RadixBits>(q, size, TestUtils::get_new_kernel_params<2>(param));
test_sycl_buffer<T, IsAscending, RadixBits>(q, size, TestUtils::get_new_kernel_params<3>(param));
test_usm<T, IsAscending, RadixBits, sycl::usm::alloc::shared>(q, size, TestUtils::create_new_kernel_param_idx<0>(param));
test_usm<T, IsAscending, RadixBits, sycl::usm::alloc::device>(q, size, TestUtils::create_new_kernel_param_idx<1>(param));
test_sycl_iterators<T, IsAscending, RadixBits>(q, size, TestUtils::create_new_kernel_param_idx<2>(param));
test_sycl_buffer<T, IsAscending, RadixBits>(q, size, TestUtils::create_new_kernel_param_idx<3>(param));
#if _ENABLE_RANGES_TESTING
test_all_view<T, IsAscending, RadixBits>(q, size, TestUtils::get_new_kernel_params<4>(param));
test_subrange_view<T, IsAscending, RadixBits>(q, size, TestUtils::get_new_kernel_params<5>(param));
test_all_view<T, IsAscending, RadixBits>(q, size, TestUtils::create_new_kernel_param_idx<4>(param));
test_subrange_view<T, IsAscending, RadixBits>(q, size, TestUtils::create_new_kernel_param_idx<5>(param));
#endif // _ENABLE_RANGES_TESTING
}

Expand All @@ -242,11 +242,11 @@ main()
for (auto size : sort_sizes)
{
test_general_cases<TEST_KEY_TYPE, Ascending, TestRadixBits>(
q, size, TestUtils::get_new_kernel_params<0>(params));
q, size, TestUtils::create_new_kernel_param_idx<0>(params));
test_general_cases<TEST_KEY_TYPE, Descending, TestRadixBits>(
q, size, TestUtils::get_new_kernel_params<1>(params));
q, size, TestUtils::create_new_kernel_param_idx<1>(params));
}
test_small_sizes<TEST_KEY_TYPE, Ascending, TestRadixBits>(q, TestUtils::get_new_kernel_params<3>(params));
test_small_sizes<TEST_KEY_TYPE, Ascending, TestRadixBits>(q, TestUtils::create_new_kernel_param_idx<3>(params));
}
catch (const ::std::exception& exc)
{
Expand Down
14 changes: 7 additions & 7 deletions test/kt/single_pass_scan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,28 +175,28 @@ template <typename T, typename BinOp, typename KernelParam>
void
test_general_cases(sycl::queue q, std::size_t size, BinOp bin_op, KernelParam param)
{
test_usm<T, sycl::usm::alloc::shared>(q, size, bin_op, TestUtils::get_new_kernel_params<0>(param));
test_usm<T, sycl::usm::alloc::device>(q, size, bin_op, TestUtils::get_new_kernel_params<1>(param));
test_sycl_iterators<T>(q, size, bin_op, TestUtils::get_new_kernel_params<2>(param));
test_usm<T, sycl::usm::alloc::shared>(q, size, bin_op, TestUtils::create_new_kernel_param_idx<0>(param));
test_usm<T, sycl::usm::alloc::device>(q, size, bin_op, TestUtils::create_new_kernel_param_idx<1>(param));
test_sycl_iterators<T>(q, size, bin_op, TestUtils::create_new_kernel_param_idx<2>(param));
#if _ENABLE_RANGES_TESTING
test_all_view<T>(q, size, bin_op, TestUtils::get_new_kernel_params<3>(param));
test_buffer<T>(q, size, bin_op, TestUtils::get_new_kernel_params<4>(param));
test_all_view<T>(q, size, bin_op, TestUtils::create_new_kernel_param_idx<3>(param));
test_buffer<T>(q, size, bin_op, TestUtils::create_new_kernel_param_idx<4>(param));
#endif
}

template <typename T, typename KernelParam>
void
test_all_cases(sycl::queue q, std::size_t size, KernelParam param)
{
test_general_cases<T>(q, size, std::plus<T>{}, TestUtils::get_new_kernel_params<0>(param));
test_general_cases<T>(q, size, std::plus<T>{}, TestUtils::create_new_kernel_param_idx<0>(param));
#if _PSTL_GROUP_REDUCTION_MULT_INT64_BROKEN
static constexpr bool int64_mult_broken = std::is_integral_v<T> && (sizeof(T) == 8);
#else
static constexpr bool int64_mult_broken = 0;
#endif
if constexpr (!int64_mult_broken)
{
test_general_cases<T>(q, size, std::multiplies<T>{}, TestUtils::get_new_kernel_params<1>(param));
test_general_cases<T>(q, size, std::multiplies<T>{}, TestUtils::create_new_kernel_param_idx<1>(param));
}
}

Expand Down
16 changes: 10 additions & 6 deletions test/support/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -961,17 +961,21 @@ create_new_policy_idx(Policy&& policy)

#if TEST_DPCPP_BACKEND_PRESENT
template <typename KernelName, int idx>
struct __kernel_name_with_idx
struct kernel_name_with_idx
{
};

template <int idx, typename KernelParams>
template <int idx, typename KernelParam>
constexpr auto
get_new_kernel_params(KernelParams)
create_new_kernel_param_idx(KernelParam)
{
return oneapi::dpl::experimental::kt::kernel_param<
KernelParams::data_per_workitem, KernelParams::workgroup_size,
__kernel_name_with_idx<typename KernelParams::kernel_name, idx>>{};
#if TEST_EXPLICIT_KERNEL_NAMES
return oneapi::dpl::experimental::kt::kernel_param<KernelParam::data_per_workitem,
KernelParam::workgroup_size,
kernel_name_with_idx<typename KernelParam::kernel_name, idx>>{};
#else
return KernelParam{};
#endif // TEST_EXPLICIT_KERNEL_NAMES
}
#endif //TEST_DPCPP_BACKEND_PRESENT

Expand Down
Loading