From 851af6a68c9a29bbcdfcba891d432ff5115483bf Mon Sep 17 00:00:00 2001 From: "Liu, Yucheng" Date: Wed, 10 Jan 2024 17:12:44 +0800 Subject: [PATCH 01/45] update --- bestla/bestla/bestla_parallel.h | 43 +++++++++++++++++++++++++-------- 1 file changed, 33 insertions(+), 10 deletions(-) diff --git a/bestla/bestla/bestla_parallel.h b/bestla/bestla/bestla_parallel.h index c780aa4cf..f6d69e1d7 100644 --- a/bestla/bestla/bestla_parallel.h +++ b/bestla/bestla/bestla_parallel.h @@ -587,8 +587,8 @@ using thread_func = std::function; class IThreading { public: explicit IThreading(int nthreads) : mThreadNum(nthreads) {} - virtual void parallel_for(const thread_func& func) const = 0; - virtual inline void sync() const { assert(0); }; + virtual void parallel_for(const thread_func& func) = 0; + virtual inline void sync() const = 0; virtual int num_threads() const { return mThreadNum; }; virtual void set_threads(int nthreads) = 0; @@ -599,8 +599,7 @@ class IThreading { class OMPThreading : public IThreading { public: explicit OMPThreading(int nthreads) : IThreading(nthreads) { omp_set_num_threads(nthreads); } - void parallel_for(const thread_func& func) const override { - if (mThreadNum > 1) { + void parallel_for(const thread_func& func) override { #pragma omp parallel { int tidx = omp_get_thread_num(); @@ -623,16 +622,37 @@ class OMPThreading : public IThreading { class StdThreading : public IThreading { public: - explicit StdThreading(int nthreads) : IThreading(nthreads) {} - void parallel_for(const thread_func& func) const override { + explicit StdThreading(int nthreads) : IThreading(nthreads) { + thdset.resize(mThreadNum - 1); + locks.resize(mThreadNum - 1); + + for (size_t i = 0; i < mThreadNum - 1; i++) { + locks[i] = false; + thdset[i] = std::thread( + [&](int tidx) { + while (true) { + if (locks[tidx]) { + (*func_)(tidx + 1); + locks[tidx] = false; + } + } + }, + int(i)); + } + } + void parallel_for(const thread_func& func) override { if (mThreadNum > 1) { - std::vector thdset(mThreadNum - 1); + func_ = &func; for (size_t i = 0; i < mThreadNum - 1; i++) { - thdset[i] = std::thread([&](int tidx) { func(tidx); }, int(i + 1)); + locks[i] = true; } func(0); - for (size_t i = 0; i < mThreadNum - 1; i++) { - thdset[i].join(); + while (1) { + bool is_join = true; + for (size_t i = 0; is_join && i < mThreadNum - 1; i++) { + is_join &= locks[i]; + } + if (is_join) break; } } else { func(0); @@ -644,6 +664,9 @@ class StdThreading : public IThreading { inline void sync() const override { assert(0); } private: + std::vector thdset; + std::vector locks; + const thread_func* func_; }; class SingleThread : public StdThreading { From c2390c53025ae0bd6736c15548c72c08a9097248 Mon Sep 17 00:00:00 2001 From: yuchengliu1 Date: Wed, 10 Jan 2024 23:01:12 +0800 Subject: [PATCH 02/45] update --- bestla/bestla/bestla_parallel.h | 70 +++++++++++++++++++++------------ 1 file changed, 44 insertions(+), 26 deletions(-) diff --git a/bestla/bestla/bestla_parallel.h b/bestla/bestla/bestla_parallel.h index f6d69e1d7..4b2cf87cc 100644 --- a/bestla/bestla/bestla_parallel.h +++ b/bestla/bestla/bestla_parallel.h @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. #pragma once +#include #include #include #include @@ -622,24 +623,7 @@ class OMPThreading : public IThreading { class StdThreading : public IThreading { public: - explicit StdThreading(int nthreads) : IThreading(nthreads) { - thdset.resize(mThreadNum - 1); - locks.resize(mThreadNum - 1); - - for (size_t i = 0; i < mThreadNum - 1; i++) { - locks[i] = false; - thdset[i] = std::thread( - [&](int tidx) { - while (true) { - if (locks[tidx]) { - (*func_)(tidx + 1); - locks[tidx] = false; - } - } - }, - int(i)); - } - } + explicit StdThreading(int nthreads) : IThreading(nthreads) { create_threads(); } void parallel_for(const thread_func& func) override { if (mThreadNum > 1) { func_ = &func; @@ -647,26 +631,60 @@ class StdThreading : public IThreading { locks[i] = true; } func(0); - while (1) { - bool is_join = true; - for (size_t i = 0; is_join && i < mThreadNum - 1; i++) { - is_join &= locks[i]; + while (true) { + bool is_lock = false; + for (size_t i = 0; is_lock && i < mThreadNum - 1; i++) { + is_lock |= locks[i]; } - if (is_join) break; + if (!is_lock) break; } } else { func(0); } } - void set_threads(int nthreads) override { mThreadNum = nthreads; } + void set_threads(int nthreads) override { + stop_threads(); + mThreadNum = nthreads; + create_threads(); + } inline void sync() const override { assert(0); } + ~StdThreading() { stop_threads(); } + private: + void stop_threads() { + for (int i = 0; i < mThreadNum - 1; i++) stop[i] = true; + for (int i = 0; i < mThreadNum - 1; i++) thdset[i].join(); + } + void create_threads() { + printf("1111111\n"); + thdset.clear(); + thdset.resize(mThreadNum - 1); + locks.resize(mThreadNum - 1); + stop.resize(mThreadNum - 1); + + for (size_t i = 0; i < mThreadNum - 1; i++) { + stop[i] = false; + locks[i] = false; + thdset[i] = std::thread( + [&](int tidx) { + while (!stop[tidx]) { + _mm_pause(); + if (locks[tidx]) { + (*func_)(tidx + 1); + locks[tidx] = false; + } + } + }, + int(i)); + } + } + std::vector thdset; - std::vector locks; - const thread_func* func_; + std::vector locks, stop; + const thread_func* func_ = nullptr; }; class SingleThread : public StdThreading { From ba4f1074c406bb6cbfc64ea47ec97e7c349a7b74 Mon Sep 17 00:00:00 2001 From: yuchengliu1 Date: Mon, 15 Jan 2024 14:16:10 +0800 Subject: [PATCH 03/45] update for spinlock --- bestla/bestla/bestla_parallel.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bestla/bestla/bestla_parallel.h b/bestla/bestla/bestla_parallel.h index 4b2cf87cc..f6463b2db 100644 --- a/bestla/bestla/bestla_parallel.h +++ b/bestla/bestla/bestla_parallel.h @@ -633,7 +633,7 @@ class StdThreading : public IThreading { func(0); while (true) { bool is_lock = false; - for (size_t i = 0; is_lock && i < mThreadNum - 1; i++) { + for (size_t i = 0; !is_lock && i < mThreadNum - 1; i++) { is_lock |= locks[i]; } if (!is_lock) break; @@ -659,7 +659,6 @@ class StdThreading : public IThreading { for (int i = 0; i < mThreadNum - 1; i++) thdset[i].join(); } void create_threads() { - printf("1111111\n"); thdset.clear(); thdset.resize(mThreadNum - 1); locks.resize(mThreadNum - 1); @@ -671,10 +670,11 @@ class StdThreading : public IThreading { thdset[i] = std::thread( [&](int tidx) { while (!stop[tidx]) { - _mm_pause(); if (locks[tidx]) { (*func_)(tidx + 1); locks[tidx] = false; + } else { + _mm_pause(); } } }, From 8abef2f9ee7291f49ce91648c38096f41b5e2e24 Mon Sep 17 00:00:00 2001 From: yuchengliu1 Date: Sun, 4 Feb 2024 20:37:35 +0800 Subject: [PATCH 04/45] update base thread pool --- bestla/CMakeLists.txt | 4 +-- bestla/bestla/bestla_parallel.h | 47 ++++++++++++++------------ bestla/bestla/ut/bestla_epilogue.cpp | 2 ++ bestla/bestla/ut/bestla_gemm.cpp | 2 ++ bestla/bestla/ut/bestla_parallel.cpp | 3 ++ bestla/bestla/ut/bestla_prologue_a.cpp | 2 ++ bestla/bestla/ut/bestla_prologue_b.cpp | 2 ++ bestla/bestla/ut/bestla_ut.cpp | 10 ++++++ bestla/bestla/ut/bestla_ut.h | 6 ++-- bestla/bestla/ut/bestla_wrapper.cpp | 22 ++++++------ 10 files changed, 63 insertions(+), 37 deletions(-) diff --git a/bestla/CMakeLists.txt b/bestla/CMakeLists.txt index 2b17a8603..59c5954ec 100644 --- a/bestla/CMakeLists.txt +++ b/bestla/CMakeLists.txt @@ -18,8 +18,8 @@ option(BTLA_UT_KERNEL_JIT "Enable unit test for jit kernels" OFF) option(BTLA_UT_KERNEL_INTRIN "Enable unit test for intrinsic kernels" OFF) option(BTLA_UT_KERNEL_WRAPPER "Enable unit test for runtime ISA kernels" OFF) option(BTLA_UT_NOASAN "Disable sanitize" OFF) -option(BTLA_UT_BENCHMARK "Benchmark ON may take a long time to finish all tests" OFF) -option(BTLA_UT_OPENMP "Use OpenMP" ON) +option(BTLA_UT_BENCHMARK "Benchmark ON may take a long time to finish all tests" ON) +option(BTLA_UT_OPENMP "Use OpenMP" OFF) add_library(${PROJECT_NAME} INTERFACE) add_library(neural_speed::${PROJECT_NAME} ALIAS ${PROJECT_NAME}) diff --git a/bestla/bestla/bestla_parallel.h b/bestla/bestla/bestla_parallel.h index f6463b2db..abe78a461 100644 --- a/bestla/bestla/bestla_parallel.h +++ b/bestla/bestla/bestla_parallel.h @@ -626,17 +626,16 @@ class StdThreading : public IThreading { explicit StdThreading(int nthreads) : IThreading(nthreads) { create_threads(); } void parallel_for(const thread_func& func) override { if (mThreadNum > 1) { - func_ = &func; + running.store(mThreadNum - 1); for (size_t i = 0; i < mThreadNum - 1; i++) { - locks[i] = true; + func_[i] = &func; } func(0); while (true) { - bool is_lock = false; - for (size_t i = 0; !is_lock && i < mThreadNum - 1; i++) { - is_lock |= locks[i]; - } - if (!is_lock) break; + if (running.load() == 0) + break; + else + _mm_pause(); } } else { func(0); @@ -644,9 +643,11 @@ class StdThreading : public IThreading { } void set_threads(int nthreads) override { - stop_threads(); - mThreadNum = nthreads; - create_threads(); + if (nthreads != mThreadNum) { + stop_threads(); + mThreadNum = nthreads; + create_threads(); + } } inline void sync() const override { assert(0); } @@ -655,24 +656,25 @@ class StdThreading : public IThreading { private: void stop_threads() { - for (int i = 0; i < mThreadNum - 1; i++) stop[i] = true; + stop = true; for (int i = 0; i < mThreadNum - 1; i++) thdset[i].join(); + thdset.clear(); + // printf("stop %d\n", mThreadNum); } void create_threads() { - thdset.clear(); + // printf("create %d\n", mThreadNum); thdset.resize(mThreadNum - 1); - locks.resize(mThreadNum - 1); - stop.resize(mThreadNum - 1); + stop = false; for (size_t i = 0; i < mThreadNum - 1; i++) { - stop[i] = false; - locks[i] = false; thdset[i] = std::thread( [&](int tidx) { - while (!stop[tidx]) { - if (locks[tidx]) { - (*func_)(tidx + 1); - locks[tidx] = false; + while (true) { + if (stop.load() == true) break; + if (func_[tidx] != nullptr) { + (*func_[tidx])(tidx + 1); + func_[tidx] = nullptr; + running.fetch_sub(1); } else { _mm_pause(); } @@ -683,8 +685,9 @@ class StdThreading : public IThreading { } std::vector thdset; - std::vector locks, stop; - const thread_func* func_ = nullptr; + std::atomic_bool stop; + std::atomic_int running; + const thread_func* func_[100]; }; class SingleThread : public StdThreading { diff --git a/bestla/bestla/ut/bestla_epilogue.cpp b/bestla/bestla/ut/bestla_epilogue.cpp index 3d83035c3..d1d293b26 100644 --- a/bestla/bestla/ut/bestla_epilogue.cpp +++ b/bestla/bestla/ut/bestla_epilogue.cpp @@ -1,6 +1,7 @@ #include "bestla_epilogue.h" #include "bestla_ut.h" +#ifdef BTLA_UT_EPILOGUE namespace bestla { using namespace utils; namespace ut { @@ -139,3 +140,4 @@ static UT_AlphaBetaProcessFp32 sUT_AlphaBetaProcessFp32; #endif } // namespace ut } // namespace bestla +#endif diff --git a/bestla/bestla/ut/bestla_gemm.cpp b/bestla/bestla/ut/bestla_gemm.cpp index 43d86eab0..5e00a8911 100644 --- a/bestla/bestla/ut/bestla_gemm.cpp +++ b/bestla/bestla/ut/bestla_gemm.cpp @@ -2,6 +2,7 @@ #include "bestla_utils.h" #include "bestla_ut.h" +#ifdef BTLA_UT_GEMM namespace bestla { using namespace utils; @@ -1115,3 +1116,4 @@ static UT_GEMM_AMXINT8 sUT_GEMM_AMXINT8; #endif } // namespace ut } // namespace bestla +#endif \ No newline at end of file diff --git a/bestla/bestla/ut/bestla_parallel.cpp b/bestla/bestla/ut/bestla_parallel.cpp index 81e4eb899..f3b8b1669 100644 --- a/bestla/bestla/ut/bestla_parallel.cpp +++ b/bestla/bestla/ut/bestla_parallel.cpp @@ -4,6 +4,8 @@ #include "bestla_gemm.h" #include "bestla_ut.h" #include "bestla_prologue_a.h" + +#ifdef BTLA_UT_PARALLEL namespace bestla { using namespace utils; namespace ut { @@ -206,3 +208,4 @@ static UT_SchedulerGemmKBlockNew sUT_SchedulerGemmKBlockNew; #endif } // namespace ut } // namespace bestla +#endif \ No newline at end of file diff --git a/bestla/bestla/ut/bestla_prologue_a.cpp b/bestla/bestla/ut/bestla_prologue_a.cpp index 7cb4b6379..c0ae19c4c 100644 --- a/bestla/bestla/ut/bestla_prologue_a.cpp +++ b/bestla/bestla/ut/bestla_prologue_a.cpp @@ -2,6 +2,7 @@ #include "bestla_ut.h" #include "kernel_avx512f.h" +#ifdef BTLA_UT_PROLOGUE_A namespace bestla { using namespace utils; namespace ut { @@ -292,3 +293,4 @@ static UT_ShuffleActivationKblock sUT_ShuffleActivationKblock; #endif } // namespace ut } // namespace bestla +#endif \ No newline at end of file diff --git a/bestla/bestla/ut/bestla_prologue_b.cpp b/bestla/bestla/ut/bestla_prologue_b.cpp index 29e18e35d..34915a44c 100644 --- a/bestla/bestla/ut/bestla_prologue_b.cpp +++ b/bestla/bestla/ut/bestla_prologue_b.cpp @@ -5,6 +5,7 @@ #include "bestla_wrapper.h" #include "bestla_ut.h" +#ifdef BTLA_UT_PROLOGUE_B namespace bestla { using namespace utils; namespace ut { @@ -1889,3 +1890,4 @@ static UT_CompFp16 sUT_CompFp16; #endif } // namespace ut } // namespace bestla +#endif \ No newline at end of file diff --git a/bestla/bestla/ut/bestla_ut.cpp b/bestla/bestla/ut/bestla_ut.cpp index c3d00a944..d55ac5b56 100644 --- a/bestla/bestla/ut/bestla_ut.cpp +++ b/bestla/bestla/ut/bestla_ut.cpp @@ -1,5 +1,15 @@ #include +#include +namespace bestla { +namespace ut { +#ifdef _OPENMP +parallel::OMPThreading DefaultThreading(4); +#else +parallel::StdThreading DefaultThreading(4); +#endif // _OPNEMP +} // namespace ut +} // namespace bestla int main() { printf("BesTLA UT done\n"); return 0; diff --git a/bestla/bestla/ut/bestla_ut.h b/bestla/bestla/ut/bestla_ut.h index 9a7e3eefd..b570253b1 100644 --- a/bestla/bestla/ut/bestla_ut.h +++ b/bestla/bestla/ut/bestla_ut.h @@ -1,3 +1,5 @@ +#pragma once + #include #include #include "bestla_utils.h" @@ -25,9 +27,9 @@ using sAMX_INT8_US = gemm::ICoreRowNAmxint8<64, 16>; using sAMX_INT8_SS = gemm::ICoreRowNAmxint8SS<64, 16>; using sAVX2 = gemm::SCoreRowNAvx2<24, 4>; #ifdef _OPENMP -static parallel::OMPThreading DefaultThreading(4); +extern parallel::OMPThreading DefaultThreading; #else -static parallel::StdThreading DefaultThreading(4); +extern parallel::StdThreading DefaultThreading; #endif // _OPNEMP constexpr size_t CacheSize = size_t(100) << 10; diff --git a/bestla/bestla/ut/bestla_wrapper.cpp b/bestla/bestla/ut/bestla_wrapper.cpp index 9e97d04c1..8b9e9d415 100644 --- a/bestla/bestla/ut/bestla_wrapper.cpp +++ b/bestla/bestla/ut/bestla_wrapper.cpp @@ -7,11 +7,11 @@ class UT_Fp32Fp32 { public: UT_Fp32Fp32() { UT_START(); -#ifdef JBLAS_UT_BENCHMARK +#ifdef BTLA_UT_BENCHMARK benchmark_all(1, 4096, 4096, 32); benchmark_all(1024, 4096, 4096, 32); benchmark_all(2048, 4096, 4096, 32); -#endif // JBLAS_UT_BENCHMARK +#endif // BTLA_UT_BENCHMARK CheckISA(AVX2); ut(1, 1, 1); @@ -116,7 +116,7 @@ class UT_Fp32Fp32 { } } }; -#ifdef JBLAS_UT_WRAPPER +#ifdef BTLA_UT_WRAPPER static UT_Fp32Fp32 sUT_Fp32Fp32; #endif @@ -125,7 +125,7 @@ class UT_U8S8S32 { UT_U8S8S32() { UT_START(); GetCPUDevice(); -#ifdef JBLAS_UT_BENCHMARK +#ifdef BTLA_UT_BENCHMARK benchmark_all(1024, 4096, 4096, 32); benchmark_all(2048, 4096, 4096, 32); #endif @@ -269,7 +269,7 @@ class UT_U8S8S32 { } } }; -#ifdef JBLAS_UT_WRAPPER +#ifdef BTLA_UT_WRAPPER static UT_U8S8S32 sUT_U8S8S32; #endif @@ -278,7 +278,7 @@ class UT_S8S8S32 { UT_S8S8S32() { UT_START(); GetCPUDevice(); -#ifdef JBLAS_UT_BENCHMARK +#ifdef BTLA_UT_BENCHMARK benchmark_all(1024, 4096, 4096, 32); benchmark_all(2048, 4096, 4096, 32); #endif @@ -393,7 +393,7 @@ class UT_S8S8S32 { } } }; -#ifdef JBLAS_UT_WRAPPER +#ifdef BTLA_UT_WRAPPER static UT_S8S8S32 sUT_S8S8S32; #endif @@ -403,7 +403,7 @@ class UT_Bf16Bf16Fp32 { UT_START(); CheckISA(AMX_BF16); request_perm_xtile_data(); -#ifdef JBLAS_UT_BENCHMARK +#ifdef BTLA_UT_BENCHMARK benchmark_all(1024, 4096, 4096, 32); benchmark_all(2048, 4096, 4096, 32); #endif @@ -499,7 +499,7 @@ class UT_Bf16Bf16Fp32 { } } }; -#ifdef JBLAS_UT_WRAPPER +#ifdef BTLA_UT_WRAPPER static UT_Bf16Bf16Fp32 sUT_Bf16Bf16Fp32; #endif @@ -508,7 +508,7 @@ class UT_Fp16Fp16Fp16 { UT_Fp16Fp16Fp16() { UT_START(); CheckISA(AVX512_FP16); -#ifdef JBLAS_UT_BENCHMARK +#ifdef BTLA_UT_BENCHMARK benchmark_all(1024, 4096, 4096, 32); benchmark_all(2048, 4096, 4096, 32); #endif @@ -602,7 +602,7 @@ class UT_Fp16Fp16Fp16 { } } }; -#ifdef JBLAS_UT_WRAPPER +#ifdef BTLA_UT_WRAPPER static UT_Fp16Fp16Fp16 sUT_Fp16Fp16Fp16; #endif } // namespace ut From 5cdd61dc612eaac78eea5bc9a68004903787d872 Mon Sep 17 00:00:00 2001 From: yuchengliu1 Date: Sun, 4 Feb 2024 22:15:47 +0800 Subject: [PATCH 05/45] bond core --- bestla/bestla/bestla_parallel.h | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/bestla/bestla/bestla_parallel.h b/bestla/bestla/bestla_parallel.h index abe78a461..7e104a3ae 100644 --- a/bestla/bestla/bestla_parallel.h +++ b/bestla/bestla/bestla_parallel.h @@ -601,6 +601,7 @@ class OMPThreading : public IThreading { public: explicit OMPThreading(int nthreads) : IThreading(nthreads) { omp_set_num_threads(nthreads); } void parallel_for(const thread_func& func) override { + if (mThreadNum > 1) { #pragma omp parallel { int tidx = omp_get_thread_num(); @@ -665,10 +666,11 @@ class StdThreading : public IThreading { // printf("create %d\n", mThreadNum); thdset.resize(mThreadNum - 1); stop = false; - + bestla::device::CpuDevice::core_bond(0); for (size_t i = 0; i < mThreadNum - 1; i++) { thdset[i] = std::thread( [&](int tidx) { + bestla::device::CpuDevice::core_bond(tidx+1); while (true) { if (stop.load() == true) break; if (func_[tidx] != nullptr) { From b572c67c8190d083f30986ae21dce424cf017786 Mon Sep 17 00:00:00 2001 From: yuchengliu1 Date: Tue, 20 Feb 2024 16:32:29 +0800 Subject: [PATCH 06/45] using singleton instead of extern --- bestla/bestla/bestla_parallel.h | 10 +- bestla/bestla/ut/bestla_parallel.cpp | 4 +- bestla/bestla/ut/bestla_prologue_a.cpp | 8 +- bestla/bestla/ut/bestla_prologue_b.cpp | 132 ++++++++++++------------- bestla/bestla/ut/bestla_ut.cpp | 9 -- bestla/bestla/ut/bestla_ut.h | 25 +++-- bestla/bestla/ut/bestla_wrapper.cpp | 50 +++++----- bestla/bestla/ut/kernel_jit.cpp | 6 +- 8 files changed, 125 insertions(+), 119 deletions(-) diff --git a/bestla/bestla/bestla_parallel.h b/bestla/bestla/bestla_parallel.h index 7e104a3ae..9b424f5c3 100644 --- a/bestla/bestla/bestla_parallel.h +++ b/bestla/bestla/bestla_parallel.h @@ -599,7 +599,10 @@ class IThreading { #if BTLA_OPENMP class OMPThreading : public IThreading { public: - explicit OMPThreading(int nthreads) : IThreading(nthreads) { omp_set_num_threads(nthreads); } + explicit OMPThreading(int nthreads) : IThreading(nthreads) { + printf("Using OMP\n"); + omp_set_num_threads(nthreads); + } void parallel_for(const thread_func& func) override { if (mThreadNum > 1) { #pragma omp parallel @@ -624,7 +627,10 @@ class OMPThreading : public IThreading { class StdThreading : public IThreading { public: - explicit StdThreading(int nthreads) : IThreading(nthreads) { create_threads(); } + explicit StdThreading(int nthreads) : IThreading(nthreads) { + printf("Using Std\n"); + create_threads(); + } void parallel_for(const thread_func& func) override { if (mThreadNum > 1) { running.store(mThreadNum - 1); diff --git a/bestla/bestla/ut/bestla_parallel.cpp b/bestla/bestla/ut/bestla_parallel.cpp index f3b8b1669..13a863946 100644 --- a/bestla/bestla/ut/bestla_parallel.cpp +++ b/bestla/bestla/ut/bestla_parallel.cpp @@ -27,7 +27,7 @@ class UT_OMPThreading { kernel::wrapper::Transpose2D::template forward(src.data(), ref.data(), row, col, col, row); parallel::Scheduler2D _para({threads, row, col, 1, 1}); - DefaultThreading.parallel_for([&](int tidx) { + UT_Threading::get()->parallel_for([&](int tidx) { parallel::ThreadProblem2D thdp{tidx}; _para.getIndex(thdp); if (thdp.valid) { @@ -61,7 +61,7 @@ class UT_StdThreading { kernel::wrapper::Transpose2D::template forward(src.data(), ref.data(), row, col, col, row); parallel::Scheduler2D _para({threads, row, col, 1, 1}); - DefaultThreading.parallel_for([&](int tidx) { + UT_Threading::get()->parallel_for([&](int tidx) { parallel::ThreadProblem2D thdp{tidx}; _para.getIndex(thdp); if (thdp.valid) { diff --git a/bestla/bestla/ut/bestla_prologue_a.cpp b/bestla/bestla/ut/bestla_prologue_a.cpp index c0ae19c4c..5f7795479 100644 --- a/bestla/bestla/ut/bestla_prologue_a.cpp +++ b/bestla/bestla/ut/bestla_prologue_a.cpp @@ -131,7 +131,7 @@ class UT_ActivationU8KBlockQuantize { auto quanAct = actA.createStorage(m, k, kblock, hasreduce); avector bufA(quanAct.mSize); quanAct.assign(bufA.data()); - actA.quantize({raw.data(), lda, &quanAct}, m, k, &DefaultThreading); + actA.quantize({raw.data(), lda, &quanAct}, m, k, UT_Threading::get()); ut::buffer_error(q.data(), quanAct.template APtr(), q.size(), uint8_t(1)); ut::buffer_error(zp.data(), quanAct.template ZPtr(), zp.size(), uint8_t(1)); @@ -186,7 +186,7 @@ class UT_ActivationS8KBlockQuantize { auto quanAct = actA.createStorage(m, k, kblock, hasreduce); avector bufA(quanAct.mSize); quanAct.assign(bufA.data()); - actA.quantize({raw.data(), k, &quanAct}, m, k, &DefaultThreading); + actA.quantize({raw.data(), k, &quanAct}, m, k, UT_Threading::get()); ut::buffer_error(q.data(), quanAct.template APtr(), q.size(), int8_t(1)); if (hasreduce) { avector redref(reduce.size(), 0.f), redqref(reduce.size(), 0.f); @@ -235,7 +235,7 @@ class UT_ShuffleActivationKblock { auto reordA = kernel.createReorderStorage(m, k, 32); avector bufA(reordA.mSize); reordA.assign(bufA.data()); - kernel.preprocess({src.data(), k, nullptr, indices.data(), &reordA}, m, k, 32, &DefaultThreading); + kernel.preprocess({src.data(), k, nullptr, indices.data(), &reordA}, m, k, 32, UT_Threading::get()); kernel.getActivation(&dstptr, &dststride, {src.data(), k, nullptr, indices.data(), &reordA}, m, kpad, 0, 0, cache, CacheSize); @@ -272,7 +272,7 @@ class UT_ShuffleActivationKblock { avector bufA(quanAct.mSize + reordAct.mSize); quanAct.assign(bufA.data()); reordAct.assign(bufA.data() + quanAct.mSize); - actA.quantize({raw_cp.data(), k, &quanAct, indices.data(), &reordAct}, m, k, &DefaultThreading); + actA.quantize({raw_cp.data(), k, &quanAct, indices.data(), &reordAct}, m, k, UT_Threading::get()); ut::buffer_error(quanAct.template APtr(), q.data(), q.size(), int8_t(1)); if (hasreduce) { avector redref(reduce.size(), 0.f), redqref(reduce.size(), 0.f); diff --git a/bestla/bestla/ut/bestla_prologue_b.cpp b/bestla/bestla/ut/bestla_prologue_b.cpp index 34915a44c..9091c6205 100644 --- a/bestla/bestla/ut/bestla_prologue_b.cpp +++ b/bestla/bestla/ut/bestla_prologue_b.cpp @@ -69,11 +69,11 @@ class UT_BlockQunatize_INT8 { auto ptr = kernel.createStorage(n, k, blocksize, BTLA_DTYPE::S8, bestla_dtype, bestla_dtype, asym); avector buffer(ptr.mSize); ptr.assign(buffer.data()); - kernel.packWeight(n, k, dequanRef.data(), ldb, &ptr, &DefaultThreading); + kernel.packWeight(n, k, dequanRef.data(), ldb, &ptr, UT_Threading::get()); avector dequant(n * k); - kernel.unpackWeight(n, k, &ptr, dequant.data(), n, &DefaultThreading); + kernel.unpackWeight(n, k, &ptr, dequant.data(), n, UT_Threading::get()); avector ws8(n * k); - kernel.unpackWeight(n, k, &ptr, ws8.data(), n, &DefaultThreading); + kernel.unpackWeight(n, k, &ptr, ws8.data(), n, UT_Threading::get()); ut::buffer_error(quanW.data(), ws8.data(), ws8.size(), (int8_t)1); ut::buffer_error(dequanRef.data(), dequant.data(), dequanRef.size(), 0.01f); } @@ -119,13 +119,13 @@ class UT_BlockQunatize_INT8 { auto ptr = kernel.createStorage(n, k, blocksize, BTLA_DTYPE::S8, bestla_dtype, bestla_dtype, asym); avector buffer(ptr.mSize); ptr.assign(buffer.data()); - kernel.packTransposeWeight(n, k, dequanT.data(), k, &ptr, &DefaultThreading); + kernel.packTransposeWeight(n, k, dequanT.data(), k, &ptr, UT_Threading::get()); avector dequant(n * k), tardequanT(k * n); - kernel.unpackWeight(n, k, &ptr, dequant.data(), n, &DefaultThreading); - kernel.unpackTransposeWeight(n, k, &ptr, tardequanT.data(), k, &DefaultThreading); + kernel.unpackWeight(n, k, &ptr, dequant.data(), n, UT_Threading::get()); + kernel.unpackTransposeWeight(n, k, &ptr, tardequanT.data(), k, UT_Threading::get()); ut::buffer_error(dequanT.data(), tardequanT.data(), tardequanT.size(), 0.01f); avector ws8(n * k); - kernel.unpackWeight(n, k, &ptr, ws8.data(), n, &DefaultThreading); + kernel.unpackWeight(n, k, &ptr, ws8.data(), n, UT_Threading::get()); ut::buffer_error(quanW.data(), ws8.data(), ws8.size(), (int8_t)1); ut::buffer_error(dequanRef.data(), dequant.data(), dequanRef.size(), 0.01f); } @@ -160,12 +160,12 @@ class UT_BlockQunatize_F8 { avector ref_buffer(ptr.mSize); ptr.assign(buffer.data()); ref_ptr.assign(ref_buffer.data()); - kernel.packWeight(n, k, raw.data(), ldb, &ptr, &DefaultThreading); - ref_ker.packWeight(n, k, raw.data(), ldb, &ref_ptr, &DefaultThreading); + kernel.packWeight(n, k, raw.data(), ldb, &ptr, UT_Threading::get()); + ref_ker.packWeight(n, k, raw.data(), ldb, &ref_ptr, UT_Threading::get()); avector dequant(n * k, 0); avector ref_dequant(n * k, 0); - kernel.unpackWeight(n, k, &ptr, dequant.data(), n, &DefaultThreading); - ref_ker.unpackWeight(n, k, &ref_ptr, ref_dequant.data(), n, &DefaultThreading); + kernel.unpackWeight(n, k, &ptr, dequant.data(), n, UT_Threading::get()); + ref_ker.unpackWeight(n, k, &ref_ptr, ref_dequant.data(), n, UT_Threading::get()); ut::buffer_error(ref_dequant.data(), dequant.data(), dequant.size(), 0.01f); } }; @@ -351,10 +351,10 @@ class UT_TransposeBlockQuantize_F4 { avector buf(packedW.mSize), buf1(packedW1.mSize); packedW.assign(buf.data()); packedW1.assign(buf1.data()); - kernel.packTransposeWeight(n, k, dequanRef.data(), k, &packedW, &DefaultThreading); - kernel.packQWeight(n, k, quanW.data(), ldb, scales.data(), nullptr, &packedW1, &DefaultThreading); + kernel.packTransposeWeight(n, k, dequanRef.data(), k, &packedW, UT_Threading::get()); + kernel.packQWeight(n, k, quanW.data(), ldb, scales.data(), nullptr, &packedW1, UT_Threading::get()); avector dequant(n * k); - kernel.unpackTransposeWeight(n, k, &packedW1, dequant.data(), k, &DefaultThreading); + kernel.unpackTransposeWeight(n, k, &packedW1, dequant.data(), k, UT_Threading::get()); if (SCA_T != BTLA_DTYPE::DQ8_BNB) { ut::buffer_error(packedW.SPtr(), packedW1.SPtr(), packedW1.CSize()); ut::buffer_error(dequanRef.data(), dequant.data(), dequant.size()); @@ -417,11 +417,11 @@ class UT_BlockQuantize_INT4 { auto packedW = kernel.createStorage(n, k, blocksize, qtype, bestla_dtype, bestla_dtype, asym); avector buffer(packedW.mSize); packedW.assign(buffer.data()); - kernel.packWeight(n, k, dequant.data(), ldb, &packedW, &DefaultThreading); + kernel.packWeight(n, k, dequant.data(), ldb, &packedW, UT_Threading::get()); avector unpackf32(dequant.size()); avector unpack512f32(dequant.size()); - kernel.unpackWeight(n, k, &packedW, unpackf32.data(), n, &DefaultThreading); - kernel512.unpackWeight(n, k, &packedW, unpack512f32.data(), n, &DefaultThreading); + kernel.unpackWeight(n, k, &packedW, unpackf32.data(), n, UT_Threading::get()); + kernel512.unpackWeight(n, k, &packedW, unpack512f32.data(), n, UT_Threading::get()); ut::buffer_error(unpackf32.data(), unpack512f32.data(), unpackf32.size(), 0.01f); } void ut_512vnni(int n, int k, int blocksize, BTLA_DTYPE qtype, bool asym = false) { @@ -459,9 +459,9 @@ class UT_BlockQuantize_INT4 { auto packedW = kernel.createStorage(n, k, blocksize, qtype, bestla_dtype, bestla_dtype, asym); avector buffer(packedW.mSize); packedW.assign(buffer.data()); - kernel.packWeight(n, k, dequant.data(), ldb, &packedW, &DefaultThreading); + kernel.packWeight(n, k, dequant.data(), ldb, &packedW, UT_Threading::get()); avector unpackf32(dequant.size()); - kernel.unpackWeight(n, k, &packedW, unpackf32.data(), n, &DefaultThreading); + kernel.unpackWeight(n, k, &packedW, unpackf32.data(), n, UT_Threading::get()); int lsb = 16; float err_thres = lsb * 0.01f; // lsb*max_scale ut::buffer_error(dequant.data(), unpackf32.data(), dequant.size(), err_thres); @@ -549,7 +549,7 @@ class UT_ShuffleIndices { } avector buf0(packedW.mSize), buf1(packedW.mSize); packedW.assign(buf0.data()); - ProWei.setShuffleIndices(groupindices.data(), &packedW, &DefaultThreading); + ProWei.setShuffleIndices(groupindices.data(), &packedW, UT_Threading::get()); buffer_error(reflut.data(), packedW.ShfIndice(), reflut.size()); storage::gemm::StorageWeightKBlockNInteger tmp(GemmCore::ID); @@ -585,7 +585,7 @@ class UT_ShuffleIndices { rordA.assign(bufA.data()); typename Launcher::Param args{ gp, {aarray.data(), k, nullptr, wptr_->ShfIndice(), &rordA}, {wptr_}, {output.data(), n}}; - parallel::GemmRunWithA>(kernel, args, &DefaultThreading); + parallel::GemmRunWithA>(kernel, args, UT_Threading::get()); } else { using Launcher = @@ -603,7 +603,7 @@ class UT_ShuffleIndices { redA.template RPtr(), redA.lda}; typename Launcher::Param args{ gp, {aarray.data(), k, &redA, wptr_->ShfIndice(), &rordA}, {wptr_}, blkargs, {output.data(), n}}; - parallel::GemmRunWithA>(kernel, args, &DefaultThreading); + parallel::GemmRunWithA>(kernel, args, UT_Threading::get()); } ut::buffer_error(output.data(), oarray.data(), output.size()); @@ -741,9 +741,9 @@ class UT_CompFp32 { avector matBf32(k * n), matAf32(m * k), matC(m * n), refC(m * n), refCupk(m * n); fill_buffer_randn(matBf32.data(), matBf32.size(), -0.5f, 0.5f); fill_buffer_randn(matAf32.data(), matAf32.size(), -0.5f, 0.5f); - launcher.mProB.packWeight(n, k, matBf32.data(), n, &packedw, &DefaultThreading); + launcher.mProB.packWeight(n, k, matBf32.data(), n, &packedw, UT_Threading::get()); gemmref_fp32fp32fp32(m, n, k, matAf32.data(), matBf32.data(), refC.data(), k, n, n); - launcher.mProB.unpackWeight(n, k, &packedw, matBf32.data(), n, &DefaultThreading); + launcher.mProB.unpackWeight(n, k, &packedw, matBf32.data(), n, UT_Threading::get()); gemmref_fp32fp32fp32(m, n, k, matAf32.data(), matBf32.data(), refCupk.data(), k, n, n); utils::GemmProblem gp(1, m, n, k, blocksize); typename Launcher::Param args{gp, @@ -751,7 +751,7 @@ class UT_CompFp32 { {&packedw}, {packedw.template SPtr(), packedw.SDtype(), packedw.CStep()}, {matC.data(), n}}; - parallel::GemmRun(launcher, args, &DefaultThreading); + parallel::GemmRun(launcher, args, UT_Threading::get()); auto err = INT8_ERR; auto dbits = bestla_dtype_bits(qtype); auto type = bestla_dtype_type(qtype); @@ -788,9 +788,9 @@ class UT_CompFp32 { avector matBf32(k * n), matAf32(m * k), matC(m * n), refC(m * n), refCupk(m * n); fill_buffer_randn(matBf32.data(), matBf32.size(), -0.5f, 0.5f); fill_buffer_randn(matAf32.data(), matAf32.size(), -0.5f, 0.5f); - launcher.mProB.packWeight(n, k, matBf32.data(), n, &packedw, &DefaultThreading); + launcher.mProB.packWeight(n, k, matBf32.data(), n, &packedw, UT_Threading::get()); gemmref_fp32fp32fp32(m, n, k, matAf32.data(), matBf32.data(), refC.data(), k, n, n); - launcher.mProB.unpackWeight(n, k, &packedw, matBf32.data(), n, &DefaultThreading); + launcher.mProB.unpackWeight(n, k, &packedw, matBf32.data(), n, UT_Threading::get()); gemmref_fp32fp32fp32(m, n, k, matAf32.data(), matBf32.data(), refCupk.data(), k, n, n); GemmProblem gp(1, m, n, k, blocksize); typename Launcher::Param args{gp, @@ -798,7 +798,7 @@ class UT_CompFp32 { {&packedw}, {packedw.template SPtr(), packedw.SDtype(), packedw.CStep()}, {matC.data(), n}}; - parallel::GemmRun(launcher, args, &DefaultThreading); + parallel::GemmRun(launcher, args, UT_Threading::get()); auto err = FP4_ERR; if (qtype == BTLA_DTYPE::F8_E5M2 || qtype == BTLA_DTYPE::F8_E4M3) err = F8_ERR; @@ -864,7 +864,7 @@ class UTBenchmark_CompFp32 { using Launcher = wrapper::gemm::LauncherBase; Launcher kernel; - DefaultThreading.set_threads(threads); + UT_Threading::set_threads(threads); auto corestr = gemm::CoreAttr::to_str(Core_T::ID); utils::timer tm; using WType = typename Wei::StorageWeight; @@ -883,7 +883,7 @@ class UTBenchmark_CompFp32 { packBs[i] = tmpB; packBs[i].assign(bufB.data() + i * tmpB.mSize); } - kernel.mProB.packWeight(n, k, B, n, &packBs[0], &DefaultThreading); + kernel.mProB.packWeight(n, k, B, n, &packBs[0], UT_Threading::get()); for (size_t i = 1; i < batch; i++) { memcpy(packBs[i].template WPtr(), packBs[0].template WPtr(), packBs[0].template WSize()); memcpy(packBs[i].template SPtr(), packBs[0].template SPtr(), packBs[0].CSize() * sizeof(Scale_T)); @@ -896,7 +896,7 @@ class UTBenchmark_CompFp32 { log.start(); GemmProblem gp(1, m, n, k); typename Launcher::Param args{gp, {A + i * m * k, k}, {&packBs[i]}, {C + i * m * n, n}}; - parallel::GemmRun(kernel, args, &DefaultThreading); + parallel::GemmRun(kernel, args, UT_Threading::get()); if (log.stop()) { double flops = double(psize) / log.avg_val / 1e6; double band = double(memsize) / log.avg_val / 1e6; @@ -921,7 +921,7 @@ class UTBenchmark_CompFp32 { prologue_b::gemm::WeightKBlockNInteger, epilogue::gemm::AccumulatorWriteBackFp32>; Launcher kernel; - DefaultThreading.set_threads(threads); + UT_Threading::set_threads(threads); auto corestr = gemm::CoreAttr::to_str(Core_T::ID); utils::timer tm; using WType = typename Wei::StorageWeight; @@ -943,7 +943,7 @@ class UTBenchmark_CompFp32 { packBs[i] = tmpB; packBs[i].assign(bufB.data() + i * tmpB.mSize); } - kernel.mProB.packWeight(n, k, B, n, &packBs[0], &DefaultThreading); + kernel.mProB.packWeight(n, k, B, n, &packBs[0], UT_Threading::get()); for (size_t i = 1; i < batch; i++) { memcpy(packBs[i].template WPtr(), packBs[0].template WPtr(), packBs[0].template WSize()); memcpy(packBs[i].template SPtr(), packBs[0].template SPtr(), packBs[0].CSize() * sizeof(Scale_T)); @@ -960,8 +960,8 @@ class UTBenchmark_CompFp32 { {&packBs[i]}, // {packBs[i].template SPtr(), packBs[i].SDtype(), packBs[i].CStep()}, {C + i * m * n, n}}; - // parallel::GemmRun(kernel, args, &DefaultThreading); - parallel::GemmRunWithA(kernel, args, &DefaultThreading); + // parallel::GemmRun(kernel, args, UT_Threading::get()); + parallel::GemmRunWithA(kernel, args, UT_Threading::get()); } if (log.stop()) { double t = log.avg_val / batch; @@ -1139,16 +1139,16 @@ class UT_CompInt8 { reduceAf32[i * kblks + j / blocksize] += matAf32[i * k + j]; } } - launcher.mProB.packWeight(n, k, matBf32.data(), n, &packedw, &DefaultThreading); + launcher.mProB.packWeight(n, k, matBf32.data(), n, &packedw, UT_Threading::get()); gemmref_fp32fp32fp32(m, n, k, matAf32.data(), matBf32.data(), refC.data(), k, n, n); - launcher.mProB.unpackWeight(n, k, &packedw, matBf32.data(), n, &DefaultThreading); + launcher.mProB.unpackWeight(n, k, &packedw, matBf32.data(), n, UT_Threading::get()); gemmref_fp32fp32fp32(m, n, k, matAf32.data(), matBf32.data(), refCupk.data(), k, n, n); auto quanA = launcher.mProA.createStorage(m, k, blocksize, isAsym); utils::avector bufferA(quanA.mSize); quanA.assign(bufferA.data()); GemmProblem gp(1, m, n, k, blocksize); typename Launcher::Param args{gp, {matAf32.data(), k, &quanA}, {&packedw}, {matC.data(), n}}; - parallel::GemmRunWithA(launcher, args, &DefaultThreading); + parallel::GemmRunWithA(launcher, args, UT_Threading::get()); auto err = INT8_ERR; auto dbits = bestla_dtype_bits(qtype); auto type = bestla_dtype_type(qtype); @@ -1204,9 +1204,9 @@ class UT_CompInt8 { reduceAf32[i * kblks + j / blocksize] += matAf32[i * k + j]; } } - launcher.mProB.packWeight(n, k, matBf32.data(), n, &packedw, &DefaultThreading); + launcher.mProB.packWeight(n, k, matBf32.data(), n, &packedw, UT_Threading::get()); gemmref_fp32fp32fp32(m, n, k, matAf32.data(), matBf32.data(), refC.data(), k, n, n); - launcher.mProB.unpackWeight(n, k, &packedw, matBf32.data(), n, &DefaultThreading); + launcher.mProB.unpackWeight(n, k, &packedw, matBf32.data(), n, UT_Threading::get()); gemmref_fp32fp32fp32(m, n, k, matAf32.data(), matBf32.data(), refCupk.data(), k, n, n); GemmProblem gp(1, m, n, k, blocksize); typename Launcher::Param args{ @@ -1217,7 +1217,7 @@ class UT_CompInt8 { packedw.template RPtr(), packedw.RDtype(), isAsym ? packedw.template ZPtr() : nullptr, isAsym ? reduceAf32.data() : nullptr, blocksize}, {matC.data(), n}}; - parallel::GemmRun(launcher, args, &DefaultThreading); + parallel::GemmRun(launcher, args, UT_Threading::get()); auto err = INT8_ERR; auto dbits = bestla_dtype_bits(qtype); auto type = bestla_dtype_type(qtype); @@ -1272,9 +1272,9 @@ class UT_CompInt8 { reduceAf32[i * kblks + j / blocksize] += matAf32[i * k + j]; } } - launcher.mProB.packWeight(n, k, matBf32.data(), n, &packedw, &DefaultThreading); + launcher.mProB.packWeight(n, k, matBf32.data(), n, &packedw, UT_Threading::get()); gemmref_fp32fp32fp32(m, n, k, matAf32.data(), matBf32.data(), refC.data(), k, n, n); - launcher.mProB.unpackWeight(n, k, &packedw, matBf32.data(), n, &DefaultThreading); + launcher.mProB.unpackWeight(n, k, &packedw, matBf32.data(), n, UT_Threading::get()); gemmref_fp32fp32fp32(m, n, k, matAf32.data(), matBf32.data(), refCupk.data(), k, n, n); GemmProblem gp(1, m, n, k, blocksize); typename Launcher::Param args{ @@ -1285,7 +1285,7 @@ class UT_CompInt8 { quanA.CStep(), quanA.template ZPtr(), packedw.template RPtr(), packedw.RDtype(), packedw.template ZPtr(), quanA.template RPtr(), blocksize}, {matC.data(), n}}; - parallel::GemmRunWithA(launcher, args, &DefaultThreading); + parallel::GemmRunWithA(launcher, args, UT_Threading::get()); auto err = INT8_ERR; auto dbits = bestla_dtype_bits(qtype); auto type = bestla_dtype_type(qtype); @@ -1334,9 +1334,9 @@ class UT_CompInt8 { reduceAf32[i * kblks + j / blocksize] += matAf32[i * k + j]; } } - launcher.mProB.packWeight(n, k, matBf32.data(), n, &packedw, &DefaultThreading); + launcher.mProB.packWeight(n, k, matBf32.data(), n, &packedw, UT_Threading::get()); gemmref_fp32fp32fp32(m, n, k, matAf32.data(), matBf32.data(), refC.data(), k, n, n); - launcher.mProB.unpackWeight(n, k, &packedw, matBf32.data(), n, &DefaultThreading); + launcher.mProB.unpackWeight(n, k, &packedw, matBf32.data(), n, UT_Threading::get()); gemmref_fp32fp32fp32(m, n, k, matAf32.data(), matBf32.data(), refCupk.data(), k, n, n); GemmProblem gp(1, m, n, k, blocksize); typename Launcher::Param args{ @@ -1346,7 +1346,7 @@ class UT_CompInt8 { {packedw.template SPtr(), packedw.SDtype(), packedw.CStep(), scaleAf32.data(), kblks, nullptr, nullptr, bestla_dtype, packedw.template ZPtr(), reduceAf32.data(), blocksize}, {matC.data(), n}}; - parallel::GemmRun(launcher, args, &DefaultThreading); + parallel::GemmRun(launcher, args, UT_Threading::get()); auto err = INT8_ERR; auto dbits = bestla_dtype_bits(qtype); auto type = bestla_dtype_type(qtype); @@ -1446,9 +1446,9 @@ class UT_CompBf16 { for (size_t i = 0; i < matBf32.size(); i++) { matBf32[i] = matBbf16[i]; } - launcher.mProB.packWeight(n, k, matBf32.data(), n, &packedw, &DefaultThreading); + launcher.mProB.packWeight(n, k, matBf32.data(), n, &packedw, UT_Threading::get()); gemmref_bf16bf16fp32(m, n, k, matAbf16.data(), matBbf16.data(), refC.data(), k, n, n); - launcher.mProB.unpackWeight(n, k, &packedw, matBf32.data(), n, &DefaultThreading); + launcher.mProB.unpackWeight(n, k, &packedw, matBf32.data(), n, UT_Threading::get()); for (size_t i = 0; i < matBf32.size(); i++) { matBbf16[i] = static_cast(matBf32[i]); } @@ -1459,7 +1459,7 @@ class UT_CompBf16 { {&packedw}, {packedw.template SPtr(), packedw.SDtype(), packedw.CStep()}, {matC.data(), n}}; - parallel::GemmRun(launcher, args, &DefaultThreading); + parallel::GemmRun(launcher, args, UT_Threading::get()); auto err = get_ut_err(qtype); buffer_error(refC.data(), matC.data(), refC.size(), err); buffer_error(refCupk.data(), matC.data(), refCupk.size(), 0.05f); @@ -1520,7 +1520,7 @@ class UTBenchmark_CompBf16 { using Launcher = wrapper::gemm::LauncherBase; Launcher kernel; - DefaultThreading.set_threads(threads); + UT_Threading::set_threads(threads); auto corestr = gemm::CoreAttr::to_str(Core_T::ID); utils::timer tm; using WType = typename Wei::StorageWeight; @@ -1537,7 +1537,7 @@ class UTBenchmark_CompBf16 { for (size_t i = 0; i < batch; i++) { packBs[i] = tmpB; packBs[i].assign(bufB.data() + i * tmpB.mSize); - kernel.mProB.packWeight(n, k, B + i * n * k, n, &packBs[i], &DefaultThreading); + kernel.mProB.packWeight(n, k, B + i * n * k, n, &packBs[i], UT_Threading::get()); } auto psize = (size_t)m * n * k * 2; auto memsize = (size_t)packBs[0].mSize + (m * k + m * n) * sizeof(float); @@ -1547,7 +1547,7 @@ class UTBenchmark_CompBf16 { log.start(); GemmProblem gp(1, m, n, k); typename Launcher::Param args{gp, {A + i * m * k, k}, {&packBs[i]}, {C + i * m * n, n}}; - parallel::GemmRun(kernel, args, &DefaultThreading); + parallel::GemmRun(kernel, args, UT_Threading::get()); if (log.stop()) { double flops = double(psize) / log.avg_val / 1e6; double band = double(memsize) / log.avg_val / 1e6; @@ -1657,7 +1657,7 @@ class UT_ORT_NBits { } } rA.assign(tmpA.data()); - launcher.mProA.reduce({matAf32.data(), k, &rA}, m, k, blocksize, &DefaultThreading); // for reduce UT + launcher.mProA.reduce({matAf32.data(), k, &rA}, m, k, blocksize, UT_Threading::get()); // for reduce UT buffer_error(reduceA.data(), rA.template RPtr(), reduceA.size(), FP32_ERR); memset(tmpA.data(), 0, tmpA.size()); // clear } @@ -1675,11 +1675,11 @@ class UT_ORT_NBits { } } launcher.mProB.packNbitsWeightQ4(n, k, isasym, (uint8_t*)matBs4.data(), k, scalesB.data(), (uint8_t*)zpBs4.data(), - &packedw, &DefaultThreading); - launcher.mProB.reduceWeight(&packedw, &DefaultThreading); + &packedw, UT_Threading::get()); + launcher.mProB.reduceWeight(&packedw, UT_Threading::get()); gemmref_fp32fp32fp32(m, n, k, matAf32.data(), matBf32.data(), refC.data(), k, n, n); avector revB(matBf32.size()); - launcher.mProB.unpackWeight(n, k, &packedw, revB.data(), n, &DefaultThreading); + launcher.mProB.unpackWeight(n, k, &packedw, revB.data(), n, UT_Threading::get()); buffer_error(matBf32.data(), revB.data(), revB.size(), FP32_ERR); gemmref_fp32fp32fp32(m, n, k, matAf32.data(), revB.data(), refCupk.data(), k, n, n); GemmProblem gp(1, m, n, k, blocksize); @@ -1691,9 +1691,9 @@ class UT_ORT_NBits { isasym ? packedw.template ZPtr() : nullptr, rA.template RPtr(), rA.lda}, {matC.data(), n}}; if (isasym) { - parallel::GemmRunWithA(launcher, args, &DefaultThreading); + parallel::GemmRunWithA(launcher, args, UT_Threading::get()); } else { - parallel::GemmRun(launcher, args, &DefaultThreading); + parallel::GemmRun(launcher, args, UT_Threading::get()); } auto err = INT4_ERR; buffer_error(refC.data(), matC.data(), refC.size(), err); @@ -1738,7 +1738,7 @@ class UT_ORT_NBits { } } rA.assign(tmpA.data()); - launcher.mProA.reduce({matAf32.data(), k, &rA}, m, k, blocksize, &DefaultThreading); // for reduce UT + launcher.mProA.reduce({matAf32.data(), k, &rA}, m, k, blocksize, UT_Threading::get()); // for reduce UT buffer_error(reduceA.data(), rA.template RPtr(), reduceA.size(), FP32_ERR); memset(tmpA.data(), 0, tmpA.size()); // clear } @@ -1748,7 +1748,7 @@ class UT_ORT_NBits { } } - launcher.mProB.packQWeight(n, k, qdata.data(), n, sdata.data(), zdata.data(), &packedw, &DefaultThreading); + launcher.mProB.packQWeight(n, k, qdata.data(), n, sdata.data(), zdata.data(), &packedw, UT_Threading::get()); auto bfile = readFile2Buffer("bestla_w3.weight.bin"); WType packedfile(0); @@ -1758,7 +1758,7 @@ class UT_ORT_NBits { buffer_error(packedw.ZPtr(), packedfile.ZPtr(), packedw.CSize()); gemmref_fp32fp32fp32(m, n, k, matAf32.data(), matBf32.data(), refC.data(), k, n, n); avector revB(matBf32.size()); - launcher.mProB.unpackWeight(n, k, &packedw, revB.data(), n, &DefaultThreading); + launcher.mProB.unpackWeight(n, k, &packedw, revB.data(), n, UT_Threading::get()); buffer_error(matBf32.data(), revB.data(), revB.size(), FP32_ERR); gemmref_fp32fp32fp32(m, n, k, matAf32.data(), revB.data(), refCupk.data(), k, n, n); GemmProblem gp(1, m, n, k, blocksize); @@ -1770,9 +1770,9 @@ class UT_ORT_NBits { isasym ? packedw.template ZPtr() : nullptr, rA.template RPtr(), rA.lda}, {matC.data(), n}}; if (isasym) { - parallel::GemmRunWithA(launcher, args, &DefaultThreading); + parallel::GemmRunWithA(launcher, args, UT_Threading::get()); } else { - parallel::GemmRun(launcher, args, &DefaultThreading); + parallel::GemmRun(launcher, args, UT_Threading::get()); } auto err = INT4_ERR; buffer_error(refC.data(), matC.data(), refC.size(), err); @@ -1856,9 +1856,9 @@ class UT_CompFp16 { for (size_t i = 0; i < matBf32.size(); i++) { matBf32[i] = matBbf16[i]; } - launcher.mProB.packWeight(n, k, matBf32.data(), n, &packedw, &DefaultThreading); + launcher.mProB.packWeight(n, k, matBf32.data(), n, &packedw, UT_Threading::get()); gemmref_bf16bf16fp32(m, n, k, matAbf16.data(), matBbf16.data(), refC.data(), k, n, n); - launcher.mProB.unpackWeight(n, k, &packedw, matBf32.data(), n, &DefaultThreading); + launcher.mProB.unpackWeight(n, k, &packedw, matBf32.data(), n, UT_Threading::get()); for (size_t i = 0; i < matBf32.size(); i++) { matBbf16[i] = static_cast(matBf32[i]); } diff --git a/bestla/bestla/ut/bestla_ut.cpp b/bestla/bestla/ut/bestla_ut.cpp index d55ac5b56..a2ac93714 100644 --- a/bestla/bestla/ut/bestla_ut.cpp +++ b/bestla/bestla/ut/bestla_ut.cpp @@ -1,15 +1,6 @@ #include #include -namespace bestla { -namespace ut { -#ifdef _OPENMP -parallel::OMPThreading DefaultThreading(4); -#else -parallel::StdThreading DefaultThreading(4); -#endif // _OPNEMP -} // namespace ut -} // namespace bestla int main() { printf("BesTLA UT done\n"); return 0; diff --git a/bestla/bestla/ut/bestla_ut.h b/bestla/bestla/ut/bestla_ut.h index b570253b1..cd787548d 100644 --- a/bestla/bestla/ut/bestla_ut.h +++ b/bestla/bestla/ut/bestla_ut.h @@ -26,11 +26,20 @@ using sAVX512_VNNI = gemm::ICoreRowNAvx512vnni<48, 8>; using sAMX_INT8_US = gemm::ICoreRowNAmxint8<64, 16>; using sAMX_INT8_SS = gemm::ICoreRowNAmxint8SS<64, 16>; using sAVX2 = gemm::SCoreRowNAvx2<24, 4>; -#ifdef _OPENMP -extern parallel::OMPThreading DefaultThreading; + +class UT_Threading { + public: + static bestla::parallel::IThreading* get() { +#if BTLA_UT_OPENMP + static bestla::parallel::OMPThreading DefaultThreading(4); #else -extern parallel::StdThreading DefaultThreading; + static bestla::parallel::StdThreading DefaultThreading(4); #endif // _OPNEMP + return &DefaultThreading; + } + + static void set_threads(int n_thread) { get()->set_threads(n_thread); } +}; constexpr size_t CacheSize = size_t(100) << 10; static int8_t cache[CacheSize]; @@ -129,11 +138,11 @@ utils::aligned_vector<_T> readFile2Buffer(const char* filepath) { return buf; } -#define UT_START() \ - { \ - GetCPUDevice(); \ - ut::DefaultThreading.set_threads(_cd->getThreads()); \ - printf("Test Class: %s\n", __FUNCTION__); \ +#define UT_START() \ + { \ + GetCPUDevice(); \ + ut::UT_Threading::set_threads(_cd->getThreads()); \ + printf("Test Class: %s\n", __FUNCTION__); \ } template static double buffer_error(_T* ref, _T* tar, size_t size, _T thres = _T(0)) { diff --git a/bestla/bestla/ut/bestla_wrapper.cpp b/bestla/bestla/ut/bestla_wrapper.cpp index 8b9e9d415..2676d3e59 100644 --- a/bestla/bestla/ut/bestla_wrapper.cpp +++ b/bestla/bestla/ut/bestla_wrapper.cpp @@ -47,10 +47,10 @@ class UT_Fp32Fp32 { auto packw = launcher.mProB.createStorage(n, k); avector buffer(packw.mSize); packw.assign(buffer.data()); - launcher.mProB.packWeight(n, k, {matB.data(), n, &packw}, &DefaultThreading); + launcher.mProB.packWeight(n, k, {matB.data(), n, &packw}, UT_Threading::get()); utils::GemmProblem gp(1, m, n, k); typename Launcher::Param args{gp, {matA.data(), k}, {matB.data(), n, &packw}, {matC.data(), n}}; - parallel::GemmRun(launcher, args, &DefaultThreading); + parallel::GemmRun(launcher, args, UT_Threading::get()); ut::buffer_error(ref.data(), matC.data(), ref.size(), 0.001f); } @@ -65,7 +65,7 @@ class UT_Fp32Fp32 { wrapper::gemm::LauncherBase; Launcher kernel; - DefaultThreading.set_threads(threads); + UT_Threading::set_threads(threads); auto corestr = gemm::CoreAttr::to_str(Core_T::ID); utils::timer tm; auto tmpB = kernel.mProB.createStorage(n, k); @@ -74,7 +74,7 @@ class UT_Fp32Fp32 { for (size_t i = 0; i < batch; i++) { packBs[i] = tmpB; packBs[i].assign(bufB.data() + i * tmpB.mSize); - kernel.mProB.packWeight(n, k, {B + i * n * k, n, &packBs[i]}, &DefaultThreading); + kernel.mProB.packWeight(n, k, {B + i * n * k, n, &packBs[i]}, UT_Threading::get()); } auto psize = (size_t)m * n * k * 2; tm.start(); @@ -83,7 +83,7 @@ class UT_Fp32Fp32 { log.start(); utils::GemmProblem gp(1, m, n, k); typename Launcher::Param args{gp, {A + i * m * k, k}, {0, 0, &packBs[i]}, {C + i * m * n, n}}; - parallel::GemmRun(kernel, args, &DefaultThreading); + parallel::GemmRun(kernel, args, UT_Threading::get()); if (log.stop()) { double flops = double(psize) / log.avg_val / 1e6; printf("%s %s Flops:%.3f PerCoreFlops:%.3f\n ", corestr, log.get_log_str(), flops, flops / threads); @@ -190,14 +190,14 @@ class UT_U8S8S32 { auto packw = launcher.mProB.createStorage(n, k); avector buffer(packw.mSize); packw.assign(buffer.data()); - launcher.mProB.packWeight(n, k, {matBs8.data(), n, &packw}, &DefaultThreading); + launcher.mProB.packWeight(n, k, {matBs8.data(), n, &packw}, UT_Threading::get()); utils::GemmProblem gp(1, m, n, k); typename Launcher::Param args{ gp, {matAu8.data(), k}, {matBs8.data(), n, &packw}, {matC.data(), n, 1, scaleAf32.data(), scaleBf32.data(), zpAu8.data(), reduceB.data()}}; - parallel::GemmRun(launcher, args, &DefaultThreading); + parallel::GemmRun(launcher, args, UT_Threading::get()); ut::buffer_error(refC.data(), matC.data(), refC.size(), 0.001f); } @@ -212,7 +212,7 @@ class UT_U8S8S32 { wrapper::gemm::LauncherBase; Launcher kernel; - DefaultThreading.set_threads(threads); + UT_Threading::set_threads(threads); auto corestr = gemm::CoreAttr::to_str(Core_T::ID); utils::timer tm; auto tmpB = kernel.mProB.createStorage(n, k); @@ -221,7 +221,7 @@ class UT_U8S8S32 { for (size_t i = 0; i < batch; i++) { packBs[i] = tmpB; packBs[i].assign(bufB.data() + i * tmpB.mSize); - kernel.mProB.packWeight(n, k, {B + i * n * k, n, &packBs[i]}, &DefaultThreading); + kernel.mProB.packWeight(n, k, {B + i * n * k, n, &packBs[i]}, UT_Threading::get()); } auto psize = (size_t)m * n * k * 2; tm.start(); @@ -230,7 +230,7 @@ class UT_U8S8S32 { log.start(); utils::GemmProblem gp(1, m, n, k); typename Launcher::Param args{gp, {A + i * m * k, k}, {0, 0, &packBs[i]}, {C + i * m * n, n}}; - parallel::GemmRun(kernel, args, &DefaultThreading); + parallel::GemmRun(kernel, args, UT_Threading::get()); if (log.stop()) { double flops = double(psize) / log.avg_val / 1e6; printf("Threads %d %s %s Flops:%.3f PerCoreFlops:%.3f\n", threads, corestr, log.get_log_str(), flops, @@ -324,11 +324,11 @@ class UT_S8S8S32 { auto packw = launcher.mProB.createStorage(n, k); avector buffer(packw.mSize); packw.assign(buffer.data()); - launcher.mProB.packWeight(n, k, {matBs8.data(), n, &packw}, &DefaultThreading); + launcher.mProB.packWeight(n, k, {matBs8.data(), n, &packw}, UT_Threading::get()); utils::GemmProblem gp(1, m, n, k); typename Launcher::Param args{ gp, {matAu8.data(), k}, {matBs8.data(), n, &packw}, {matC.data(), n, 1, scaleAf32.data(), scaleBf32.data()}}; - parallel::GemmRun(launcher, args, &DefaultThreading); + parallel::GemmRun(launcher, args, UT_Threading::get()); ut::buffer_error(refC.data(), matC.data(), refC.size(), 0.001f); } @@ -343,7 +343,7 @@ class UT_S8S8S32 { wrapper::gemm::LauncherBase; Launcher kernel; - DefaultThreading.set_threads(threads); + UT_Threading::set_threads(threads); auto corestr = gemm::CoreAttr::to_str(Core_T::ID); utils::timer tm; auto tmpB = kernel.mProB.createStorage(n, k); @@ -352,7 +352,7 @@ class UT_S8S8S32 { for (size_t i = 0; i < batch; i++) { packBs[i] = tmpB; packBs[i].assign(bufB.data() + i * tmpB.mSize); - kernel.mProB.packWeight(n, k, {B + i * n * k, n, &packBs[i]}, &DefaultThreading); + kernel.mProB.packWeight(n, k, {B + i * n * k, n, &packBs[i]}, UT_Threading::get()); } auto psize = (size_t)m * n * k * 2; tm.start(); @@ -361,7 +361,7 @@ class UT_S8S8S32 { log.start(); utils::GemmProblem gp(1, m, n, k); typename Launcher::Param args{gp, {A + i * m * k, k}, {0, 0, &packBs[i]}, {C + i * m * n, n}}; - parallel::GemmRun(kernel, args, &DefaultThreading); + parallel::GemmRun(kernel, args, UT_Threading::get()); if (log.stop()) { double flops = double(psize) / log.avg_val / 1e6; printf("Threads %d %s %s Flops:%.3f PerCoreFlops:%.3f\n", threads, corestr, log.get_log_str(), flops, @@ -430,11 +430,11 @@ class UT_Bf16Bf16Fp32 { fill_buffer_randn(matAbf16.data(), matAbf16.size(), utils::bf16(-0.5f), utils::bf16(0.5f)); fill_buffer_randn(matBbf16.data(), matBbf16.size(), utils::bf16(-0.5f), utils::bf16(0.5f)); avector matC(m * n), refC(m * n); - launcher.mProB.packWeight(n, k, {matBbf16.data(), n, &packw}, &DefaultThreading); + launcher.mProB.packWeight(n, k, {matBbf16.data(), n, &packw}, UT_Threading::get()); gemmref_bf16bf16fp32(m, n, k, matAbf16.data(), matBbf16.data(), refC.data(), k, n, n); utils::GemmProblem gp(1, m, n, k); typename Launcher::Param args{gp, {matAbf16.data(), k}, {matBbf16.data(), n, &packw}, {matC.data(), n}}; - parallel::GemmRun(launcher, args, &DefaultThreading); + parallel::GemmRun(launcher, args, UT_Threading::get()); buffer_error(refC.data(), matC.data(), refC.size(), 0.05f); } @@ -449,7 +449,7 @@ class UT_Bf16Bf16Fp32 { wrapper::gemm::LauncherBase; Launcher kernel; - DefaultThreading.set_threads(threads); + UT_Threading::set_threads(threads); auto corestr = gemm::CoreAttr::to_str(Core_T::ID); utils::timer tm; auto tmpB = kernel.mProB.createStorage(n, k); @@ -458,7 +458,7 @@ class UT_Bf16Bf16Fp32 { for (size_t i = 0; i < batch; i++) { packBs[i] = tmpB; packBs[i].assign(bufB.data() + i * tmpB.mSize); - kernel.mProB.packWeight(n, k, {B + i * n * k, n, &packBs[i]}, &DefaultThreading); + kernel.mProB.packWeight(n, k, {B + i * n * k, n, &packBs[i]}, UT_Threading::get()); } auto psize = (size_t)m * n * k * 2; tm.start(); @@ -467,7 +467,7 @@ class UT_Bf16Bf16Fp32 { log.start(); utils::GemmProblem gp(1, m, n, k); typename Launcher::Param args{gp, {A + i * m * k, k}, {0, 0, &packBs[i]}, {C + i * m * n, n}}; - parallel::GemmRun(kernel, args, &DefaultThreading); + parallel::GemmRun(kernel, args, UT_Threading::get()); if (log.stop()) { double flops = double(psize) / log.avg_val / 1e6; printf("Threads %d %s %s Flops:%.3f PerCoreFlops:%.3f\n", threads, corestr, log.get_log_str(), flops, @@ -534,11 +534,11 @@ class UT_Fp16Fp16Fp16 { avector matAbf16(m * k), matBbf16(k * n), matC(m * n), refC(m * n); fill_buffer_randn(matAbf16.data(), matAbf16.size(), utils::fp16(-0.5f), utils::fp16(0.5f)); fill_buffer_randn(matBbf16.data(), matBbf16.size(), utils::fp16(-0.5f), utils::fp16(0.5f)); - launcher.mProB.packWeight(n, k, {matBbf16.data(), n, &packw}, &DefaultThreading); + launcher.mProB.packWeight(n, k, {matBbf16.data(), n, &packw}, UT_Threading::get()); gemmref_fp16fp16fp16(m, n, k, matAbf16.data(), matBbf16.data(), refC.data(), k, n, n); GemmProblem gp(1, m, n, k); typename Launcher::Param args{gp, {matAbf16.data(), k}, {matBbf16.data(), n, &packw}, {matC.data(), n}}; - parallel::GemmRun(launcher, args, &DefaultThreading); + parallel::GemmRun(launcher, args, UT_Threading::get()); buffer_error(refC.data(), matC.data(), refC.size(), utils::fp16(0.0002f * k)); } @@ -553,7 +553,7 @@ class UT_Fp16Fp16Fp16 { wrapper::gemm::LauncherBase; Launcher kernel; - DefaultThreading.set_threads(threads); + UT_Threading::set_threads(threads); auto corestr = gemm::CoreAttr::to_str(Core_T::ID); utils::timer tm; auto tmpB = kernel.mProB.createStorage(n, k); @@ -562,7 +562,7 @@ class UT_Fp16Fp16Fp16 { for (size_t i = 0; i < batch; i++) { packBs[i] = tmpB; packBs[i].assign(bufB.data() + i * tmpB.mSize); - kernel.mProB.packWeight(n, k, {B + i * n * k, n, &packBs[i]}, &DefaultThreading); + kernel.mProB.packWeight(n, k, {B + i * n * k, n, &packBs[i]}, UT_Threading::get()); } auto psize = (size_t)m * n * k * 2; tm.start(); @@ -571,7 +571,7 @@ class UT_Fp16Fp16Fp16 { log.start(); GemmProblem gp(1, m, n, k); typename Launcher::Param args{gp, {A + i * m * k, k}, {0, 0, &packBs[i]}, {C + i * m * n, n}}; - parallel::GemmRun(kernel, args, &DefaultThreading); + parallel::GemmRun(kernel, args, UT_Threading::get()); if (log.stop()) { double flops = double(psize) / log.avg_val / 1e6; printf("Threads %d %s %s Flops:%.3f PerCoreFlops:%.3f\n", threads, corestr, log.get_log_str(), flops, diff --git a/bestla/bestla/ut/kernel_jit.cpp b/bestla/bestla/ut/kernel_jit.cpp index ce1198c99..7d91297ef 100644 --- a/bestla/bestla/ut/kernel_jit.cpp +++ b/bestla/bestla/ut/kernel_jit.cpp @@ -27,9 +27,9 @@ class UT_Memcpy2D_AVX512F { kernel::jit::JitMemcpy2DAvx512f::forward(src.data(), dst.data(), row, col, srcstep, dststep); } tm.start(); - parallel::Scheduler2D para({DefaultThreading.num_threads(), row, col, 4, 64}); + parallel::Scheduler2D para({UT_Threading::get()->num_threads(), row, col, 4, 64}); for (size_t i = 0; i < TestLoop; i++) { - DefaultThreading.parallel_for([&](int tidx) { + UT_Threading::get()->parallel_for([&](int tidx) { parallel::ThreadProblem2D thdp{tidx}; para.getIndex(thdp); if (thdp.valid) { @@ -47,7 +47,7 @@ class UT_Memcpy2D_AVX512F { tm.start(); for (size_t i = 0; i < TestLoop; i++) { - DefaultThreading.parallel_for([&](int tidx) { + UT_Threading::get()->parallel_for([&](int tidx) { parallel::ThreadProblem2D thdp{tidx}; para.getIndex(thdp); if (thdp.valid) { From 120123512661f5e8ffc07f9b5ad6bf72c243da6b Mon Sep 17 00:00:00 2001 From: yuchengliu1 Date: Wed, 21 Feb 2024 14:59:10 +0800 Subject: [PATCH 07/45] add sync --- bestla/bestla/bestla_parallel.h | 20 +++++++++++++++----- bestla/bestla/ut/bestla_prologue_b.cpp | 18 +++++++++--------- 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/bestla/bestla/bestla_parallel.h b/bestla/bestla/bestla_parallel.h index 9b424f5c3..47c3af827 100644 --- a/bestla/bestla/bestla_parallel.h +++ b/bestla/bestla/bestla_parallel.h @@ -589,7 +589,7 @@ class IThreading { public: explicit IThreading(int nthreads) : mThreadNum(nthreads) {} virtual void parallel_for(const thread_func& func) = 0; - virtual inline void sync() const = 0; + virtual inline void sync(int idx=0) = 0; virtual int num_threads() const { return mThreadNum; }; virtual void set_threads(int nthreads) = 0; @@ -618,7 +618,7 @@ class OMPThreading : public IThreading { mThreadNum = nthreads; omp_set_num_threads(nthreads); } - virtual inline void sync() const override { + virtual inline void sync(int idx=0) override { #pragma omp barrier (void)(0); // make msvc happy with c++20 } @@ -633,7 +633,8 @@ class StdThreading : public IThreading { } void parallel_for(const thread_func& func) override { if (mThreadNum > 1) { - running.store(mThreadNum - 1); + running.store(mThreadNum-1); + for(int i=0;i<10;i++) flag[i].store(mThreadNum); for (size_t i = 0; i < mThreadNum - 1; i++) { func_[i] = &func; } @@ -657,7 +658,15 @@ class StdThreading : public IThreading { } } - inline void sync() const override { assert(0); } + inline void sync(int idx=0) override { + flag[idx].fetch_sub(1); + while (true) { + if (flag[idx].load() == 0) + break; + else + _mm_pause(); + } + } ~StdThreading() { stop_threads(); } @@ -695,6 +704,7 @@ class StdThreading : public IThreading { std::vector thdset; std::atomic_bool stop; std::atomic_int running; + std::atomic_int flag[10]; const thread_func* func_[100]; }; @@ -704,7 +714,7 @@ class SingleThread : public StdThreading { void set_threads(int nthreads) override { (void)(nthreads); } - inline void sync() const override {} + inline void sync(int idx) override {} }; template diff --git a/bestla/bestla/ut/bestla_prologue_b.cpp b/bestla/bestla/ut/bestla_prologue_b.cpp index 9091c6205..41f2e765f 100644 --- a/bestla/bestla/ut/bestla_prologue_b.cpp +++ b/bestla/bestla/ut/bestla_prologue_b.cpp @@ -186,7 +186,7 @@ class UT_S3_WOQ { template void ut(int m, int n, int k, int blocksize, int enable_thr) { - DefaultThreading.set_threads(enable_thr); + UT_Threading::set_threads(enable_thr); printf("%s:%d %d %d %d\n", __FUNCTION__, m, n, k, blocksize); int ldb = n; @@ -207,8 +207,8 @@ class UT_S3_WOQ { avector buffer_ref(ptr_ref.mSize); ptr.assign(buffer.data()); ptr_ref.assign(buffer_ref.data()); - kernel.packQWeight(n, k, quanW.data(), ldb, scales.data(), nullptr, &ptr, &DefaultThreading); - kernel.packQWeight(n, k, quanW.data(), ldb, scales.data(), nullptr, &ptr_ref, &DefaultThreading); + kernel.packQWeight(n, k, quanW.data(), ldb, scales.data(), nullptr, &ptr, UT_Threading::get()); + kernel.packQWeight(n, k, quanW.data(), ldb, scales.data(), nullptr, &ptr_ref, UT_Threading::get()); using Launcher = wrapper::gemm::LauncherKBlock(), ptr.SDtype(), ptr.CStep()}, {matC.data(), n}}; - parallel::GemmRun(launcher, args, &DefaultThreading); + parallel::GemmRun(launcher, args, UT_Threading::get()); typename Launcher::Param args_ref{gp, {matAf32.data(), k}, {&ptr_ref}, {ptr_ref.template SPtr(), ptr_ref.SDtype(), ptr_ref.CStep()}, {refC.data(), n}}; - parallel::GemmRun(launcher, args_ref, &DefaultThreading); + parallel::GemmRun(launcher, args_ref, UT_Threading::get()); } else if constexpr (ISA == BTLA_ISA::AMX_BF16) { avector matAbf16(m * k); fill_buffer_randn(matAbf16.data(), matAbf16.size(), utils::bf16(-0.5f), utils::bf16(0.5f)); GemmProblem gp(1, m, n, k, blocksize); typename Launcher::Param args{ gp, {matAbf16.data(), k}, {&ptr}, {ptr.template SPtr(), ptr.SDtype(), ptr.CStep()}, {matC.data(), n}}; - parallel::GemmRun(launcher, args, &DefaultThreading); + parallel::GemmRun(launcher, args, UT_Threading::get()); typename Launcher::Param args_ref{gp, {matAbf16.data(), k}, {&ptr_ref}, {ptr_ref.template SPtr(), ptr_ref.SDtype(), ptr_ref.CStep()}, {refC.data(), n}}; - parallel::GemmRun(launcher, args_ref, &DefaultThreading); + parallel::GemmRun(launcher, args_ref, UT_Threading::get()); } else { using Launcher2 = wrapper::gemm::LauncherIntKBlock(launcher, args, &DefaultThreading); + parallel::GemmRunWithA(launcher, args, UT_Threading::get()); typename Launcher2::Param args_ref{gp, {matAf32.data(), k, &quanA_ref}, {&ptr_ref}, {refC.data(), n}}; - parallel::GemmRunWithA(launcher, args_ref, &DefaultThreading); + parallel::GemmRunWithA(launcher, args_ref, UT_Threading::get()); } buffer_error(matC.data(), refC.data(), matC.size(), 0.001f); } From 79b7cbccfae405198ec62b1c45bfe963dcaec50a Mon Sep 17 00:00:00 2001 From: yuchengliu1 Date: Wed, 21 Feb 2024 16:22:07 +0800 Subject: [PATCH 08/45] integrate to core --- neural_speed/core/layers/ip_fusion_ffn.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/neural_speed/core/layers/ip_fusion_ffn.cpp b/neural_speed/core/layers/ip_fusion_ffn.cpp index ec950ee42..b4b1a1d89 100644 --- a/neural_speed/core/layers/ip_fusion_ffn.cpp +++ b/neural_speed/core/layers/ip_fusion_ffn.cpp @@ -51,19 +51,19 @@ void GemmRunWithA_ffn(Launch_T1* launcher1, Launch_T2* launcher2, const typename if (thdpA1.valid) { launcher1->mProA.run(args1.paramA, thdpA1); } - th->sync(); + th->sync(0); typename Parallel_T::ThreadProblem thdp1{tidx}; para1.getIndex(thdp1); if (thdp1.valid) { launcher1->run(args1, thdp1); } - th->sync(); + th->sync(1); typename AParall2::ThreadProblem thdpA2{tidx}; apara2.getIndex(thdpA2); if (thdpA2.valid) { launcher2->mProA.run(args2.paramA, thdpA2); } - th->sync(); + th->sync(2); typename Parallel_T::ThreadProblem thdp2{tidx}; para2.getIndex(thdp2); if (thdp2.valid) { @@ -368,20 +368,20 @@ void GemmRunWithA_ffn(Launch_T1* launcher1, Launch_T2* launcher2, Launch_T3* lau if (thdpA1.valid) { launcher1->mProA.run(args1.paramA, thdpA1); } - th->sync(); + th->sync(0); typename Parallel_T::ThreadProblem thdp1{tidx}; para1.getIndex(thdp1); if (thdp1.valid) { launcher1->run(args1, thdp1); launcher2->run(args2, thdp1); } - th->sync(); + th->sync(1); typename AParall3::ThreadProblem thdpA3{tidx}; apara3.getIndex(thdpA3); if (thdpA3.valid) { launcher3->mProA.run(args3.paramA, thdpA3); } - th->sync(); + th->sync(2); typename Parallel_T::ThreadProblem thdp3{tidx}; para3.getIndex(thdp3); if (thdp3.valid) { From 86feb68e1d21fbbf11b1123dd0f8104361c2c192 Mon Sep 17 00:00:00 2001 From: ZheWang Date: Fri, 23 Feb 2024 00:59:11 -0800 Subject: [PATCH 09/45] fix bugs --- bestla/bestla/bestla_parallel.h | 13 +++++++++---- neural_speed/models/model_utils/quant_utils.cpp | 4 ++-- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/bestla/bestla/bestla_parallel.h b/bestla/bestla/bestla_parallel.h index 47c3af827..d15e7be3f 100644 --- a/bestla/bestla/bestla_parallel.h +++ b/bestla/bestla/bestla_parallel.h @@ -708,13 +708,18 @@ class StdThreading : public IThreading { const thread_func* func_[100]; }; -class SingleThread : public StdThreading { +class SingleThread : public IThreading { public: - SingleThread() : StdThreading(1) {} + SingleThread() : IThreading(1) {} - void set_threads(int nthreads) override { (void)(nthreads); } + void set_threads(int nthreads) override { + assert(0); + (void)(nthreads); + } + + inline void parallel_for(const thread_func& func) override { func(0); } - inline void sync(int idx) override {} + inline void sync(int idx = 0) override {} }; template diff --git a/neural_speed/models/model_utils/quant_utils.cpp b/neural_speed/models/model_utils/quant_utils.cpp index 849578400..d089a1f48 100644 --- a/neural_speed/models/model_utils/quant_utils.cpp +++ b/neural_speed/models/model_utils/quant_utils.cpp @@ -276,9 +276,9 @@ size_t bestla_quantize(const float* f32ptr, void* dstpr, const quant_params_inte auto ctype = quant2ne_comp_type(params.compute_dtype); auto dstbptr = reinterpret_cast(dstpr); #ifdef __OPENMP - bestla::parallel::OMPThreading threading(nthread); + static bestla::parallel::OMPThreading threading(nthread); #else - bestla::parallel::StdThreading threading(nthread); + static bestla::parallel::StdThreading threading(nthread); #endif BTLA_DTYPE quant_type = BTLA_DTYPE::S4_CLIP; if (params.bits == quant_bits::q3) { From 2ed0175be2d1ba63afc6ca18a2691fa8d3e326fa Mon Sep 17 00:00:00 2001 From: "Luo, Yu" Date: Tue, 27 Feb 2024 17:11:05 +0800 Subject: [PATCH 10/45] add thread config for hybrid cpu --- bestla/bestla/ut/bestla_wrapper.cpp | 73 +++++++++++++++++++++-------- 1 file changed, 54 insertions(+), 19 deletions(-) diff --git a/bestla/bestla/ut/bestla_wrapper.cpp b/bestla/bestla/ut/bestla_wrapper.cpp index 2676d3e59..d50ad5aba 100644 --- a/bestla/bestla/ut/bestla_wrapper.cpp +++ b/bestla/bestla/ut/bestla_wrapper.cpp @@ -86,7 +86,8 @@ class UT_Fp32Fp32 { parallel::GemmRun(kernel, args, UT_Threading::get()); if (log.stop()) { double flops = double(psize) / log.avg_val / 1e6; - printf("%s %s Flops:%.3f PerCoreFlops:%.3f\n ", corestr, log.get_log_str(), flops, flops / threads); + printf("Threads %d %s %s Flops:%.3f PerCoreFlops:%.3f\n ", threads, corestr, log.get_log_str(), flops, + flops / threads); } } } @@ -108,11 +109,17 @@ class UT_Fp32Fp32 { float testtime = 500.f; GetCPUDevice(); if (_cd->AVX512F()) { - benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime, 48); - benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime, 56); + benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime, _cd->getThreads()); + benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime, _cd->getCores()); } if (_cd->AVX2()) { - benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime, 56); + if (_cd->isHybrid()) { + benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime, _cd->getThreads()); + benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime, _cd->getCores()); + benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime, _cd->getPcoreNum()); + } else { + benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime, _cd->getThreads()); + } } } }; @@ -256,16 +263,37 @@ class UT_U8S8S32 { GetCPUDevice(); if (_cd->AMX_INT8()) { request_perm_xtile_data(); - benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, 48); - benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, 48); - benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, 48); + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + _cd->getThreads()); + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + _cd->getThreads()); + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + _cd->getThreads()); } if (_cd->AVX512_VNNI()) { - benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, 48); + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + _cd->getThreads()); } if (_cd->AVX_VNNI()) { - benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, 48); - benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, 48); + if (_cd->isHybrid()) { + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + _cd->getThreads()); + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + _cd->getCores()); + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + _cd->getPcoreNum()); + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + _cd->getThreads()); + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + _cd->getCores()); + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + _cd->getPcoreNum()); + } else { + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + _cd->getThreads()); + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + _cd->getThreads()); + } } } }; @@ -387,9 +415,12 @@ class UT_S8S8S32 { GetCPUDevice(); if (_cd->AMX_INT8()) { request_perm_xtile_data(); - benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, 48); - benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, 48); - benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, 48); + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + _cd->getThreads()); + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + _cd->getThreads()); + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + _cd->getThreads()); } } }; @@ -493,9 +524,12 @@ class UT_Bf16Bf16Fp32 { GetCPUDevice(); if (_cd->AMX_BF16()) { request_perm_xtile_data(); - benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, 48); - benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, 48); - benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, 48); + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + _cd->getThreads()); + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + _cd->getThreads()); + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + _cd->getThreads()); } } }; @@ -596,9 +630,10 @@ class UT_Fp16Fp16Fp16 { float testtime = 500.f; GetCPUDevice(); if (_cd->AVX512_FP16()) { - benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime, 56); - benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, 56); - benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime, 48); + benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime, _cd->getThreads()); + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + _cd->getThreads()); + benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime, _cd->getThreads()); } } }; From 1bade9296f809c6098b080b51e1381ac9b129a4e Mon Sep 17 00:00:00 2001 From: "Luo, Yu" Date: Wed, 28 Feb 2024 10:55:03 +0800 Subject: [PATCH 11/45] split benchmark and UT --- bestla/CMakeLists.txt | 16 +- bestla/bestla/ut/bestla_benchmark.cpp | 734 +++++++++++++++++++++++++ bestla/bestla/ut/bestla_prologue_b.cpp | 317 +---------- bestla/bestla/ut/bestla_wrapper.cpp | 371 ------------- 4 files changed, 748 insertions(+), 690 deletions(-) create mode 100644 bestla/bestla/ut/bestla_benchmark.cpp diff --git a/bestla/CMakeLists.txt b/bestla/CMakeLists.txt index 59c5954ec..d3dba63ea 100644 --- a/bestla/CMakeLists.txt +++ b/bestla/CMakeLists.txt @@ -64,7 +64,8 @@ endif() function(add_ut_flag UT_OPTION) if(${${UT_OPTION}}) - target_compile_definitions(${PROJECT_NAME}_ut PRIVATE ${UT_OPTION}) + # target_compile_definitions(${PROJECT_NAME}_ut PRIVATE ${UT_OPTION}) + add_compile_definitions(${UT_OPTION}) endif() endfunction() @@ -96,8 +97,17 @@ if(UT_BUILD) add_ut_flag(BTLA_UT_KERNEL_INTRIN) add_ut_flag(BTLA_UT_KERNEL_JIT) add_ut_flag(BTLA_UT_KERNEL_WRAPPER) - add_ut_flag(BTLA_UT_BENCHMARK) - target_link_libraries(${PROJECT_NAME}_ut PRIVATE ${PROJECT_NAME}) endif(UT_BUILD) +if(BTLA_UT_BENCHMARK) + file(GLOB srcs ${PROJECT_NAME}/ut/bestla_benchmark.cpp) #compile everything even run parts of UTs + file(GLOB ut_headers ${PROJECT_NAME}/ut/*.h) + include_directories(${PROJECT_NAME}) + add_executable(${PROJECT_NAME}_benchmark ${srcs} ${headers} ${ut_headers}) + if(BTLA_UT_OPENMP) + include(FindOpenMP) + target_compile_definitions(${PROJECT_NAME} INTERFACE BTLA_USE_OPENMP) + target_link_libraries(${PROJECT_NAME}_benchmark PRIVATE OpenMP::OpenMP_CXX) + endif() +endif(BTLA_UT_BENCHMARK) diff --git a/bestla/bestla/ut/bestla_benchmark.cpp b/bestla/bestla/ut/bestla_benchmark.cpp new file mode 100644 index 000000000..a0b9cbb6b --- /dev/null +++ b/bestla/bestla/ut/bestla_benchmark.cpp @@ -0,0 +1,734 @@ +#include +#include "bestla_wrapper.h" +#include "bestla_ut.h" + +namespace bestla { +using namespace utils; +namespace ut { +class Benchmark_Fp32Fp32 { + public: + Benchmark_Fp32Fp32() { + UT_START(); + benchmark_all(1, 4096, 4096, 32); + benchmark_all(1024, 4096, 4096, 32); + benchmark_all(2048, 4096, 4096, 32); + } + + using AType = float; + using BType = float; + using CType = float; + template + void benchmark(int m, int n, int k, int batch, AType* A, BType* B, CType* C, float timems, int threads) { + LOG_T log; + using Parallel = parallel::gemm::SchedulerBase; + using Launcher = + wrapper::gemm::LauncherBase; + Launcher kernel; + UT_Threading::set_threads(threads); + auto corestr = gemm::CoreAttr::to_str(Core_T::ID); + utils::timer tm; + auto tmpB = kernel.mProB.createStorage(n, k); + std::vector packBs(batch, 0); + std::vector bufB(tmpB.mSize * batch); + for (size_t i = 0; i < batch; i++) { + packBs[i] = tmpB; + packBs[i].assign(bufB.data() + i * tmpB.mSize); + kernel.mProB.packWeight(n, k, {B + i * n * k, n, &packBs[i]}, UT_Threading::get()); + } + auto psize = (size_t)m * n * k * 2; + tm.start(); + while (tm.stop() < timems) { + for (size_t i = 0; i < batch; i++) { + log.start(); + utils::GemmProblem gp(1, m, n, k); + typename Launcher::Param args{gp, {A + i * m * k, k}, {0, 0, &packBs[i]}, {C + i * m * n, n}}; + parallel::GemmRun(kernel, args, UT_Threading::get()); + if (log.stop()) { + double flops = double(psize) / log.avg_val / 1e6; + printf("Threads %d %s %s Flops:%.3f PerCoreFlops:%.3f\n ", threads, corestr, log.get_log_str(), flops, + flops / threads); + } + } + } + } + + void benchmark_all(size_t m, size_t n, size_t k, size_t batch) { + printf("%s %d %d %d %d\n", __FUNCTION__, int(m), int(n), int(k), int(batch)); + avector A(m * k * batch); + avector B(k * n * batch); + avector C(m * n * batch, 0), RefC(m * n * batch, 0); + fill_buffer_randn(A.data(), m * k, -0.5f, 0.5f); + fill_buffer_randn(B.data(), n * k, -0.5f, 0.5f); + for (size_t i = 0; i < batch - 1; i++) { + memcpy(A.data() + i * m * k, A.data(), m * k * sizeof(AType)); + memcpy(B.data() + i * n * k, B.data(), n * k * sizeof(BType)); + } + using LOG = timer_statistics_logger<100>; + + float testtime = 500.f; + GetCPUDevice(); + if (_cd->AVX512F()) { + benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime, _cd->getThreads()); + benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime, _cd->getCores()); + } + if (_cd->AVX2()) { + if (_cd->isHybrid()) { + benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime, _cd->getThreads()); + benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime, _cd->getCores()); + benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime, _cd->getPcoreNum()); + } else { + benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime, _cd->getThreads()); + } + } + } +}; +#ifdef BTLA_UT_WRAPPER +static Benchmark_Fp32Fp32 sBenchmark_Fp32Fp32; +#endif + +class Benchmark_U8S8S32 { + public: + Benchmark_U8S8S32() { + UT_START(); + benchmark_all(1024, 4096, 4096, 32); + benchmark_all(2048, 4096, 4096, 32); + } + + using AType = uint8_t; + using BType = int8_t; + using CType = int; + template + void benchmark(int m, int n, int k, int batch, AType* A, BType* B, CType* C, float timems, int threads) { + LOG_T log; + using Parallel = parallel::gemm::SchedulerBase; + using Launcher = + wrapper::gemm::LauncherBase; + Launcher kernel; + UT_Threading::set_threads(threads); + auto corestr = gemm::CoreAttr::to_str(Core_T::ID); + utils::timer tm; + auto tmpB = kernel.mProB.createStorage(n, k); + std::vector packBs(batch, 0); + std::vector bufB(tmpB.mSize * batch); + for (size_t i = 0; i < batch; i++) { + packBs[i] = tmpB; + packBs[i].assign(bufB.data() + i * tmpB.mSize); + kernel.mProB.packWeight(n, k, {B + i * n * k, n, &packBs[i]}, UT_Threading::get()); + } + auto psize = (size_t)m * n * k * 2; + tm.start(); + while (tm.stop() < timems) { + for (size_t i = 0; i < batch; i++) { + log.start(); + utils::GemmProblem gp(1, m, n, k); + typename Launcher::Param args{gp, {A + i * m * k, k}, {0, 0, &packBs[i]}, {C + i * m * n, n}}; + parallel::GemmRun(kernel, args, UT_Threading::get()); + if (log.stop()) { + double flops = double(psize) / log.avg_val / 1e6; + printf("Threads %d %s %s Flops:%.3f PerCoreFlops:%.3f\n", threads, corestr, log.get_log_str(), flops, + flops / threads); + } + } + } + } + + void benchmark_all(size_t m, size_t n, size_t k, size_t batch) { + printf("%s %d %d %d %d\n", __FUNCTION__, int(m), int(n), int(k), int(batch)); + avector A(m * k * batch); + avector B(k * n * batch); + avector C(m * n * batch), RefC(m * n * batch); + fill_buffer_randn(A.data(), m * k, AType(0), AType(255)); + fill_buffer_randn(B.data(), k * n, BType(-127), BType(127)); + for (size_t i = 0; i < batch - 1; i++) { + memcpy(A.data() + i * m * k, A.data(), m * k * sizeof(AType)); + memcpy(B.data() + i * n * k, B.data(), n * k * sizeof(BType)); + } + using LOG = timer_statistics_logger<100>; + float testtime = 500.f; + GetCPUDevice(); + if (_cd->AMX_INT8()) { + request_perm_xtile_data(); + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + _cd->getThreads()); + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + _cd->getThreads()); + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + _cd->getThreads()); + } + if (_cd->AVX512_VNNI()) { + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + _cd->getThreads()); + } + if (_cd->AVX_VNNI()) { + if (_cd->isHybrid()) { + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + _cd->getThreads()); + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + _cd->getCores()); + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + _cd->getPcoreNum()); + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + _cd->getThreads()); + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + _cd->getCores()); + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + _cd->getPcoreNum()); + } else { + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + _cd->getThreads()); + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + _cd->getThreads()); + } + } + } +}; +#ifdef BTLA_UT_WRAPPER +static Benchmark_U8S8S32 sBenchmark_U8S8S32; +#endif + +class Benchmark_S8S8S32 { + public: + Benchmark_S8S8S32() { + UT_START(); + benchmark_all(1024, 4096, 4096, 32); + benchmark_all(2048, 4096, 4096, 32); + } + + using AType = int8_t; + using BType = int8_t; + using CType = int; + template + void benchmark(int m, int n, int k, int batch, AType* A, BType* B, CType* C, float timems, int threads) { + LOG_T log; + using Parallel = parallel::gemm::SchedulerBase; + using Launcher = + wrapper::gemm::LauncherBase; + Launcher kernel; + UT_Threading::set_threads(threads); + auto corestr = gemm::CoreAttr::to_str(Core_T::ID); + utils::timer tm; + auto tmpB = kernel.mProB.createStorage(n, k); + std::vector packBs(batch, 0); + std::vector bufB(tmpB.mSize * batch); + for (size_t i = 0; i < batch; i++) { + packBs[i] = tmpB; + packBs[i].assign(bufB.data() + i * tmpB.mSize); + kernel.mProB.packWeight(n, k, {B + i * n * k, n, &packBs[i]}, UT_Threading::get()); + } + auto psize = (size_t)m * n * k * 2; + tm.start(); + while (tm.stop() < timems) { + for (size_t i = 0; i < batch; i++) { + log.start(); + utils::GemmProblem gp(1, m, n, k); + typename Launcher::Param args{gp, {A + i * m * k, k}, {0, 0, &packBs[i]}, {C + i * m * n, n}}; + parallel::GemmRun(kernel, args, UT_Threading::get()); + if (log.stop()) { + double flops = double(psize) / log.avg_val / 1e6; + printf("Threads %d %s %s Flops:%.3f PerCoreFlops:%.3f\n", threads, corestr, log.get_log_str(), flops, + flops / threads); + } + } + } + } + + void benchmark_all(size_t m, size_t n, size_t k, size_t batch) { + printf("%s %d %d %d %d\n", __FUNCTION__, int(m), int(n), int(k), int(batch)); + avector A(m * k * batch); + avector B(k * n * batch); + avector C(m * n * batch), RefC(m * n * batch); + fill_buffer_randn(A.data(), m * k, AType(0), AType(255)); + fill_buffer_randn(B.data(), k * n, BType(-127), BType(127)); + for (size_t i = 0; i < batch - 1; i++) { + memcpy(A.data() + i * m * k, A.data(), m * k * sizeof(AType)); + memcpy(B.data() + i * n * k, B.data(), n * k * sizeof(AType)); + } + using LOG = timer_statistics_logger<100>; + float testtime = 500.f; + GetCPUDevice(); + if (_cd->AMX_INT8()) { + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + _cd->getThreads()); + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + _cd->getThreads()); + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + _cd->getThreads()); + } + } +}; +#ifdef BTLA_UT_WRAPPER +static Benchmark_S8S8S32 sBenchmark_S8S8S32; +#endif + +class Benchmark_Bf16Bf16Fp32 { + public: + Benchmark_Bf16Bf16Fp32() { + UT_START(); + benchmark_all(1024, 4096, 4096, 32); + benchmark_all(2048, 4096, 4096, 32); + } + + using AType = utils::bf16; + using BType = utils::bf16; + using CType = float; + template + void benchmark(int m, int n, int k, int batch, AType* A, BType* B, CType* C, float timems, int threads) { + LOG_T log; + using Parallel = parallel::gemm::SchedulerBase; + using Launcher = + wrapper::gemm::LauncherBase; + Launcher kernel; + UT_Threading::set_threads(threads); + auto corestr = gemm::CoreAttr::to_str(Core_T::ID); + utils::timer tm; + auto tmpB = kernel.mProB.createStorage(n, k); + std::vector packBs(batch, 0); + std::vector bufB(tmpB.mSize * batch); + for (size_t i = 0; i < batch; i++) { + packBs[i] = tmpB; + packBs[i].assign(bufB.data() + i * tmpB.mSize); + kernel.mProB.packWeight(n, k, {B + i * n * k, n, &packBs[i]}, UT_Threading::get()); + } + auto psize = (size_t)m * n * k * 2; + tm.start(); + while (tm.stop() < timems) { + for (size_t i = 0; i < batch; i++) { + log.start(); + utils::GemmProblem gp(1, m, n, k); + typename Launcher::Param args{gp, {A + i * m * k, k}, {0, 0, &packBs[i]}, {C + i * m * n, n}}; + parallel::GemmRun(kernel, args, UT_Threading::get()); + if (log.stop()) { + double flops = double(psize) / log.avg_val / 1e6; + printf("Threads %d %s %s Flops:%.3f PerCoreFlops:%.3f\n", threads, corestr, log.get_log_str(), flops, + flops / threads); + } + } + } + } + + void benchmark_all(size_t m, size_t n, size_t k, size_t batch) { + printf("%s %d %d %d %d\n", __FUNCTION__, int(m), int(n), int(k), int(batch)); + avector A(m * k * batch); + avector B(k * n * batch); + avector C(m * n * batch), RefC(m * n * batch); + fill_buffer_randn(A.data(), k * m, AType(-0.5f), AType(0.5f)); + fill_buffer_randn(B.data(), k * n, BType(-0.5f), BType(0.5f)); + for (size_t i = 0; i < batch - 1; i++) { + memcpy(A.data() + i * m * k, A.data(), m * k * sizeof(AType)); + memcpy(B.data() + i * n * k, B.data(), n * k * sizeof(BType)); + } + using LOG = timer_statistics_logger<100>; + float testtime = 500.f; + GetCPUDevice(); + if (_cd->AMX_BF16()) { + request_perm_xtile_data(); + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + _cd->getThreads()); + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + _cd->getThreads()); + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + _cd->getThreads()); + } + } +}; +#ifdef BTLA_UT_WRAPPER +static Benchmark_Bf16Bf16Fp32 sBenchmark_Bf16Bf16Fp32; +#endif + +class Benchmark_Fp16Fp16Fp16 { + public: + Benchmark_Fp16Fp16Fp16() { + UT_START(); + benchmark_all(1024, 4096, 4096, 32); + benchmark_all(2048, 4096, 4096, 32); + } + + using AType = utils::fp16; + using BType = utils::fp16; + using CType = utils::fp16; + template + void benchmark(int m, int n, int k, int batch, AType* A, BType* B, CType* C, float timems, int threads) { + LOG_T log; + using Parallel = parallel::gemm::SchedulerBase; + using Launcher = + wrapper::gemm::LauncherBase; + Launcher kernel; + UT_Threading::set_threads(threads); + auto corestr = gemm::CoreAttr::to_str(Core_T::ID); + utils::timer tm; + auto tmpB = kernel.mProB.createStorage(n, k); + std::vector packBs(batch, 0); + std::vector bufB(tmpB.mSize * batch); + for (size_t i = 0; i < batch; i++) { + packBs[i] = tmpB; + packBs[i].assign(bufB.data() + i * tmpB.mSize); + kernel.mProB.packWeight(n, k, {B + i * n * k, n, &packBs[i]}, UT_Threading::get()); + } + auto psize = (size_t)m * n * k * 2; + tm.start(); + while (tm.stop() < timems) { + for (size_t i = 0; i < batch; i++) { + log.start(); + GemmProblem gp(1, m, n, k); + typename Launcher::Param args{gp, {A + i * m * k, k}, {0, 0, &packBs[i]}, {C + i * m * n, n}}; + parallel::GemmRun(kernel, args, UT_Threading::get()); + if (log.stop()) { + double flops = double(psize) / log.avg_val / 1e6; + printf("Threads %d %s %s Flops:%.3f PerCoreFlops:%.3f\n", threads, corestr, log.get_log_str(), flops, + flops / threads); + } + } + } + } + + void benchmark_all(size_t m, size_t n, size_t k, size_t batch) { + printf("%s %d %d %d %d\n", __FUNCTION__, int(m), int(n), int(k), int(batch)); + avector A(m * k * batch); + avector B(k * n * batch); + avector C(m * n * batch), RefC(m * n * batch); + fill_buffer_randn(A.data(), k * m, AType(-0.5f), AType(0.5f)); + fill_buffer_randn(B.data(), k * n, AType(-0.5f), AType(0.5f)); + for (size_t i = 0; i < batch - 1; i++) { + memcpy(A.data() + i * m * k, A.data(), m * k * sizeof(AType)); + memcpy(B.data() + i * n * k, B.data(), n * k * sizeof(BType)); + } + using LOG = timer_statistics_logger<100>; + float testtime = 500.f; + GetCPUDevice(); + if (_cd->AVX512_FP16()) { + benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime, _cd->getThreads()); + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + _cd->getThreads()); + benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime, _cd->getThreads()); + } + } +}; +#ifdef BTLA_UT_WRAPPER +static Benchmark_Fp16Fp16Fp16 sBenchmark_Fp16Fp16Fp16; +#endif + +class UTBenchmark_CompFp32 { + public: + UTBenchmark_CompFp32() { + UT_START(); + CheckISA(AVX512F); + ut_s4(); + /* ut_s8(); + ut_f4();*/ + } + + void ut_s4() { + // benchmark_all(1, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP); + benchmark_all(32, 4096, 4096, 128, BTLA_DTYPE::S3_CLIP); + benchmark_all(32, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP); + // benchmark_all(1024, 4096, 4096, 128, BTLA_DTYPE::S3_CLIP); + // benchmark_all(1024, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP); + // benchmark_all(2048, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP); + // benchmark_all(4096, 4096, 11008, 128, BTLA_DTYPE::S4_CLIP); + // benchmark_all(2, 4096, 4096, 32, BTLA_DTYPE::S4_FULLRANGE); + // benchmark_all(2, 4096, 4096, 128, BTLA_DTYPE::S4_FULLRANGE); + // benchmark_all(2, 4096, 4096, -1, BTLA_DTYPE::S4_FULLRANGE); + // benchmark_all(2, 4096, 4096, 32, BTLA_DTYPE::S4_CLIP); + // benchmark_all(2, 4096, 4096, 32, + // BTLA_DTYPE::S4_FULLRANGE); + } + + // void ut_s8() { + // ut(2, 4096, 4096, 32, BTLA_DTYPE::S8); + // ut(2, 4096, 4096, 128, BTLA_DTYPE::S8); + // ut(2, 4096, 4096, -1, BTLA_DTYPE::S8); + // ut(2, 4096, 4096, 32, BTLA_DTYPE::S8); + // } + + // void ut_f4() { + // ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_BNB); + // ut(2, 4096, 4096, -1, BTLA_DTYPE::F4_BNB); + // ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_E2M1); + // ut(2, 4096, 4096, -1, BTLA_DTYPE::F4_E2M1); + // ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_NF4); + // ut(2, 4096, 4096, -1, BTLA_DTYPE::F4_NF4); + // ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_BNB); + // ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_E2M1); + // ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_NF4); + // } + + template class Wei, typename Scale_T> + void benchmark(int m, int n, int k, int blocksize, int batch, float* A, float* B, float* C, float timems, int threads, + BTLA_DTYPE qtype) { + LOG_T log; + using Parallel = parallel::gemm::SchedulerBase; + using Launcher = wrapper::gemm::LauncherBase; + Launcher kernel; + UT_Threading::set_threads(threads); + auto corestr = gemm::CoreAttr::to_str(Core_T::ID); + utils::timer tm; + using WType = typename Wei::StorageWeight; + WType tmpB(0); + if constexpr (std::is_same_v, + prologue_b::gemm::WeightKBlockNInteger>) { + tmpB = kernel.mProB.createStorage(n, k, blocksize, qtype, bestla_dtype, bestla_dtype, false); + + } else if constexpr (std::is_same_v, + prologue_b::gemm::WeightKBlockNFloat>) { + tmpB = kernel.mProB.createStorage(n, k, blocksize, qtype, bestla_dtype); + } + std::vector packBs(batch, 0); + std::vector bufB(tmpB.mSize * batch); + for (size_t i = 0; i < batch; i++) { + packBs[i] = tmpB; + packBs[i].assign(bufB.data() + i * tmpB.mSize); + } + kernel.mProB.packWeight(n, k, B, n, &packBs[0], UT_Threading::get()); + for (size_t i = 1; i < batch; i++) { + memcpy(packBs[i].template WPtr(), packBs[0].template WPtr(), packBs[0].template WSize()); + memcpy(packBs[i].template SPtr(), packBs[0].template SPtr(), packBs[0].CSize() * sizeof(Scale_T)); + } + auto psize = (size_t)m * n * k * 2; + auto memsize = (size_t)packBs[0].mSize + (m * k + m * n) * sizeof(float); + tm.start(); + while (tm.stop() < timems) { + for (size_t i = 0; i < batch; i++) { + log.start(); + GemmProblem gp(1, m, n, k); + typename Launcher::Param args{gp, {A + i * m * k, k}, {&packBs[i]}, {C + i * m * n, n}}; + parallel::GemmRun(kernel, args, UT_Threading::get()); + if (log.stop()) { + double flops = double(psize) / log.avg_val / 1e6; + double band = double(memsize) / log.avg_val / 1e6; + printf("Threads %d %s %s Flops:%.3fG PerCoreFlops:%.3fG MemoryBandwidth:%.3fGB/s\n", threads, corestr, + log.get_log_str(), flops, flops / threads, band); + } + } + } + } + + template class Wei, typename Scale_T> + void benchmark_mem(int m, int n, int k, int blocksize, int batch, float* A, float* B, float* C, float timems, + int threads, BTLA_DTYPE qtype) { + LOG_T log; + using Parallel = parallel::gemm::SchedulerKBlock; + // using Launcher = + // wrapper::gemm::LauncherKBlock; + using Launcher = + wrapper::gemm::LauncherIntKBlock; + Launcher kernel; + UT_Threading::set_threads(threads); + auto corestr = gemm::CoreAttr::to_str(Core_T::ID); + utils::timer tm; + using WType = typename Wei::StorageWeight; + WType tmpB(0); + if constexpr (std::is_same_v, + prologue_b::gemm::WeightKBlockNInteger>) { + tmpB = kernel.mProB.createStorage(n, k, blocksize, qtype, bestla_dtype, bestla_dtype, false); + + } else if constexpr (std::is_same_v, + prologue_b::gemm::WeightKBlockNFloat>) { + tmpB = kernel.mProB.createStorage(n, k, blocksize, qtype, bestla_dtype); + } + std::vector packBs(batch, 0); + std::vector bufB(tmpB.mSize * batch); + auto quanA = kernel.mProA.createStorage(m, k, blocksize, false); + utils::avector bufferA(quanA.mSize); + quanA.assign(bufferA.data()); + for (size_t i = 0; i < batch; i++) { + packBs[i] = tmpB; + packBs[i].assign(bufB.data() + i * tmpB.mSize); + } + kernel.mProB.packWeight(n, k, B, n, &packBs[0], UT_Threading::get()); + for (size_t i = 1; i < batch; i++) { + memcpy(packBs[i].template WPtr(), packBs[0].template WPtr(), packBs[0].template WSize()); + memcpy(packBs[i].template SPtr(), packBs[0].template SPtr(), packBs[0].CSize() * sizeof(Scale_T)); + } + auto psize = (size_t)m * n * k * 2; + auto memsize = (size_t)packBs[0].mSize + (m * k + m * n) * sizeof(float); + tm.start(); + while (tm.stop() < timems) { + log.start(); + for (size_t i = 0; i < batch; i++) { + GemmProblem gp(1, m, n, k, blocksize); + typename Launcher::Param args{gp, + {A + i * m * k, k, &quanA}, + {&packBs[i]}, + // {packBs[i].template SPtr(), packBs[i].SDtype(), packBs[i].CStep()}, + {C + i * m * n, n}}; + // parallel::GemmRun(kernel, args, UT_Threading::get()); + parallel::GemmRunWithA(kernel, args, UT_Threading::get()); + } + if (log.stop()) { + double t = log.avg_val / batch; + double flops = double(psize) / t / 1e6; + double band = double(memsize) / t / 1e6; + printf("Threads %d %s Flops:%.3fG PerCoreFlops:%.3fG MemoryBandwidth:%.3fGB/s\n", threads, corestr, flops, + flops / threads, band); + } + } + } + + template