Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
TARGET_512
Browse files Browse the repository at this point in the history
  • Loading branch information
DDEle committed Feb 22, 2024
1 parent 227c806 commit 5f20ca8
Showing 1 changed file with 82 additions and 63 deletions.
145 changes: 82 additions & 63 deletions neural_speed/core/layers/mha_dense.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,21 @@
constexpr bool MHA_PREFER_AVX512FP16 = true;

#ifdef __GNUC__
#pragma GCC push_options
#pragma GCC target("avx512f", "avx512bw", "avx512vl")
#define ADD_TARGET(T, ...) __VA_ARGS__, T
#define TARGETS_512_0() "avx512f", "avx512bw", "avx512vl"
#if CompileBF16()
#pragma GCC target("avx512bf16")
#define TARGETS_512_1() ADD_TARGET("avx512bf16", TARGETS_512_0())
#else
#define TARGETS_512_1() TARGETS_512_0()
#endif
#if CompileFP16()
#pragma GCC target("avx512fp16")
#define TARGETS_512_2() ADD_TARGET("avx512fp16", TARGETS_512_1())
#else
#define TARGETS_512_2() TARGETS_512_1()
#endif
#define TARGET_512 __attribute__((target(TARGETS_512_2())))
#else
#define TARGET_512
#endif

