diff --git a/neural_speed/core/layers/mha_dense.cpp b/neural_speed/core/layers/mha_dense.cpp index c32807166..586fd7dfd 100644 --- a/neural_speed/core/layers/mha_dense.cpp +++ b/neural_speed/core/layers/mha_dense.cpp @@ -43,14 +43,21 @@ constexpr bool MHA_PREFER_AVX512FP16 = true; #ifdef __GNUC__ -#pragma GCC push_options -#pragma GCC target("avx512f", "avx512bw", "avx512vl") +#define ADD_TARGET(T, ...) __VA_ARGS__, T +#define TARGETS_512_0() "avx512f", "avx512bw", "avx512vl" #if CompileBF16() -#pragma GCC target("avx512bf16") +#define TARGETS_512_1() ADD_TARGET("avx512bf16", TARGETS_512_0()) +#else +#define TARGETS_512_1() TARGETS_512_0() #endif #if CompileFP16() -#pragma GCC target("avx512fp16") +#define TARGETS_512_2() ADD_TARGET("avx512fp16", TARGETS_512_1()) +#else +#define TARGETS_512_2() TARGETS_512_1() #endif +#define TARGET_512 __attribute__((target(TARGETS_512_2()))) +#else +#define TARGET_512 #endif using namespace bestla; // NOLINT @@ -88,8 +95,8 @@ struct mha_problem_t { template inline X_T poly_scale_2nd_ps(const Z_T z, const X_T f, const X_T c0, const X_T c1, const X_T c2); template <> -inline __m512 poly_scale_2nd_ps<__m512, __m512>(const __m512 z, const __m512 f, const __m512 c0, const __m512 c1, - const __m512 c2) { +TARGET_512 inline __m512 poly_scale_2nd_ps<__m512, __m512>(const __m512 z, const __m512 f, const __m512 c0, + const __m512 c1, const __m512 c2) { const auto y = _mm512_fmadd_ps(_mm512_fmadd_ps(f, c0, c1), f, c2); // auto y = (f * c0 + c1) * f + c2; const auto exp = _mm512_scalef_ps(y, z); return exp; @@ -105,14 +112,14 @@ inline __m256 poly_scale_2nd_ps<__m256, __m256i>(const __m256i z, const __m256 f const auto y_not_exp = _mm256_and_si256(_mm256_castps_si256(y), mask_not_exp); const auto y_exp_scaled = _mm256_add_epi32(y_exp, _mm256_slli_epi32(z, 23)); - return _mm256_castsi256_ps(_mm256_or_epi32(y_not_exp, _mm256_and_si256(y_exp_scaled, mask_exp))); + return _mm256_castsi256_ps(_mm256_or_si256(y_not_exp, _mm256_and_si256(y_exp_scaled, mask_exp))); } template inline X_T exp_ps_0_1(const X_T x); template <> -inline __m512 exp_ps_0_1<__m512>(const __m512 x) { +TARGET_512 inline __m512 exp_ps_0_1<__m512>(const __m512 x) { static const auto c0 = _mm512_set1_ps(0.240226507f); static const auto c1 = _mm512_set1_ps(0.452920674f); static const auto c2 = _mm512_set1_ps(0.713483036f); @@ -158,13 +165,13 @@ inline float mha_exp_ref(float x) { } #ifdef NOT_CURRENTLY_USED -inline __m512 exp_2nd_ph(const __m512 z, const __m512 f, const __m512 c0, const __m512 c1, const __m512 c2) { +TARGET_512 inline __m512 exp_2nd_ph(const __m512 z, const __m512 f, const __m512 c0, const __m512 c1, const __m512 c2) { const auto y = _mm512_fmadd_ph(_mm512_fmadd_ph(f, c0, c1), f, c2); // auto y = (f * c0 + c1) * f + c2; const auto exp = _mm512_scalef_ph(y, z); return exp; } -inline __m512 exp_ph_0_1(const __m512 x) { +TARGET_512 inline __m512 exp_ph_0_1(const __m512 x) { static const auto c0 = _mm512_castsi512_ph(_mm512_set1_epi16(fp16(0.240226507f).x)); static const auto c1 = _mm512_castsi512_ph(_mm512_set1_epi16(fp16(0.452920674f).x)); static const auto c2 = _mm512_castsi512_ph(_mm512_set1_epi16(fp16(0.713483036f).x)); @@ -270,8 +277,9 @@ class scale_exp_acc_sum_fp32_t { float alibi_slope; // m-factor in the alibi paper for current head: https://arxiv.org/abs/2108.12409 }; - BTLA_CODE forward(const float* src, const int src_step, const int M_offset, const int N_offset, const int M, - const int N, const Param& p, void* /* tmpcache */, size_t /* cachesize */) const { + TARGET_512 BTLA_CODE forward(const float* src, const int src_step, const int M_offset, const int N_offset, + const int M, const int N, const Param& p, void* /* tmpcache */, + size_t /* cachesize */) const { assert(("alibi not supported!", p.alibi_slope == 0.f)); const auto dst = p.dst + M_offset * p.ld_dst + N_offset; const auto dst_sum = p.dst_sum + M_offset; @@ -970,8 +978,9 @@ class scale_track_max_t { float alibi_slope; // m-factor in the alibi paper for current head: https://arxiv.org/abs/2108.12409 }; - BTLA_CODE forward(const SType* src, const int src_step, const int M_offset, const int N_offset, const int M, - const int N, const Param& p, void* /* tmpcache */, size_t /* cachesize */) const { + TARGET_512 BTLA_CODE forward(const SType* src, const int src_step, const int M_offset, const int N_offset, + const int M, const int N, const Param& p, void* /* tmpcache */, + size_t /* cachesize */) const { assert(("alibi not supported!", p.alibi_slope == 0.f)); const auto dst = p.dst + M_offset * p.ld_dst + N_offset; const auto dst_max = p.dst_max + M_offset; @@ -1047,6 +1056,47 @@ class scale_track_max_t { : forward_(src, src_step, M_offset, N_offset, M, N, p); } +#if CompileAVX512F() + template + TARGET_512 BTLA_CODE forward_512(const SType* src, const int src_step, const int M_offset, const int N_offset, + const int M, const int N, const Param& p) const { + const auto dst = p.dst + M_offset * p.ld_dst + N_offset; + const auto dst_max = p.dst_max + M_offset; + const auto v_scale = _mm512_set1_ps(p.scale); + const auto v_seq15 = _mm512_loadu_ps(seq15); + const auto alibi_slope = _mm512_set1_ps(p.alibi_slope); + const auto alibi_base = _mm512_mul_ps(alibi_slope, _mm512_add_ps(v_seq15, _mm512_set1_ps(N_offset))); + const auto alibi_step = _mm512_set1_ps(p.alibi_slope * 16); + + for (int i = 0; i < M; ++i) { + auto alibi_curr = alibi_base; + const auto N_unmasked = + std::min(N, p.causal_offset < 0 ? INT32_MAX : i + M_offset - N_offset + p.causal_offset + 1); + + const auto v_mask = _cvtu32_mask16((1U << (N_unmasked % 16)) - 1); + int j = 0; + auto v_max = _mm512_set1_ps(-INFINITY); + for (; j < N_unmasked - 15; j += 16) { + const auto xs = _mm512_fmadd_ps(v_scale, _mm512_loadu_ps(src + i * src_step + j), alibi_curr); + v_max = _mm512_max_ps(v_max, xs); + _mm512_storeu_ps(dst + i * p.ld_dst + j, xs); + if constexpr (HAS_ALIBI) alibi_curr = _mm512_add_ps(alibi_curr, alibi_step); + } + if (j < N_unmasked) { + const auto xs = _mm512_fmadd_ps(v_scale, _mm512_maskz_loadu_ps(v_mask, src + i * src_step + j), alibi_curr); + v_max = _mm512_mask_max_ps(v_max, v_mask, v_max, xs); + _mm512_storeu_ps(dst + i * p.ld_dst + j, xs); + if constexpr (HAS_ALIBI) alibi_curr = _mm512_add_ps(alibi_curr, alibi_step); + j += 16; + } + dst_max[i] = std::max(dst_max[i], _mm512_reduce_max_ps(v_max)); + + // if (j < utils::padto(N, 64)) + // memset(dst + i * p.ld_dst + j, 0, sizeof(*dst) * (utils::padto(N, 64) - j)); + } + return BTLA_CODE::Success; + } +#endif template BTLA_CODE forward_(const SType* src, const int src_step, const int M_offset, const int N_offset, const int M, const int N, const Param& p) const { @@ -1055,39 +1105,7 @@ class scale_track_max_t { #if MHA_2ND_EXP #if CompileAVX512F() if constexpr (ISA_T >= BTLA_ISA::AVX512F) { - const auto v_scale = _mm512_set1_ps(p.scale); - const auto v_seq15 = _mm512_loadu_ps(seq15); - const auto alibi_slope = _mm512_set1_ps(p.alibi_slope); - const auto alibi_base = _mm512_mul_ps(alibi_slope, _mm512_add_ps(v_seq15, _mm512_set1_ps(N_offset))); - const auto alibi_step = _mm512_set1_ps(p.alibi_slope * 16); - - for (int i = 0; i < M; ++i) { - auto alibi_curr = alibi_base; - const auto N_unmasked = - std::min(N, p.causal_offset < 0 ? INT32_MAX : i + M_offset - N_offset + p.causal_offset + 1); - - const auto v_mask = _cvtu32_mask16((1U << (N_unmasked % 16)) - 1); - int j = 0; - auto v_max = _mm512_set1_ps(-INFINITY); - for (; j < N_unmasked - 15; j += 16) { - const auto xs = _mm512_fmadd_ps(v_scale, _mm512_loadu_ps(src + i * src_step + j), alibi_curr); - v_max = _mm512_max_ps(v_max, xs); - _mm512_storeu_ps(dst + i * p.ld_dst + j, xs); - if constexpr (HAS_ALIBI) alibi_curr = _mm512_add_ps(alibi_curr, alibi_step); - } - if (j < N_unmasked) { - const auto xs = _mm512_fmadd_ps(v_scale, _mm512_maskz_loadu_ps(v_mask, src + i * src_step + j), alibi_curr); - v_max = _mm512_mask_max_ps(v_max, v_mask, v_max, xs); - _mm512_storeu_ps(dst + i * p.ld_dst + j, xs); - if constexpr (HAS_ALIBI) alibi_curr = _mm512_add_ps(alibi_curr, alibi_step); - j += 16; - } - dst_max[i] = std::max(dst_max[i], _mm512_reduce_max_ps(v_max)); - - // if (j < utils::padto(N, 64)) - // memset(dst + i * p.ld_dst + j, 0, sizeof(*dst) * (utils::padto(N, 64) - j)); - } - return BTLA_CODE::Success; + return forward_512(src, src_step, M_offset, N_offset, M, N, p); } #endif #if CompileAVX2() @@ -1163,8 +1181,9 @@ class scale_track_max_t { float alibi_slope; // m-factor in the alibi paper for current head: https://arxiv.org/abs/2108.12409 }; - BTLA_CODE forward(const SType* src, const int src_step, const int M_offset, const int N_offset, const int M, - const int N, const Param& p, void* /* tmpcache */, size_t /* cachesize */) const { + TARGET_512 BTLA_CODE forward(const SType* src, const int src_step, const int M_offset, const int N_offset, + const int M, const int N, const Param& p, void* /* tmpcache */, + size_t /* cachesize */) const { assert(("alibi not supported!", p.alibi_slope == 0.f)); const auto dst = p.dst + M_offset * p.ld_dst + N_offset; const auto dst_max = p.dst_max + M_offset; @@ -1267,8 +1286,8 @@ class weight_cvt_bf16_ntile48_t { bool is_padded; }; weight_cvt_bf16_ntile48_t() = default; - BTLA_CODE getWeight(BType** dst_ptr, int* dst_step, const Param& p, int k_size, int n_size, int k_offset, - int n_offset, void* /* tmpcache */, size_t /* cachesize */) { + TARGET_512 BTLA_CODE getWeight(BType** dst_ptr, int* dst_step, const Param& p, int k_size, int n_size, int k_offset, + int n_offset, void* /* tmpcache */, size_t /* cachesize */) { assert(p.is_padded); const auto src = const_cast(p.B) + k_offset * 48 + n_offset * p.ldb; const auto dst = *dst_ptr; @@ -1360,8 +1379,8 @@ struct inplace_precompute_max_softmax_t { #if CompileFP16() template struct inplace_precompute_max_softmax_t { - static void forward(int m_size, int n_size, int n_pad_size, bool is_causal, float* src, fp16* dst, const float* s_max, - float* expsum, int ld_src, int ld_dst) { + TARGET_512 static void forward(int m_size, int n_size, int n_pad_size, bool is_causal, float* src, fp16* dst, + const float* s_max, float* expsum, int ld_src, int ld_dst) { for (int ii = 0; ii < m_size; ++ii) { const auto i_src = src + ii * ld_src; const auto i_dst = dst + ii * ld_dst; @@ -1410,8 +1429,8 @@ struct inplace_precompute_max_softmax_t { #if CompileBF16() template struct inplace_precompute_max_softmax_t { - static void forward(int m_size, int n_size, int n_pad_size, bool is_causal, float* src, bf16* dst, const float* s_max, - float* expsum, int ld_src, int ld_dst) { + TARGET_512 static void forward(int m_size, int n_size, int n_pad_size, bool is_causal, float* src, bf16* dst, + const float* s_max, float* expsum, int ld_src, int ld_dst) { for (int ii = 0; ii < m_size; ++ii) { const auto i_src = src + ii * ld_src; const auto i_dst = dst + ii * ld_dst; @@ -1469,7 +1488,7 @@ struct inplace_precompute_max_softmax_t { #if CompileAVX512F() template struct inplace_precompute_max_softmax_t= BTLA_ISA::AVX512F, float>, float, ISA_T> { - static void forward( // NOLINT [build/include_what_you_use] + TARGET_512 static void forward( // NOLINT [build/include_what_you_use] int m_size, int n_size, int n_pad_size, bool is_causal, float* src, float* dst, const float* s_max, float* expsum, int ld_src, int ld_dst) { for (int ii = 0; ii < m_size; ++ii) { @@ -1571,8 +1590,8 @@ struct inplace_precompute_max_softmax_t struct inplace_precompute_max_softmax_t { - static void forward(int m_size, int n_size, int n_pad_size, bool is_causal, float* src, uint8_t* dst, float* s_max, - float* expsum, int ld_src, int ld_dst) { + TARGET_512 static void forward(int m_size, int n_size, int n_pad_size, bool is_causal, float* src, uint8_t* dst, + float* s_max, float* expsum, int ld_src, int ld_dst) { for (int ii = 0; ii < m_size; ++ii) { const auto i_src = src + ii * ld_src; const auto i_dst = dst + ii * ld_dst; @@ -2276,7 +2295,7 @@ bool bestla_reordered_attn_fp32_support(const attn_shape_t* params) { GetCPUDevice(); #if CompileBF16() // TODO(Yi): check K V's layout - return _cd->AMX_BF16(); + if (_cd->AMX_BF16()) return true; #endif return _cd->AVX512F() || _cd->AVX2(); // use avx2 and f16c on avx2 platforms } @@ -2432,7 +2451,7 @@ void bestla_reordered_attn_fp32_update_k_24x1(const bestla_fusion_attn_fp32_upda const auto cache_step_bs = p.heads_kv * cache_step_head_num; const int n_para = p.batch_size * p.heads_kv; -#pragma omp parallel + // #pragma omp parallel for (int i_para = 0; i_para < n_para; ++i_para) { const int ibs = i_para / p.heads_kv; const int ihn = i_para % p.heads_kv; @@ -2702,9 +2721,9 @@ void bestla_fusion_attn_fp32_batch_cpy_v(const bestla_fusion_attn_fp32_batch_cpy : bestla_fusion_attn_fp32_batch_cpy_v_(params); } -#ifdef __GNUC__ -#pragma GCC pop_options -#endif +// #ifdef __GNUC__ +// #pragma GCC pop_options +// #endif #ifdef NS_TESTS #define CheckISA(ISA) \