diff --git a/src/avx512-64bit-argsort.hpp b/src/avx512-64bit-argsort.hpp index e8254ce1..7c6a34c1 100644 --- a/src/avx512-64bit-argsort.hpp +++ b/src/avx512-64bit-argsort.hpp @@ -85,7 +85,7 @@ X86_SIMD_SORT_INLINE void argsort_16_64bit(type_t *arr, int64_t *arg, int32_t N) typename vtype::opmask_t load_mask = (0x01 << (N - 8)) - 0x01; argzmm_t argzmm1 = argtype::loadu(arg); argzmm_t argzmm2 = argtype::maskz_loadu(load_mask, arg + 8); - reg_t arrzmm1 = vtype::template i64gather(argzmm1, arr); + reg_t arrzmm1 = vtype::i64gather(arr, arg); reg_t arrzmm2 = vtype::template mask_i64gather( vtype::zmm_max(), load_mask, argzmm2, arr); arrzmm1 = sort_zmm_64bit(arrzmm1, argzmm1); @@ -111,7 +111,7 @@ X86_SIMD_SORT_INLINE void argsort_32_64bit(type_t *arr, int64_t *arg, int32_t N) X86_SIMD_SORT_UNROLL_LOOP(2) for (int ii = 0; ii < 2; ++ii) { argzmm[ii] = argtype::loadu(arg + 8 * ii); - arrzmm[ii] = vtype::template i64gather(argzmm[ii], arr); + arrzmm[ii] = vtype::i64gather(arr, arg + 8 * ii); arrzmm[ii] = sort_zmm_64bit(arrzmm[ii], argzmm[ii]); } @@ -154,7 +154,7 @@ X86_SIMD_SORT_INLINE void argsort_64_64bit(type_t *arr, int64_t *arg, int32_t N) X86_SIMD_SORT_UNROLL_LOOP(4) for (int ii = 0; ii < 4; ++ii) { argzmm[ii] = argtype::loadu(arg + 8 * ii); - arrzmm[ii] = vtype::template i64gather(argzmm[ii], arr); + arrzmm[ii] = vtype::i64gather(arr, arg + 8 * ii); arrzmm[ii] = sort_zmm_64bit(arrzmm[ii], argzmm[ii]); } @@ -206,7 +206,7 @@ X86_SIMD_SORT_UNROLL_LOOP(4) //X86_SIMD_SORT_UNROLL_LOOP(8) // for (int ii = 0; ii < 8; ++ii) { // argzmm[ii] = argtype::loadu(arg + 8*ii); -// arrzmm[ii] = vtype::template i64gather(argzmm[ii], arr); +// arrzmm[ii] = vtype::i64gather(argzmm[ii], arr); // arrzmm[ii] = sort_zmm_64bit(arrzmm[ii], argzmm[ii]); // } // @@ -257,17 +257,14 @@ type_t get_pivot_64bit(type_t *arr, // median of 8 int64_t size = (right - left) / 8; using reg_t = typename vtype::reg_t; - // TODO: Use gather here too: - __m512i rand_index = _mm512_set_epi64(arg[left + size], - arg[left + 2 * size], - arg[left + 3 * size], - arg[left + 4 * size], - arg[left + 5 * size], - arg[left + 6 * size], - arg[left + 7 * size], - arg[left + 8 * size]); - reg_t rand_vec - = vtype::template i64gather(rand_index, arr); + reg_t rand_vec = vtype::set(arr[arg[left + size]], + arr[arg[left + 2 * size]], + arr[arg[left + 3 * size]], + arr[arg[left + 4 * size]], + arr[arg[left + 5 * size]], + arr[arg[left + 6 * size]], + arr[arg[left + 7 * size]], + arr[arg[left + 8 * size]]); // pivot will never be a nan, since there are no nan's! reg_t sort = sort_zmm_64bit(rand_vec); return ((type_t *)&sort)[4]; diff --git a/src/avx512-64bit-common.h b/src/avx512-64bit-common.h index 0353507e..bbbd440b 100644 --- a/src/avx512-64bit-common.h +++ b/src/avx512-64bit-common.h @@ -45,12 +45,22 @@ struct ymm_vector { { return _mm256_set1_ps(type_max()); } - static zmmi_t seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8) { return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8); } + static reg_t set(type_t v1, + type_t v2, + type_t v3, + type_t v4, + type_t v5, + type_t v6, + type_t v7, + type_t v8) + { + return _mm256_set_ps(v1, v2, v3, v4, v5, v6, v7, v8); + } static opmask_t kxor_opmask(opmask_t x, opmask_t y) { return _kxor_mask8(x, y); @@ -86,10 +96,16 @@ struct ymm_vector { { return _mm512_mask_i64gather_ps(src, mask, index, base, scale); } - template - static reg_t i64gather(__m512i index, void const *base) + static reg_t i64gather(type_t *arr, int64_t *ind) { - return _mm512_i64gather_ps(index, base, scale); + return set(arr[ind[7]], + arr[ind[6]], + arr[ind[5]], + arr[ind[4]], + arr[ind[3]], + arr[ind[2]], + arr[ind[1]], + arr[ind[0]]); } static reg_t loadu(void const *mem) { @@ -195,6 +211,17 @@ struct ymm_vector { { return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8); } + static reg_t set(type_t v1, + type_t v2, + type_t v3, + type_t v4, + type_t v5, + type_t v6, + type_t v7, + type_t v8) + { + return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8); + } static opmask_t kxor_opmask(opmask_t x, opmask_t y) { return _kxor_mask8(x, y); @@ -221,10 +248,16 @@ struct ymm_vector { { return _mm512_mask_i64gather_epi32(src, mask, index, base, scale); } - template - static reg_t i64gather(__m512i index, void const *base) + static reg_t i64gather(type_t *arr, int64_t *ind) { - return _mm512_i64gather_epi32(index, base, scale); + return set(arr[ind[7]], + arr[ind[6]], + arr[ind[5]], + arr[ind[4]], + arr[ind[3]], + arr[ind[2]], + arr[ind[1]], + arr[ind[0]]); } static reg_t loadu(void const *mem) { @@ -324,6 +357,17 @@ struct ymm_vector { { return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8); } + static reg_t set(type_t v1, + type_t v2, + type_t v3, + type_t v4, + type_t v5, + type_t v6, + type_t v7, + type_t v8) + { + return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8); + } static opmask_t kxor_opmask(opmask_t x, opmask_t y) { return _kxor_mask8(x, y); @@ -350,10 +394,16 @@ struct ymm_vector { { return _mm512_mask_i64gather_epi32(src, mask, index, base, scale); } - template - static reg_t i64gather(__m512i index, void const *base) + static reg_t i64gather(type_t *arr, int64_t *ind) { - return _mm512_i64gather_epi32(index, base, scale); + return set(arr[ind[7]], + arr[ind[6]], + arr[ind[5]], + arr[ind[4]], + arr[ind[3]], + arr[ind[2]], + arr[ind[1]], + arr[ind[0]]); } static reg_t loadu(void const *mem) { @@ -456,6 +506,17 @@ struct zmm_vector { { return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); } + static reg_t set(type_t v1, + type_t v2, + type_t v3, + type_t v4, + type_t v5, + type_t v6, + type_t v7, + type_t v8) + { + return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); + } static opmask_t kxor_opmask(opmask_t x, opmask_t y) { return _kxor_mask8(x, y); @@ -482,10 +543,16 @@ struct zmm_vector { { return _mm512_mask_i64gather_epi64(src, mask, index, base, scale); } - template - static reg_t i64gather(__m512i index, void const *base) + static reg_t i64gather(type_t *arr, int64_t *ind) { - return _mm512_i64gather_epi64(index, base, scale); + return set(arr[ind[7]], + arr[ind[6]], + arr[ind[5]], + arr[ind[4]], + arr[ind[3]], + arr[ind[2]], + arr[ind[1]], + arr[ind[0]]); } static reg_t loadu(void const *mem) { @@ -589,16 +656,33 @@ struct zmm_vector { { return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); } + static reg_t set(type_t v1, + type_t v2, + type_t v3, + type_t v4, + type_t v5, + type_t v6, + type_t v7, + type_t v8) + { + return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); + } template static reg_t mask_i64gather(reg_t src, opmask_t mask, __m512i index, void const *base) { return _mm512_mask_i64gather_epi64(src, mask, index, base, scale); } - template - static reg_t i64gather(__m512i index, void const *base) + static reg_t i64gather(type_t *arr, int64_t *ind) { - return _mm512_i64gather_epi64(index, base, scale); + return set(arr[ind[7]], + arr[ind[6]], + arr[ind[5]], + arr[ind[4]], + arr[ind[3]], + arr[ind[2]], + arr[ind[1]], + arr[ind[0]]); } static opmask_t knot_opmask(opmask_t x) { @@ -704,13 +788,22 @@ struct zmm_vector { { return _mm512_set1_pd(type_max()); } - static zmmi_t seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8) { return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); } - + static reg_t set(type_t v1, + type_t v2, + type_t v3, + type_t v4, + type_t v5, + type_t v6, + type_t v7, + type_t v8) + { + return _mm512_set_pd(v1, v2, v3, v4, v5, v6, v7, v8); + } static reg_t maskz_loadu(opmask_t mask, void const *mem) { return _mm512_maskz_loadu_pd(mask, mem); @@ -742,10 +835,16 @@ struct zmm_vector { { return _mm512_mask_i64gather_pd(src, mask, index, base, scale); } - template - static reg_t i64gather(__m512i index, void const *base) + static reg_t i64gather(type_t *arr, int64_t *ind) { - return _mm512_i64gather_pd(index, base, scale); + return set(arr[ind[7]], + arr[ind[6]], + arr[ind[5]], + arr[ind[4]], + arr[ind[3]], + arr[ind[2]], + arr[ind[1]], + arr[ind[0]]); } static reg_t loadu(void const *mem) { @@ -841,7 +940,6 @@ X86_SIMD_SORT_INLINE reg_t sort_zmm_64bit(reg_t zmm) template X86_SIMD_SORT_INLINE reg_t bitonic_merge_zmm_64bit(reg_t zmm) { - // 1) half_cleaner[8]: compare 0-4, 1-5, 2-6, 3-7 zmm = cmp_merge( zmm, diff --git a/src/avx512-common-argsort.h b/src/avx512-common-argsort.h index 015a6bd7..5c73aede 100644 --- a/src/avx512-common-argsort.h +++ b/src/avx512-common-argsort.h @@ -75,7 +75,7 @@ static inline int64_t partition_avx512(type_t *arr, if (right - left == vtype::numlanes) { argzmm_t argvec = argtype::loadu(arg + left); - reg_t vec = vtype::template i64gather(argvec, arr); + reg_t vec = vtype::i64gather(arr, arg + left); int32_t amount_gt_pivot = partition_vec(arg, left, left + vtype::numlanes, @@ -91,11 +91,9 @@ static inline int64_t partition_avx512(type_t *arr, // first and last vtype::numlanes values are partitioned at the end argzmm_t argvec_left = argtype::loadu(arg + left); - reg_t vec_left - = vtype::template i64gather(argvec_left, arr); + reg_t vec_left = vtype::i64gather(arr, arg + left); argzmm_t argvec_right = argtype::loadu(arg + (right - vtype::numlanes)); - reg_t vec_right - = vtype::template i64gather(argvec_right, arr); + reg_t vec_right = vtype::i64gather(arr, arg + (right - vtype::numlanes)); // store points of the vectors int64_t r_store = right - vtype::numlanes; int64_t l_store = left; @@ -113,11 +111,11 @@ static inline int64_t partition_avx512(type_t *arr, if ((r_store + vtype::numlanes) - right < left - l_store) { right -= vtype::numlanes; arg_vec = argtype::loadu(arg + right); - curr_vec = vtype::template i64gather(arg_vec, arr); + curr_vec = vtype::i64gather(arr, arg + right); } else { arg_vec = argtype::loadu(arg + left); - curr_vec = vtype::template i64gather(arg_vec, arr); + curr_vec = vtype::i64gather(arr, arg + left); left += vtype::numlanes; } // partition the current vector and save it on both sides of the array @@ -201,12 +199,11 @@ static inline int64_t partition_avx512_unrolled(type_t *arr, X86_SIMD_SORT_UNROLL_LOOP(8) for (int ii = 0; ii < num_unroll; ++ii) { argvec_left[ii] = argtype::loadu(arg + left + vtype::numlanes * ii); - vec_left[ii] = vtype::template i64gather( - argvec_left[ii], arr); + vec_left[ii] = vtype::i64gather(arr, arg + left + vtype::numlanes * ii); argvec_right[ii] = argtype::loadu( arg + (right - vtype::numlanes * (num_unroll - ii))); - vec_right[ii] = vtype::template i64gather( - argvec_right[ii], arr); + vec_right[ii] = vtype::i64gather( + arr, arg + (right - vtype::numlanes * (num_unroll - ii))); } // store points of the vectors int64_t r_store = right - vtype::numlanes; @@ -228,16 +225,16 @@ X86_SIMD_SORT_UNROLL_LOOP(8) for (int ii = 0; ii < num_unroll; ++ii) { arg_vec[ii] = argtype::loadu(arg + right + ii * vtype::numlanes); - curr_vec[ii] = vtype::template i64gather( - arg_vec[ii], arr); + curr_vec[ii] = vtype::i64gather( + arr, arg + right + ii * vtype::numlanes); } } else { X86_SIMD_SORT_UNROLL_LOOP(8) for (int ii = 0; ii < num_unroll; ++ii) { arg_vec[ii] = argtype::loadu(arg + left + ii * vtype::numlanes); - curr_vec[ii] = vtype::template i64gather( - arg_vec[ii], arr); + curr_vec[ii] = vtype::i64gather( + arr, arg + left + ii * vtype::numlanes); } left += num_unroll * vtype::numlanes; } diff --git a/src/avx512-common-qsort.h b/src/avx512-common-qsort.h index 70b63af3..4ebd9475 100644 --- a/src/avx512-common-qsort.h +++ b/src/avx512-common-qsort.h @@ -758,28 +758,23 @@ X86_SIMD_SORT_INLINE type_t get_pivot_32bit(type_t *arr, // median of 16 int64_t size = (right - left) / 16; using zmm_t = typename vtype::reg_t; - using ymm_t = typename vtype::halfreg_t; - __m512i rand_index1 = _mm512_set_epi64(left + size, - left + 2 * size, - left + 3 * size, - left + 4 * size, - left + 5 * size, - left + 6 * size, - left + 7 * size, - left + 8 * size); - __m512i rand_index2 = _mm512_set_epi64(left + 9 * size, - left + 10 * size, - left + 11 * size, - left + 12 * size, - left + 13 * size, - left + 14 * size, - left + 15 * size, - left + 16 * size); - ymm_t rand_vec1 - = vtype::template i64gather(rand_index1, arr); - ymm_t rand_vec2 - = vtype::template i64gather(rand_index2, arr); - zmm_t rand_vec = vtype::merge(rand_vec1, rand_vec2); + type_t vec_arr[16] = {arr[left + size], + arr[left + 2 * size], + arr[left + 3 * size], + arr[left + 4 * size], + arr[left + 5 * size], + arr[left + 6 * size], + arr[left + 7 * size], + arr[left + 8 * size], + arr[left + 9 * size], + arr[left + 10 * size], + arr[left + 11 * size], + arr[left + 12 * size], + arr[left + 13 * size], + arr[left + 14 * size], + arr[left + 15 * size], + arr[left + 16 * size]}; + zmm_t rand_vec = vtype::loadu(vec_arr); zmm_t sort = vtype::sort_vec(rand_vec); // pivot will never be a nan, since there are no nan's! return ((type_t *)&sort)[8]; @@ -793,15 +788,14 @@ X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr, // median of 8 int64_t size = (right - left) / 8; using zmm_t = typename vtype::reg_t; - __m512i rand_index = _mm512_set_epi64(left + size, - left + 2 * size, - left + 3 * size, - left + 4 * size, - left + 5 * size, - left + 6 * size, - left + 7 * size, - left + 8 * size); - zmm_t rand_vec = vtype::template i64gather(rand_index, arr); + zmm_t rand_vec = vtype::set(arr[left + size], + arr[left + 2 * size], + arr[left + 3 * size], + arr[left + 4 * size], + arr[left + 5 * size], + arr[left + 6 * size], + arr[left + 7 * size], + arr[left + 8 * size]); // pivot will never be a nan, since there are no nan's! zmm_t sort = vtype::sort_vec(rand_vec); return ((type_t *)&sort)[4];