using namespace bestla; // NOLINT
Expand Down Expand Up @@ -88,8 +95,8 @@ struct mha_problem_t {
template <typename X_T, typename Z_T = X_T>
inline X_T poly_scale_2nd_ps(const Z_T z, const X_T f, const X_T c0, const X_T c1, const X_T c2);
template <>
inline __m512 poly_scale_2nd_ps<__m512, __m512>(const __m512 z, const __m512 f, const __m512 c0, const __m512 c1,
const __m512 c2) {
TARGET_512 inline __m512 poly_scale_2nd_ps<__m512, __m512>(const __m512 z, const __m512 f, const __m512 c0,
const __m512 c1, const __m512 c2) {
const auto y = _mm512_fmadd_ps(_mm512_fmadd_ps(f, c0, c1), f, c2); // auto y = (f * c0 + c1) * f + c2;
const auto exp = _mm512_scalef_ps(y, z);
return exp;
Expand All @@ -105,14 +112,14 @@ inline __m256 poly_scale_2nd_ps<__m256, __m256i>(const __m256i z, const __m256 f
const auto y_not_exp = _mm256_and_si256(_mm256_castps_si256(y), mask_not_exp);

const auto y_exp_scaled = _mm256_add_epi32(y_exp, _mm256_slli_epi32(z, 23));
return _mm256_castsi256_ps(_mm256_or_epi32(y_not_exp, _mm256_and_si256(y_exp_scaled, mask_exp)));
return _mm256_castsi256_ps(_mm256_or_si256(y_not_exp, _mm256_and_si256(y_exp_scaled, mask_exp)));
}

template <typename X_T>
inline X_T exp_ps_0_1(const X_T x);

template <>
inline __m512 exp_ps_0_1<__m512>(const __m512 x) {
TARGET_512 inline __m512 exp_ps_0_1<__m512>(const __m512 x) {
static const auto c0 = _mm512_set1_ps(0.240226507f);
static const auto c1 = _mm512_set1_ps(0.452920674f);
static const auto c2 = _mm512_set1_ps(0.713483036f);
Expand Down Expand Up @@ -158,13 +165,13 @@ inline float mha_exp_ref(float x) {
}

#ifdef NOT_CURRENTLY_USED
inline __m512 exp_2nd_ph(const __m512 z, const __m512 f, const __m512 c0, const __m512 c1, const __m512 c2) {
TARGET_512 inline __m512 exp_2nd_ph(const __m512 z, const __m512 f, const __m512 c0, const __m512 c1, const __m512 c2) {
const auto y = _mm512_fmadd_ph(_mm512_fmadd_ph(f, c0, c1), f, c2); // auto y = (f * c0 + c1) * f + c2;
const auto exp = _mm512_scalef_ph(y, z);
return exp;
}

inline __m512 exp_ph_0_1(const __m512 x) {
TARGET_512 inline __m512 exp_ph_0_1(const __m512 x) {
static const auto c0 = _mm512_castsi512_ph(_mm512_set1_epi16(fp16(0.240226507f).x));
static const auto c1 = _mm512_castsi512_ph(_mm512_set1_epi16(fp16(0.452920674f).x));
static const auto c2 = _mm512_castsi512_ph(_mm512_set1_epi16(fp16(0.713483036f).x));
Expand Down Expand Up @@ -270,8 +277,9 @@ class scale_exp_acc_sum_fp32_t {
float alibi_slope; // m-factor in the alibi paper for current head: https://arxiv.org/abs/2108.12409
};

BTLA_CODE forward(const float* src, const int src_step, const int M_offset, const int N_offset, const int M,
const int N, const Param& p, void* /* tmpcache */, size_t /* cachesize */) const {
TARGET_512 BTLA_CODE forward(const float* src, const int src_step, const int M_offset, const int N_offset,
const int M, const int N, const Param& p, void* /* tmpcache */,
size_t /* cachesize */) const {
assert(("alibi not supported!", p.alibi_slope == 0.f));
const auto dst = p.dst + M_offset * p.ld_dst + N_offset;
const auto dst_sum = p.dst_sum + M_offset;
Expand Down Expand Up @@ -970,8 +978,9 @@ class scale_track_max_t<ISA_T, fp16, float> {
float alibi_slope; // m-factor in the alibi paper for current head: https://arxiv.org/abs/2108.12409
};

BTLA_CODE forward(const SType* src, const int src_step, const int M_offset, const int N_offset, const int M,
const int N, const Param& p, void* /* tmpcache */, size_t /* cachesize */) const {
TARGET_512 BTLA_CODE forward(const SType* src, const int src_step, const int M_offset, const int N_offset,
const int M, const int N, const Param& p, void* /* tmpcache */,
size_t /* cachesize */) const {
assert(("alibi not supported!", p.alibi_slope == 0.f));
const auto dst = p.dst + M_offset * p.ld_dst + N_offset;
const auto dst_max = p.dst_max + M_offset;
Expand Down Expand Up @@ -1047,6 +1056,47 @@ class scale_track_max_t<ISA_T, float, float> {
: forward_<true>(src, src_step, M_offset, N_offset, M, N, p);
}

#if CompileAVX512F()
template <bool HAS_ALIBI>
TARGET_512 BTLA_CODE forward_512(const SType* src, const int src_step, const int M_offset, const int N_offset,
const int M, const int N, const Param& p) const {
const auto dst = p.dst + M_offset * p.ld_dst + N_offset;
const auto dst_max = p.dst_max + M_offset;
const auto v_scale = _mm512_set1_ps(p.scale);
const auto v_seq15 = _mm512_loadu_ps(seq15);
const auto alibi_slope = _mm512_set1_ps(p.alibi_slope);
const auto alibi_base = _mm512_mul_ps(alibi_slope, _mm512_add_ps(v_seq15, _mm512_set1_ps(N_offset)));
const auto alibi_step = _mm512_set1_ps(p.alibi_slope * 16);

for (int i = 0; i < M; ++i) {
auto alibi_curr = alibi_base;
const auto N_unmasked =
std::min(N, p.causal_offset < 0 ? INT32_MAX : i + M_offset - N_offset + p.causal_offset + 1);

const auto v_mask = _cvtu32_mask16((1U << (N_unmasked % 16)) - 1);
int j = 0;
auto v_max = _mm512_set1_ps(-INFINITY);
for (; j < N_unmasked - 15; j += 16) {
const auto xs = _mm512_fmadd_ps(v_scale, _mm512_loadu_ps(src + i * src_step + j), alibi_curr);
v_max = _mm512_max_ps(v_max, xs);
_mm512_storeu_ps(dst + i * p.ld_dst + j, xs);
if constexpr (HAS_ALIBI) alibi_curr = _mm512_add_ps(alibi_curr, alibi_step);
}
if (j < N_unmasked) {
const auto xs = _mm512_fmadd_ps(v_scale, _mm512_maskz_loadu_ps(v_mask, src + i * src_step + j), alibi_curr);
v_max = _mm512_mask_max_ps(v_max, v_mask, v_max, xs);
_mm512_storeu_ps(dst + i * p.ld_dst + j, xs);
if constexpr (HAS_ALIBI) alibi_curr = _mm512_add_ps(alibi_curr, alibi_step);
j += 16;
}
dst_max[i] = std::max(dst_max[i], _mm512_reduce_max_ps(v_max));

// if (j < utils::padto(N, 64))
// memset(dst + i * p.ld_dst + j, 0, sizeof(*dst) * (utils::padto(N, 64) - j));
}
return BTLA_CODE::Success;
}
#endif
template <bool HAS_ALIBI>
BTLA_CODE forward_(const SType* src, const int src_step, const int M_offset, const int N_offset, const int M,
const int N, const Param& p) const {
Expand All @@ -1055,39 +1105,7 @@ class scale_track_max_t<ISA_T, float, float> {
#if MHA_2ND_EXP
#if CompileAVX512F()
if constexpr (ISA_T >= BTLA_ISA::AVX512F) {
const auto v_scale = _mm512_set1_ps(p.scale);
const auto v_seq15 = _mm512_loadu_ps(seq15);
const auto alibi_slope = _mm512_set1_ps(p.alibi_slope);
const auto alibi_base = _mm512_mul_ps(alibi_slope, _mm512_add_ps(v_seq15, _mm512_set1_ps(N_offset)));
const auto alibi_step = _mm512_set1_ps(p.alibi_slope * 16);

for (int i = 0; i < M; ++i) {
auto alibi_curr = alibi_base;
const auto N_unmasked =
std::min(N, p.causal_offset < 0 ? INT32_MAX : i + M_offset - N_offset + p.causal_offset + 1);

const auto v_mask = _cvtu32_mask16((1U << (N_unmasked % 16)) - 1);
int j = 0;
auto v_max = _mm512_set1_ps(-INFINITY);
for (; j < N_unmasked - 15; j += 16) {
const auto xs = _mm512_fmadd_ps(v_scale, _mm512_loadu_ps(src + i * src_step + j), alibi_curr);
v_max = _mm512_max_ps(v_max, xs);
_mm512_storeu_ps(dst + i * p.ld_dst + j, xs);
if constexpr (HAS_ALIBI) alibi_curr = _mm512_add_ps(alibi_curr, alibi_step);
}
if (j < N_unmasked) {
const auto xs = _mm512_fmadd_ps(v_scale, _mm512_maskz_loadu_ps(v_mask, src + i * src_step + j), alibi_curr);
v_max = _mm512_mask_max_ps(v_max, v_mask, v_max, xs);
_mm512_storeu_ps(dst + i * p.ld_dst + j, xs);
if constexpr (HAS_ALIBI) alibi_curr = _mm512_add_ps(alibi_curr, alibi_step);
j += 16;
}
dst_max[i] = std::max(dst_max[i], _mm512_reduce_max_ps(v_max));

// if (j < utils::padto(N, 64))
// memset(dst + i * p.ld_dst + j, 0, sizeof(*dst) * (utils::padto(N, 64) - j));
}
return BTLA_CODE::Success;
return forward_512<HAS_ALIBI>(src, src_step, M_offset, N_offset, M, N, p);
}
#endif
#if CompileAVX2()
Expand Down Expand Up @@ -1163,8 +1181,9 @@ class scale_track_max_t<ISA_T, int32_t, float> {
float alibi_slope; // m-factor in the alibi paper for current head: https://arxiv.org/abs/2108.12409
};

BTLA_CODE forward(const SType* src, const int src_step, const int M_offset, const int N_offset, const int M,
const int N, const Param& p, void* /* tmpcache */, size_t /* cachesize */) const {
TARGET_512 BTLA_CODE forward(const SType* src, const int src_step, const int M_offset, const int N_offset,
const int M, const int N, const Param& p, void* /* tmpcache */,
size_t /* cachesize */) const {
assert(("alibi not supported!", p.alibi_slope == 0.f));
const auto dst = p.dst + M_offset * p.ld_dst + N_offset;
const auto dst_max = p.dst_max + M_offset;
Expand Down Expand Up @@ -1267,8 +1286,8 @@ class weight_cvt_bf16_ntile48_t {
bool is_padded;
};
weight_cvt_bf16_ntile48_t() = default;
BTLA_CODE getWeight(BType** dst_ptr, int* dst_step, const Param& p, int k_size, int n_size, int k_offset,
int n_offset, void* /* tmpcache */, size_t /* cachesize */) {
TARGET_512 BTLA_CODE getWeight(BType** dst_ptr, int* dst_step, const Param& p, int k_size, int n_size, int k_offset,
int n_offset, void* /* tmpcache */, size_t /* cachesize */) {
assert(p.is_padded);
const auto src = const_cast<SType*>(p.B) + k_offset * 48 + n_offset * p.ldb;
const auto dst = *dst_ptr;
Expand Down Expand Up @@ -1360,8 +1379,8 @@ struct inplace_precompute_max_softmax_t {
#if CompileFP16()
template <BTLA_ISA ISA_T>
struct inplace_precompute_max_softmax_t<float, fp16, ISA_T> {
static void forward(int m_size, int n_size, int n_pad_size, bool is_causal, float* src, fp16* dst, const float* s_max,
float* expsum, int ld_src, int ld_dst) {
TARGET_512 static void forward(int m_size, int n_size, int n_pad_size, bool is_causal, float* src, fp16* dst,
const float* s_max, float* expsum, int ld_src, int ld_dst) {
for (int ii = 0; ii < m_size; ++ii) {
const auto i_src = src + ii * ld_src;
const auto i_dst = dst + ii * ld_dst;
Expand Down Expand Up @@ -1410,8 +1429,8 @@ struct inplace_precompute_max_softmax_t<float, fp16, ISA_T> {
#if CompileBF16()
template <BTLA_ISA ISA_T>
struct inplace_precompute_max_softmax_t<float, bf16, ISA_T> {
static void forward(int m_size, int n_size, int n_pad_size, bool is_causal, float* src, bf16* dst, const float* s_max,
float* expsum, int ld_src, int ld_dst) {
TARGET_512 static void forward(int m_size, int n_size, int n_pad_size, bool is_causal, float* src, bf16* dst,
const float* s_max, float* expsum, int ld_src, int ld_dst) {
for (int ii = 0; ii < m_size; ++ii) {
const auto i_src = src + ii * ld_src;
const auto i_dst = dst + ii * ld_dst;
Expand Down Expand Up @@ -1469,7 +1488,7 @@ struct inplace_precompute_max_softmax_t<float, bf16, ISA_T> {
#if CompileAVX512F()
template <BTLA_ISA ISA_T>
struct inplace_precompute_max_softmax_t<std::enable_if_t<ISA_T >= BTLA_ISA::AVX512F, float>, float, ISA_T> {
static void forward( // NOLINT [build/include_what_you_use]
TARGET_512 static void forward( // NOLINT [build/include_what_you_use]
int m_size, int n_size, int n_pad_size, bool is_causal, float* src, float* dst, const float* s_max, float* expsum,
int ld_src, int ld_dst) {
for (int ii = 0; ii < m_size; ++ii) {
Expand Down Expand Up @@ -1571,8 +1590,8 @@ struct inplace_precompute_max_softmax_t<std::enable_if_t<(ISA_T < BTLA_ISA::AVX5
#endif
template <BTLA_ISA ISA_T>
struct inplace_precompute_max_softmax_t<float, uint8_t, ISA_T> {
static void forward(int m_size, int n_size, int n_pad_size, bool is_causal, float* src, uint8_t* dst, float* s_max,
float* expsum, int ld_src, int ld_dst) {
TARGET_512 static void forward(int m_size, int n_size, int n_pad_size, bool is_causal, float* src, uint8_t* dst,
float* s_max, float* expsum, int ld_src, int ld_dst) {
for (int ii = 0; ii < m_size; ++ii) {
const auto i_src = src + ii * ld_src;
const auto i_dst = dst + ii * ld_dst;
Expand Down Expand Up @@ -2276,7 +2295,7 @@ bool bestla_reordered_attn_fp32_support(const attn_shape_t* params) {
GetCPUDevice();
#if CompileBF16()
// TODO(Yi): check K V's layout
return _cd->AMX_BF16();
if (_cd->AMX_BF16()) return true;
#endif
return _cd->AVX512F() || _cd->AVX2(); // use avx2 and f16c on avx2 platforms
}
Expand Down Expand Up @@ -2432,7 +2451,7 @@ void bestla_reordered_attn_fp32_update_k_24x1(const bestla_fusion_attn_fp32_upda
const auto cache_step_bs = p.heads_kv * cache_step_head_num;

const int n_para = p.batch_size * p.heads_kv;
#pragma omp parallel
// #pragma omp parallel
for (int i_para = 0; i_para < n_para; ++i_para) {
const int ibs = i_para / p.heads_kv;
const int ihn = i_para % p.heads_kv;
Expand Down Expand Up @@ -2702,9 +2721,9 @@ void bestla_fusion_attn_fp32_batch_cpy_v(const bestla_fusion_attn_fp32_batch_cpy
: bestla_fusion_attn_fp32_batch_cpy_v_<true>(params);
}

#ifdef __GNUC__
#pragma GCC pop_options
#endif
// #ifdef __GNUC__
// #pragma GCC pop_options
// #endif

#ifdef NS_TESTS
#define CheckISA(ISA) \
Expand Down

0 comments on commit 5f20ca8

Please sign in to comment.