diff --git a/bestla/bestla/bestla_device.h b/bestla/bestla/bestla_device.h index bcb720c01..ca2660c3c 100644 --- a/bestla/bestla/bestla_device.h +++ b/bestla/bestla/bestla_device.h @@ -458,7 +458,7 @@ class CpuRuntime { return instances[thread]; } - inline float getPE() const { return 1.0f * P_core_num / E_core_num; } + inline float getPE() const { return 1.0f * P_core_num / E_core_num; } inline void setPE(float& PE_) { PE = PE_; } diff --git a/bestla/bestla/bestla_parallel.h b/bestla/bestla/bestla_parallel.h index d040fcfce..eb51e50bb 100644 --- a/bestla/bestla/bestla_parallel.h +++ b/bestla/bestla/bestla_parallel.h @@ -777,7 +777,7 @@ class SchedulerDispatcher { Ecore_num = cr.E_core_num; utils::GemmProblem problem_P = problem, problem_E = problem; const int N = problem.dims[2]; - const int N_offset = utils::padto(N - int(N / (1 + cr.getPE())),Scheduler::mStep[1]); + const int N_offset = utils::padto(N - int(N / (1 + cr.getPE())), Scheduler::mStep[1]); problem_P.dims[2] = N_offset; Scheduler_P = new Scheduler({th->num_threads() - cr.E_core_num, problem_P, {0, 0}, cr.mL2Cache_P, cr.mL1Cache_P}); problem_E.dims[2] = N - N_offset; diff --git a/neural_speed/core/layers/mha_dense_wrapper.h b/neural_speed/core/layers/mha_dense_wrapper.h index be586a305..97e3d7377 100644 --- a/neural_speed/core/layers/mha_dense_wrapper.h +++ b/neural_speed/core/layers/mha_dense_wrapper.h @@ -644,7 +644,7 @@ class mha_interface_t { static_assert(GemmQK::MTILE == GemmPV::MTILE, "2 GEMM should have the same M_TILE."); - BTLA_CODE compute(const attn_fwd_args_t& p, const parallel::IThreading& th) { + BTLA_CODE compute(const attn_fwd_args_t& p, parallel::IThreading& th) { static constexpr auto M_TILE = GemmQK::MTILE; assert(p.Q_sc == 1 && p.K_sc == 1 && p.V_sc == 1 && p.dst_sc == 1); assert(p.Q_layout == ATTN_FWD_LAYOUT_PLAIN && p.K_layout == ATTN_FWD_LAYOUT_PLAIN && @@ -652,7 +652,7 @@ class mha_interface_t { assert(p.step_v_head_size == 1); assert(p.step_k_head_size == 1 || p.step_k_sl == 1); const auto num_heads = p.batch_size * p.head_num; // Total number of heads - device::CpuBase cb; // Note: DO NOT use cb.mNumThreads; use th.num_threads() instead + GetCPUDevice(); const bool is_causal = (p.attn_flags & NE_ATTN_FLAG_IS_CAUSAL) != 0; const bool is_alibi = (p.attn_flags & NE_ATTN_FLAG_IS_ALIBI8) != 0; @@ -698,7 +698,7 @@ class mha_interface_t { const mha_problem_t problem = {p.batch_size, p.head_num, p.heads_kv, p.head_size, p.sl_q, p.sl_kv}; const auto m_tiles = updiv(p.sl_q, M_TILE); const auto num_tasks = num_heads * m_tiles; - const Scheduler2D parl({th.num_threads(), {num_tasks, 1}, {1, 1}}); + const Scheduler2D parl({th.num_threads(), {num_tasks, 1}, {1, 1}, {0, 0}}); th.parallel_for([&](int tid) { { // reorder K & V @@ -760,8 +760,8 @@ class mha_interface_t { typename parallel::gemm::ThreadProblemBase tpQK{ /* ThreadProblem2D */ {tid, {}, {i_m, 0}, {m_size, unmasked_size_pad_qk}, true}, /* .block = */ {M_TILE, GemmQK::NTILE, p.head_size}, - /* .stacksize = */ cb.mL2Cache, - /* .tmpcachesize = */ cb.mL2Cache, + /* .stacksize = */ _cd->getL2CacheSize(), + /* .tmpcachesize = */ _cd->getL2CacheSize(), }; const auto bf16_tmp = reinterpret_cast(tmp); l_expsum.run( // QxK => S ==exp==> P @@ -791,8 +791,8 @@ class mha_interface_t { typename parallel::gemm::ThreadProblemBase tpPV{ /* ThreadProblem2D */ {tid, {}, {0, 0}, {m_size, p.head_size}, true}, /* .block = */ {M_TILE, GemmPV::NTILE, unmasked_size_pad_qk}, - /* .stacksize = */ cb.mL2Cache, - /* .tmpcachesize = */ cb.mL2Cache, + /* .stacksize = */ _cd->getL2CacheSize(), + /* .tmpcachesize = */ _cd->getL2CacheSize(), }; l_scale.run( // PxV => O PVArgs{ @@ -1574,7 +1574,7 @@ class mha_stable_interface_t { static_assert(GemmQK::MTILE == GemmPV::MTILE, "2 GEMM should have the same M_TILE."); static constexpr auto M_TILE = GemmQK::MTILE; - BTLA_CODE compute(const attn_fwd_args_t& p, const parallel::IThreading& th) { + BTLA_CODE compute(const attn_fwd_args_t& p, parallel::IThreading& th) { assert((std::is_same::value || p.Q_sc == 1)); assert((std::is_same::value || p.K_sc == 1)); assert((std::is_same::value || p.V_sc == 1)); @@ -1603,7 +1603,7 @@ class mha_stable_interface_t { assert((p.K_layout != ATTN_FWD_LAYOUT_PLAIN || p.step_v_head_size == 1)); assert((p.V_layout != ATTN_FWD_LAYOUT_PLAIN || p.step_k_sl == 1)); const auto num_heads = p.batch_size * p.head_num; // Total number of heads - device::CpuBase cb; // Note: DO NOT use cb.mNumThreads; use th.num_threads() instead + GetCPUDevice(); const bool is_causal = (p.attn_flags & NE_ATTN_FLAG_IS_CAUSAL) != 0; const bool is_alibi = (p.attn_flags & NE_ATTN_FLAG_IS_ALIBI8) != 0; const bool prefer_fp32 = (p.attn_flags & NE_ATTN_FLAG_PREFER_FP32) != 0; @@ -1637,7 +1637,7 @@ class mha_stable_interface_t { const auto num_tasks = num_heads * m_tiles; using Scheduler2D = bestla::parallel::Scheduler2D; - const Scheduler2D parl({th.num_threads(), {num_tasks, 1}, {1, 1}}); // main parallel scheduler + const Scheduler2D parl({th.num_threads(), {num_tasks, 1}, {1, 1}, {0, 0}}); // main parallel scheduler th.parallel_for([&](int tid) { const int tmp_s_size = M_TILE * padto(padto(p.sl_kv, GemmQK::NTILE), GemmPV::KTILE); @@ -1694,8 +1694,8 @@ class mha_stable_interface_t { typename parallel::gemm::ThreadProblemBase tpQK{ /* ThreadProblem2D */ {tid, {}, {i_m, 0}, {m_size, unmasked_size_pad_qk}, true}, /* .block = */ {M_TILE, GemmQK::NTILE, p.head_size}, - /* .stacksize = */ cb.mL2Cache, - /* .tmpcachesize = */ cb.mL2Cache, + /* .stacksize = */ _cd->getL2CacheSize(), + /* .tmpcachesize = */ _cd->getL2CacheSize(), }; l_qk.run( // QxK => S ==exp==> P QKArgs{ @@ -1749,8 +1749,8 @@ class mha_stable_interface_t { typename parallel::gemm::ThreadProblemBase tpPV{ /* ThreadProblem2D */ {tid, {}, {0, 0}, {m_size, p.head_size}, true}, /* .block = */ {M_TILE, GemmPV::NTILE, unmasked_size_pad_pv}, - /* .stacksize = */ cb.mL2Cache, - /* .tmpcachesize = */ cb.mL2Cache, + /* .stacksize = */ _cd->getL2CacheSize(), + /* .tmpcachesize = */ _cd->getL2CacheSize(), }; l_pv.run( // PxV => O PVArgs{