Skip to content

Commit

Permalink
balance data load into SLM cache
Browse files Browse the repository at this point in the history
Signed-off-by: Sergey Kopienko <[email protected]>
  • Loading branch information
SergeyKopienko committed Nov 19, 2024
1 parent 00dcb1d commit 144de4a
Showing 1 changed file with 27 additions and 24 deletions.
51 changes: 27 additions & 24 deletions include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge.h
Original file line number Diff line number Diff line change
Expand Up @@ -303,10 +303,9 @@ struct __parallel_merge_submitter_large<_IdType, _CustomName,
assert(__chunk > 0);

// Pessimistically only use 2/3 of the memory to take into account memory used by compiled kernel
const std::size_t __max_slm_size_adj =
std::max((std::size_t)__chunk,
std::min((std::size_t)__n, oneapi::dpl::__internal::__slm_adjusted_work_group_size(
__exec, sizeof(_RangeValueType)))) * 2 / 3;
const auto __slm_adjusted_work_group_size = oneapi::dpl::__internal::__slm_adjusted_work_group_size(__exec, sizeof(_RangeValueType));
const auto __slm_adjusted_work_group_size_x_part = __slm_adjusted_work_group_size * 2 / 3;
const std::size_t __max_slm_size_adj = __slm_adjusted_work_group_size_x_part;

// The amount of data must be a multiple of the chunk size.
const std::size_t __max_source_data_items_fit_into_slm = __max_slm_size_adj - __max_slm_size_adj % __chunk;
Expand All @@ -319,7 +318,7 @@ struct __parallel_merge_submitter_large<_IdType, _CustomName,

// The amount of the base diagonals is the amount of the work-groups
// - also it's the distance between two base diagonals is equal to the amount of work-items in each work-group
const std::size_t __wg_count = oneapi::dpl::__internal::__dpl_ceiling_div(__n, __max_source_data_items_fit_into_slm);
const std::size_t __wg_count = oneapi::dpl::__internal::__dpl_ceiling_div(__n, __chunk * __wi_in_one_wg);

// Create storage for save split-points on each base diagonal + 1 (for the right base diagonal in the last work-group)
// - in GLOBAL coordinates
Expand Down Expand Up @@ -385,39 +384,43 @@ struct __parallel_merge_submitter_large<_IdType, _CustomName,

const _IdType __rng1_wg_data_size = __sp_base_right_global.first - __sp_base_left_global.first;
const _IdType __rng2_wg_data_size = __sp_base_right_global.second - __sp_base_left_global.second;
const _IdType __rng_wg_data_size = __rng1_wg_data_size + __rng2_wg_data_size;

_RangeValueType* __rng1_cache_slm = std::addressof(__loc_acc[0]);
_RangeValueType* __rng2_cache_slm = std::addressof(__loc_acc[0]) + __rng1_wg_data_size;

constexpr std::size_t __slm_bank_size = 32;
const std::size_t __chunk_of_data_reading = oneapi::dpl::__internal::__dpl_ceiling_div(__rng1_wg_data_size + __rng2_wg_data_size, __wi_in_one_wg);

const std::size_t __chunk_of_data_reading = std::max(
oneapi::dpl::__internal::__dpl_ceiling_div(__rng_wg_data_size, __wi_in_one_wg),
oneapi::dpl::__internal::__dpl_ceiling_div(__slm_bank_size, 2 * sizeof(_RangeValueType)));
const std::size_t __idx_begin = __local_id * __chunk_of_data_reading;
if (__idx_begin < __rng_wg_data_size)
const std::size_t __how_many_wi_reads_rng1 = oneapi::dpl::__internal::__dpl_ceiling_div(__rng1_wg_data_size, __chunk_of_data_reading);
const std::size_t __how_many_wi_reads_rng2 = oneapi::dpl::__internal::__dpl_ceiling_div(__rng2_wg_data_size, __chunk_of_data_reading);

// Calculate the amount of WI for read data from rng1
if (__local_id < __how_many_wi_reads_rng1)
{
const _IdType __idx_end = std::min(__idx_begin + __chunk_of_data_reading, (std::size_t)__rng_wg_data_size);
const std::size_t __idx_begin = __local_id * __chunk_of_data_reading;

// Cooperative data load from __rng1 to __rng1_cache_slm
if (__idx_begin < __rng1_wg_data_size)
{
const _IdType __idx_begin_rng1 = __idx_begin;
const _IdType __idx_end_rng1 = std::min(__idx_end, __rng1_wg_data_size);
const std::size_t __idx_end = std::min(__idx_begin + __chunk_of_data_reading, (std::size_t)__rng1_wg_data_size);

_ONEDPL_PRAGMA_UNROLL
for (_IdType __idx = __idx_begin_rng1; __idx < __idx_end_rng1; ++__idx)
for (_IdType __idx = __idx_begin; __idx < __idx_end; ++__idx)
__rng1_cache_slm[__idx] = __rng1[__sp_base_left_global.first + __idx];
}
}

// Cooperative data load from __rng2 to __rng1_cache_slm
if (__idx_end > __rng1_wg_data_size)
{
const _IdType __idx_begin_rng2 = 0;
const _IdType __idx_end_rng2 = __idx_end - __rng1_wg_data_size;
const std::size_t __first_wi_local_id_for_read_rng2 = __wi_in_one_wg - __how_many_wi_reads_rng2 - 1;
if (__local_id >= __first_wi_local_id_for_read_rng2)
{
const std::size_t __idx_begin = (__local_id - __first_wi_local_id_for_read_rng2) * __chunk_of_data_reading;

// Cooperative data load from __rng2 to __rng2_cache_slm
if (__idx_begin < __rng2_wg_data_size)
{
const std::size_t __idx_end = std::min(__idx_begin + __chunk_of_data_reading, (std::size_t)__rng2_wg_data_size);

_ONEDPL_PRAGMA_UNROLL
for (_IdType __idx = __idx_begin_rng2; __idx < __idx_end_rng2; ++__idx)
for (_IdType __idx = __idx_begin; __idx < __idx_end; ++__idx)
__rng2_cache_slm[__idx] = __rng2[__sp_base_left_global.second + __idx];
}
}
Expand Down Expand Up @@ -477,10 +480,10 @@ __parallel_merge(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPolicy
constexpr bool __same_merge_types = std::is_same_v<_Range1ValueType, _Range2ValueType>;

const std::size_t __n = __rng1.size() + __rng2.size();
if (false)//if (__n < __starting_size_limit_for_large_submitter || !__same_merge_types)
if (__n < __starting_size_limit_for_large_submitter || !__same_merge_types)
{
static_assert(__starting_size_limit_for_large_submitter < std::numeric_limits<std::uint32_t>::max());

using _WiIndex = std::uint32_t;
using _MergeKernelName = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider<
__merge_kernel_name<_CustomName, _WiIndex>>;
Expand Down

0 comments on commit 144de4a

Please sign in to comment.