Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
sync
Browse files Browse the repository at this point in the history
  • Loading branch information
luoyu-intel committed Jul 12, 2024
1 parent 97c8190 commit fb7fc64
Show file tree
Hide file tree
Showing 14 changed files with 556 additions and 556 deletions.
11 changes: 6 additions & 5 deletions bestla/bestla/kernel_avx2.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ namespace bestla {
namespace kernel {
namespace avx2 {
#if CompileAVX2()
#if defined(__GNUC__)
#if defined(__INTEL_LLVM_COMPILER)
#pragma clang attribute push(__attribute__((target("avx2,fma,f16c"))), apply_to = function)
#elif defined(__GNUC__)
#pragma GCC push_options
#pragma GCC target("avx2", "fma", "f16c")
#elif defined(ICX)
//#pragma clang attribute push(__attribute__((target("avx2,fma,f16c"))), apply_to = function)
#endif

static inline void zero_reg() { _mm256_zeroupper(); }
Expand Down Expand Up @@ -5373,10 +5373,11 @@ static inline BTLA_CODE inplace_precompute_max_softmax_fp32_fp32(int m_size, int

return BTLA_CODE::Success;
}
#ifdef __GNUC__
#if defined(__INTEL_LLVM_COMPILER)
#pragma clang attribute pop
#elif defined(__GNUC__)
#pragma GCC pop_options
#endif

#endif
} // namespace avx2
} // namespace kernel
Expand Down
11 changes: 7 additions & 4 deletions bestla/bestla/kernel_avx512_bf16.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ namespace kernel {
namespace avx512f {
namespace avx512_bf16 {
#if CompileBF16()
#if defined(__GNUC__)
#if defined(__INTEL_LLVM_COMPILER)
#pragma clang attribute push(__attribute__((target("avx512f,avx512bf16,avx512vl,avx512bw"))), apply_to = function)
#elif defined(__GNUC__)
#pragma GCC push_options
#pragma GCC target("avx512bf16", "avx512vl", "avx512bw")
#elif defined(ICX)
#pragma clang attribute push(__attribute__((target("avx512bf16,avx512vl,avx512bw"))), apply_to = function)
#endif

static inline __m256i zmm_cvt_fp32_bf16(__m512 vfp32) { return (__m256i)_mm512_cvtneps_pbh(vfp32); }

static inline __m512 load_bf16_fp32(const utils::bf16* srcptr) {
Expand Down Expand Up @@ -175,7 +176,9 @@ static inline BTLA_CODE inplace_precompute_max_softmax_fp32_bf16(int m_size, int
}
return BTLA_CODE::Success;
}
#if defined(__GNUC__)
#if defined(__INTEL_LLVM_COMPILER)
#pragma clang attribute pop
#elif defined(__GNUC__)
#pragma GCC pop_options
#endif
#endif
Expand Down
10 changes: 6 additions & 4 deletions bestla/bestla/kernel_avx512_fp16.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ namespace kernel {
namespace avx512f {
namespace avx512_fp16 {
#if CompileFP16()
#if defined(__GNUC__)
#if defined(__INTEL_LLVM_COMPILER)
#pragma clang attribute push(__attribute__((target("avx512f,avx512bf16,avx512bw,avx512fp16"))), apply_to = function)
#elif defined(__GNUC__)
#pragma GCC push_options
#pragma GCC target("avx512f", "avx512bf16", "avx512vl", "avx512bw", "avx512fp16")
#elif defined(ICX)
#pragma clang attribute push(__attribute__((target("avx512f,avx512bf16,avx512bw,avx512fp16"))), apply_to = function)
#endif

inline __m512 zmm_cvt_fp16_fp32(__m256i vfp16) { return _mm512_cvtxph_ps((__m256h)vfp16); }
Expand Down Expand Up @@ -465,7 +465,9 @@ static inline BTLA_CODE inplace_precompute_max_softmax_fp32_fp16(int m_size, int
}
return BTLA_CODE::Success;
}
#if defined(__GNUC__)
#if defined(__INTEL_LLVM_COMPILER)
#pragma clang attribute pop
#elif defined(__GNUC__)
#pragma GCC pop_options
#endif
#endif
Expand Down
13 changes: 7 additions & 6 deletions bestla/bestla/kernel_avx512_vnni.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ namespace bestla {
namespace kernel {
namespace avx512f {
#if CompileAVX512VNNI()
#ifdef __GNUC__
#pragma GCC push_options
#pragma GCC target("avx512f", "avx512bw", "avx512vl", "avx512dq", "avx512vnni")
#elif defined(ICX)
#if defined(__INTEL_LLVM_COMPILER)
#pragma clang attribute push(__attribute__((target("avx512f,avx512bw,avx512vl,avx512dq,avx512vnni"))), \
apply_to = function)
#elif defined(__GNUC__)
#pragma GCC push_options
#pragma GCC target("avx512f", "avx512bw", "avx512vl", "avx512dq", "avx512vnni")
#endif

namespace vnni {
Expand Down Expand Up @@ -1517,9 +1517,10 @@ static inline BTLA_CODE gemv_7bit_s8s8_fp32(const utils::GemvParamA& A, const ut

} // namespace vnni

#ifdef __GNUC__
#if defined(__INTEL_LLVM_COMPILER)
#pragma clang attribute pop
#elif defined(__GNUC__)
#pragma GCC pop_options
#else
#endif
#endif
} // namespace avx512f
Expand Down
11 changes: 6 additions & 5 deletions bestla/bestla/kernel_avx512f.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ namespace bestla {
namespace kernel {
namespace avx512f {
#if CompileAVX512F()
#ifdef __GNUC__
#if defined(__INTEL_LLVM_COMPILER)
#pragma clang attribute push(__attribute__((target("avx512f,avx512vl,avx512bw,avx512dq"))), apply_to = function)
#elif defined(__GNUC__)
#pragma GCC push_options
#pragma GCC target("avx512f", "avx512bw", "avx512vl", "avx512dq")
#elif defined(ICX)
#pragma clang attribute push(__attribute__((target("avx512f,avx512vl,avx512bw,avx512dq"))), apply_to = function)
#endif

inline __m512 zmm_cvt_fp16_fp32(__m256i vfp16) { return _mm512_cvtph_ps(vfp16); }
Expand Down Expand Up @@ -6512,9 +6512,10 @@ static inline BTLA_CODE inplace_precompute_max_softmax_fp32_u8(int m_size, int n
}
return BTLA_CODE::Success;
}
#ifdef __GNUC__
#if defined(__INTEL_LLVM_COMPILER)
#pragma clang attribute pop
#elif defined(__GNUC__)
#pragma GCC pop_options
#else
#endif
#endif
} // namespace avx512f
Expand Down
10 changes: 6 additions & 4 deletions bestla/bestla/kernel_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -1925,6 +1925,7 @@ class ScaleTrackMax {
const int M_offset, const int N_offset, const int M, const int N, float scale,
int causal_offset, float alibi_slope, float tanh_scale, void* tmpcache,
size_t cachesize) {
#if CompileAVX2()
if (alibi_slope == 0 && tanh_scale == 0)
return avx2::scale_track_max_fp32_fp32<false, false>(src, src_step, dst, dst_max, ld_dst, M_offset, N_offset, M,
N, scale, causal_offset, alibi_slope, tanh_scale, tmpcache,
Expand All @@ -1937,14 +1938,15 @@ class ScaleTrackMax {
return avx2::scale_track_max_fp32_fp32<true, false>(src, src_step, dst, dst_max, ld_dst, M_offset, N_offset, M, N,
scale, causal_offset, alibi_slope, tanh_scale, tmpcache,
cachesize);
else
return BTLA_CODE::NotSupport;
#endif
return BTLA_CODE::NotSupport;
}

static BTLA_CODE forward_avx512(const SType* src, const int src_step, DType* dst, DType* dst_max, int ld_dst,
const int M_offset, const int N_offset, const int M, const int N, float scale,
int causal_offset, float alibi_slope, float tanh_scale, void* tmpcache,
size_t cachesize) {
#if CompileAVX512F()
if (alibi_slope == 0 && tanh_scale == 0)
return avx512f::scale_track_max_fp32_fp32<false, false>(src, src_step, dst, dst_max, ld_dst, M_offset, N_offset,
M, N, scale, causal_offset, alibi_slope, tanh_scale,
Expand All @@ -1957,8 +1959,8 @@ class ScaleTrackMax {
return avx512f::scale_track_max_fp32_fp32<true, false>(src, src_step, dst, dst_max, ld_dst, M_offset, N_offset, M,
N, scale, causal_offset, alibi_slope, tanh_scale, tmpcache,
cachesize);
else
return BTLA_CODE::NotSupport;
#endif
return BTLA_CODE::NotSupport;
}
};

Expand Down
135 changes: 67 additions & 68 deletions bestla/bestla/sycl/sycl_prologue_b.h
Original file line number Diff line number Diff line change
Expand Up @@ -376,76 +376,76 @@ class WeightS4Trans {
int constexpr TileK = 32;
int constexpr GroupK = SgSize * TileK;
auto ev = q->submit([&](sycl::handler& cgh) {
cgh.parallel_for(sycl::nd_range<1>(problem, group),
[=](sycl::nd_item<1> it) [[sycl::reqd_work_group_size(
1, 1, SgSize)]] [[intel::kernel_args_restrict]] [[intel::reqd_sub_group_size(SgSize)]] {
int g_idx = it.get_group(0);
auto sg = it.get_sub_group();
int sg_id = sg.get_local_id()[0];
int g_n = g_idx;
auto sptr = B_scale + g_n * ldb;
auto bptr = B + g_n * k / 2;
auto aptr = A;
auto cptr = C + g_n;
if constexpr (std::is_same_v<CType, sycl::half>) {
sycl::half2 tmpAcc = {0.f, 0.f};
for (int i = 0; i < k; i += GroupK * Unroll) {
cgh.parallel_for(
sycl::nd_range<1>(problem, group),
[=](sycl::nd_item<1> it) [[intel::kernel_args_restrict]] [[intel::reqd_sub_group_size(SgSize)]] {
int g_idx = it.get_group(0);
auto sg = it.get_sub_group();
int sg_id = sg.get_local_id()[0];
int g_n = g_idx;
auto sptr = B_scale + g_n * ldb;
auto bptr = B + g_n * k / 2;
auto aptr = A;
auto cptr = C + g_n;
if constexpr (std::is_same_v<CType, sycl::half>) {
sycl::half2 tmpAcc = {0.f, 0.f};
for (int i = 0; i < k; i += GroupK * Unroll) {
#pragma unroll
for (int iu = 0; iu < Unroll; iu++) {
uint8_t tmps8[TileK / 2];
*(sycl::vec<uint8_t, TileK / 2>*)tmps8 =
*(sycl::vec<uint8_t, TileK / 2>*)(bptr + sg_id * TileK / 2);
CType scale = *(sptr + sg_id * TileK / blocksize);
for (int iu = 0; iu < Unroll; iu++) {
uint8_t tmps8[TileK / 2];
*(sycl::vec<uint8_t, TileK / 2>*)tmps8 =
*(sycl::vec<uint8_t, TileK / 2>*)(bptr + sg_id * TileK / 2);
CType scale = *(sptr + sg_id * TileK / blocksize);
#pragma unroll
for (int ikk = 0; ikk < TileK; ikk += 2) {
sycl::half2 tmpA = *(sycl::half2*)&aptr[sg_id * TileK + ikk];
sycl::half2 tmpB = {static_cast<int8_t>((tmps8[ikk / 2] & 0x0f) - 8),
static_cast<int8_t>((tmps8[ikk / 2] >> 4) - 8)};
tmpAcc += tmpA * tmpB * scale;
}
sptr += GroupK / blocksize;
aptr += GroupK;
bptr += GroupK / 2;
}
}
sycl::half2 sum = {0.f, 0.f};
for (int i = 0; i < SgSize; i += 1) {
sum += sg.shuffle(tmpAcc, i);
}
if (sg_id == 0) {
*cptr = sum[0] + sum[1];
}
} else {
CType tmpAcc = 0.f;
int constexpr Unroll = 2;
for (int i = 0; i < k; i += GroupK * Unroll) {
for (int ikk = 0; ikk < TileK; ikk += 2) {
sycl::half2 tmpA = *(sycl::half2*)&aptr[sg_id * TileK + ikk];
sycl::half2 tmpB = {static_cast<int8_t>((tmps8[ikk / 2] & 0x0f) - 8),
static_cast<int8_t>((tmps8[ikk / 2] >> 4) - 8)};
tmpAcc += tmpA * tmpB * scale;
}
sptr += GroupK / blocksize;
aptr += GroupK;
bptr += GroupK / 2;
}
}
sycl::half2 sum = {0.f, 0.f};
for (int i = 0; i < SgSize; i += 1) {
sum += sg.shuffle(tmpAcc, i);
}
if (sg_id == 0) {
*cptr = sum[0] + sum[1];
}
} else {
CType tmpAcc = 0.f;
int constexpr Unroll = 2;
for (int i = 0; i < k; i += GroupK * Unroll) {
#pragma unroll
for (int iu = 0; iu < Unroll; iu++) {
uint8_t tmps8[TileK / 2];
*(sycl::vec<uint8_t, TileK / 2>*)tmps8 =
*(sycl::vec<uint8_t, TileK / 2>*)(bptr + sg_id * TileK / 2);
CType scale = *(sptr + sg_id * TileK / blocksize);
for (int iu = 0; iu < Unroll; iu++) {
uint8_t tmps8[TileK / 2];
*(sycl::vec<uint8_t, TileK / 2>*)tmps8 =
*(sycl::vec<uint8_t, TileK / 2>*)(bptr + sg_id * TileK / 2);
CType scale = *(sptr + sg_id * TileK / blocksize);
#pragma unroll
for (int ikk = 0; ikk < TileK; ikk += 2) {
tmpAcc += CType(aptr[sg_id * TileK + ikk]) *
static_cast<int8_t>((tmps8[ikk / 2] & 0x0f) - 8) * scale;
tmpAcc += CType(aptr[sg_id * TileK + ikk + 1]) *
static_cast<int8_t>((tmps8[ikk / 2] >> 4) - 8) * scale;
}
sptr += GroupK / blocksize;
aptr += GroupK;
bptr += GroupK / 2;
}
}
float sum = 0.f;
for (int i = 0; i < SgSize; i += 1) {
sum += sg.shuffle(tmpAcc, i);
}
if (sg_id == 0) {
*cptr = sum;
}
}
});
for (int ikk = 0; ikk < TileK; ikk += 2) {
tmpAcc +=
CType(aptr[sg_id * TileK + ikk]) * static_cast<int8_t>((tmps8[ikk / 2] & 0x0f) - 8) * scale;
tmpAcc +=
CType(aptr[sg_id * TileK + ikk + 1]) * static_cast<int8_t>((tmps8[ikk / 2] >> 4) - 8) * scale;
}
sptr += GroupK / blocksize;
aptr += GroupK;
bptr += GroupK / 2;
}
}
float sum = 0.f;
for (int i = 0; i < SgSize; i += 1) {
sum += sg.shuffle(tmpAcc, i);
}
if (sg_id == 0) {
*cptr = sum;
}
}
});
});
return ev;
} else {
Expand All @@ -458,8 +458,7 @@ class WeightS4Trans {
auto ev = q->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl::nd_range<1>(problem, group),
[=](sycl::nd_item<1> it) [[sycl::reqd_work_group_size(
1, 1, SgSize)]] [[intel::kernel_args_restrict]] [[intel::reqd_sub_group_size(SgSize)]] {
[=](sycl::nd_item<1> it) [[intel::kernel_args_restrict]] [[intel::reqd_sub_group_size(SgSize)]] {
int g_idx = it.get_group(0);
auto sg = it.get_sub_group();
int sg_id = sg.get_local_id()[0];
Expand Down
Loading

0 comments on commit fb7fc64

Please sign in to comment.