Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
update for MHA
Browse files Browse the repository at this point in the history
  • Loading branch information
yuchengliu1 committed Feb 29, 2024
1 parent add466e commit 7173011
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 16 deletions.
2 changes: 1 addition & 1 deletion bestla/bestla/bestla_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ class CpuRuntime {
return instances[thread];
}

inline float getPE() const { return 1.0f * P_core_num / E_core_num; }
inline float getPE() const { return 1.0f * P_core_num / E_core_num; }

inline void setPE(float& PE_) { PE = PE_; }

Expand Down
2 changes: 1 addition & 1 deletion bestla/bestla/bestla_parallel.h
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,7 @@ class SchedulerDispatcher {
Ecore_num = cr.E_core_num;
utils::GemmProblem problem_P = problem, problem_E = problem;
const int N = problem.dims[2];
const int N_offset = utils::padto(N - int(N / (1 + cr.getPE())),Scheduler::mStep[1]);
const int N_offset = utils::padto(N - int(N / (1 + cr.getPE())), Scheduler::mStep[1]);
problem_P.dims[2] = N_offset;
Scheduler_P = new Scheduler({th->num_threads() - cr.E_core_num, problem_P, {0, 0}, cr.mL2Cache_P, cr.mL1Cache_P});
problem_E.dims[2] = N - N_offset;
Expand Down
28 changes: 14 additions & 14 deletions neural_speed/core/layers/mha_dense_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -644,15 +644,15 @@ class mha_interface_t {

static_assert(GemmQK::MTILE == GemmPV::MTILE, "2 GEMM should have the same M_TILE.");

BTLA_CODE compute(const attn_fwd_args_t<Q_T, K_T, V_T, DST_T>& p, const parallel::IThreading& th) {
BTLA_CODE compute(const attn_fwd_args_t<Q_T, K_T, V_T, DST_T>& p, parallel::IThreading& th) {
static constexpr auto M_TILE = GemmQK::MTILE;
assert(p.Q_sc == 1 && p.K_sc == 1 && p.V_sc == 1 && p.dst_sc == 1);
assert(p.Q_layout == ATTN_FWD_LAYOUT_PLAIN && p.K_layout == ATTN_FWD_LAYOUT_PLAIN &&
p.V_layout == ATTN_FWD_LAYOUT_PLAIN && p.dst_layout == ATTN_FWD_LAYOUT_PLAIN);
assert(p.step_v_head_size == 1);
assert(p.step_k_head_size == 1 || p.step_k_sl == 1);
const auto num_heads = p.batch_size * p.head_num; // Total number of heads
device::CpuBase cb; // Note: DO NOT use cb.mNumThreads; use th.num_threads() instead
GetCPUDevice();

const bool is_causal = (p.attn_flags & NE_ATTN_FLAG_IS_CAUSAL) != 0;
const bool is_alibi = (p.attn_flags & NE_ATTN_FLAG_IS_ALIBI8) != 0;
Expand Down Expand Up @@ -698,7 +698,7 @@ class mha_interface_t {
const mha_problem_t problem = {p.batch_size, p.head_num, p.heads_kv, p.head_size, p.sl_q, p.sl_kv};
const auto m_tiles = updiv(p.sl_q, M_TILE);
const auto num_tasks = num_heads * m_tiles;
const Scheduler2D parl({th.num_threads(), {num_tasks, 1}, {1, 1}});
const Scheduler2D parl({th.num_threads(), {num_tasks, 1}, {1, 1}, {0, 0}});

th.parallel_for([&](int tid) {
{ // reorder K & V
Expand Down Expand Up @@ -760,8 +760,8 @@ class mha_interface_t {
typename parallel::gemm::ThreadProblemBase tpQK{
/* ThreadProblem2D */ {tid, {}, {i_m, 0}, {m_size, unmasked_size_pad_qk}, true},
/* .block = */ {M_TILE, GemmQK::NTILE, p.head_size},
/* .stacksize = */ cb.mL2Cache,
/* .tmpcachesize = */ cb.mL2Cache,
/* .stacksize = */ _cd->getL2CacheSize(),
/* .tmpcachesize = */ _cd->getL2CacheSize(),
};
const auto bf16_tmp = reinterpret_cast<bf16*>(tmp);
l_expsum.run( // QxK => S ==exp==> P
Expand Down Expand Up @@ -791,8 +791,8 @@ class mha_interface_t {
typename parallel::gemm::ThreadProblemBase tpPV{
/* ThreadProblem2D */ {tid, {}, {0, 0}, {m_size, p.head_size}, true},
/* .block = */ {M_TILE, GemmPV::NTILE, unmasked_size_pad_qk},
/* .stacksize = */ cb.mL2Cache,
/* .tmpcachesize = */ cb.mL2Cache,
/* .stacksize = */ _cd->getL2CacheSize(),
/* .tmpcachesize = */ _cd->getL2CacheSize(),
};
l_scale.run( // PxV => O
PVArgs{
Expand Down Expand Up @@ -1574,7 +1574,7 @@ class mha_stable_interface_t {
static_assert(GemmQK::MTILE == GemmPV::MTILE, "2 GEMM should have the same M_TILE.");
static constexpr auto M_TILE = GemmQK::MTILE;

BTLA_CODE compute(const attn_fwd_args_t<Q_T, K_T, V_T, DST_T>& p, const parallel::IThreading& th) {
BTLA_CODE compute(const attn_fwd_args_t<Q_T, K_T, V_T, DST_T>& p, parallel::IThreading& th) {
assert((std::is_same<Q_T, int8_t>::value || p.Q_sc == 1));
assert((std::is_same<K_T, int8_t>::value || p.K_sc == 1));
assert((std::is_same<V_T, int8_t>::value || p.V_sc == 1));
Expand Down Expand Up @@ -1603,7 +1603,7 @@ class mha_stable_interface_t {
assert((p.K_layout != ATTN_FWD_LAYOUT_PLAIN || p.step_v_head_size == 1));
assert((p.V_layout != ATTN_FWD_LAYOUT_PLAIN || p.step_k_sl == 1));
const auto num_heads = p.batch_size * p.head_num; // Total number of heads
device::CpuBase cb; // Note: DO NOT use cb.mNumThreads; use th.num_threads() instead
GetCPUDevice();
const bool is_causal = (p.attn_flags & NE_ATTN_FLAG_IS_CAUSAL) != 0;
const bool is_alibi = (p.attn_flags & NE_ATTN_FLAG_IS_ALIBI8) != 0;
const bool prefer_fp32 = (p.attn_flags & NE_ATTN_FLAG_PREFER_FP32) != 0;
Expand Down Expand Up @@ -1637,7 +1637,7 @@ class mha_stable_interface_t {
const auto num_tasks = num_heads * m_tiles;

using Scheduler2D = bestla::parallel::Scheduler2D;
const Scheduler2D parl({th.num_threads(), {num_tasks, 1}, {1, 1}}); // main parallel scheduler
const Scheduler2D parl({th.num_threads(), {num_tasks, 1}, {1, 1}, {0, 0}}); // main parallel scheduler

th.parallel_for([&](int tid) {
const int tmp_s_size = M_TILE * padto(padto(p.sl_kv, GemmQK::NTILE), GemmPV::KTILE);
Expand Down Expand Up @@ -1694,8 +1694,8 @@ class mha_stable_interface_t {
typename parallel::gemm::ThreadProblemBase tpQK{
/* ThreadProblem2D */ {tid, {}, {i_m, 0}, {m_size, unmasked_size_pad_qk}, true},
/* .block = */ {M_TILE, GemmQK::NTILE, p.head_size},
/* .stacksize = */ cb.mL2Cache,
/* .tmpcachesize = */ cb.mL2Cache,
/* .stacksize = */ _cd->getL2CacheSize(),
/* .tmpcachesize = */ _cd->getL2CacheSize(),
};
l_qk.run( // QxK => S ==exp==> P
QKArgs{
Expand Down Expand Up @@ -1749,8 +1749,8 @@ class mha_stable_interface_t {
typename parallel::gemm::ThreadProblemBase tpPV{
/* ThreadProblem2D */ {tid, {}, {0, 0}, {m_size, p.head_size}, true},
/* .block = */ {M_TILE, GemmPV::NTILE, unmasked_size_pad_pv},
/* .stacksize = */ cb.mL2Cache,
/* .tmpcachesize = */ cb.mL2Cache,
/* .stacksize = */ _cd->getL2CacheSize(),
/* .tmpcachesize = */ _cd->getL2CacheSize(),
};
l_pv.run( // PxV => O
PVArgs{
Expand Down

0 comments on commit 7173011

Please sign in to comment.