Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
[BesTLA] Simplify the templates (#274)
Browse files Browse the repository at this point in the history
* remove ISA from prologue_a

* fix assert condition

* add AUTOCALL

* remove ISA from prologueb

* remove gemmcore instance. remove runtime ISA

* compile with GCC

* apply refactor to all kernels

* remove warning

* compile with gcc, add linux UT preset

* clang-format

* fix class name of amx

* fix UT case

* fix clang-tidy

* support DQ for NFloat

* fix warning

* clang-format
  • Loading branch information
luoyu-intel authored Jun 1, 2024
1 parent 315df3a commit e0e65bd
Show file tree
Hide file tree
Showing 26 changed files with 1,510 additions and 1,472 deletions.
12 changes: 12 additions & 0 deletions CMakePresets.json
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,18 @@
"NS_USE_OMP": "OFF"
}
},
{
"name": "linux-release-ut-thread",
"displayName": "Linux Release Thread Pool for UTs",
"description": "Release",
"inherits": "linux-debug",
"cacheVariables": {
"CMAKE_BUILD_TYPE": "Release",
"NS_USE_OMP": "OFF",
"BTLA_UT_ALL": "ON",
"BTLA_UT_BENCHMARK": "ON"
}
},
{
"name": "windows-base",
"description": "Target Windows with the Visual Studio development environment.",
Expand Down
20 changes: 10 additions & 10 deletions bestla/bestla/bestla_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,9 @@ class CpuDevice {
inline bool AVX512_BF16() { return mHasAVX512_BF16; }
inline bool AVX512_FP16() { return mHasAVX512_FP16; }
inline float* const getPE() { return PE; }
inline size_t getPcoreNum() { return P_core.size(); }
inline size_t getEcoreNum() { return E_core.size(); }
inline size_t getSMTcoreNum() { return SMT_core.size(); }
inline int getPcoreNum() { return static_cast<int>(P_core.size()); }
inline int getEcoreNum() { return static_cast<int>(E_core.size()); }
inline int getSMTcoreNum() { return static_cast<int>(SMT_core.size()); }
inline int* getPCores() { return P_core.data(); }
inline int* getECores() { return E_core.data(); }
inline int* getSMTCores() { return SMT_core.data(); }
Expand Down Expand Up @@ -467,15 +467,15 @@ class CpuDevice {
bool isClient() { return mClient; }

protected:
uint32_t L2Cache, L1Cache, L3Cache;
uint32_t L2Cache = 0, L1Cache = 0, L3Cache = 0;
bool mHybrid = false, mClient = false;
bool mHasAVX2, mHasAVX_VNNI, mHasAVX, mHasAVX512_VNNI, mHasAMX_INT8, mHasAMX_BF16, mHasAVX512F, mHasAVX512BW,
mHasAVX512_BF16, mHasAVX512_FP16;
int numcores;
int numthreads;
bool mHasAVX2 = false, mHasAVX_VNNI = false, mHasAVX = false, mHasAVX512_VNNI = false, mHasAMX_INT8 = false,
mHasAMX_BF16 = false, mHasAVX512F = false, mHasAVX512BW, mHasAVX512_BF16 = false, mHasAVX512_FP16 = false;
int numcores = 0;
int numthreads = 0;
std::vector<int> P_core, E_core, SMT_core;
uint32_t E_L2Cache, E_L1Cache;
float PE[int(BTLA_ISA::ISA_COUNT)];
uint32_t E_L2Cache = 0, E_L1Cache = 0;
float PE[int(BTLA_ISA::ISA_COUNT)] = {1.f};
};

#define GetCPUDevice() auto _cd = bestla::device::CpuDevice::getInstance();
Expand Down
239 changes: 138 additions & 101 deletions bestla/bestla/bestla_gemm.h

Large diffs are not rendered by default.

31 changes: 19 additions & 12 deletions bestla/bestla/bestla_parallel.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ class IThreading {
virtual std::pair<float, float> get_PEtime() const { return {0.0f, 0.0f}; };

protected:
int mThreadNum;
const bool isSupportPE;
int mThreadNum = 0;
const bool isSupportPE = false;
};

#if BTLA_OPENMP
Expand Down Expand Up @@ -107,7 +107,14 @@ class OMPThreading : public IThreading {
class StdThreading : public IThreading {
public:
using Timer_T = utils::timer<utils::microseconds>;
explicit StdThreading() : IThreading(true) { cr = nullptr; }
explicit StdThreading() : IThreading(true) {
cr = nullptr;
memset(func_, 0, sizeof(func_));
memset(flag, 0, sizeof(flag));
stop = true;
time_per_p = -1.f;
time_per_e = -1.f;
}

void parallel_for(const thread_func& func) override {
time_per_p = 0;
Expand All @@ -117,7 +124,7 @@ class StdThreading : public IThreading {
running.store(mThreadNum - 1);
for (int i = 0; i < 10; i++) flag[i].store(mThreadNum);
if (cr->mHybrid) {
int time_p = 0, time_e = 0;
int64_t time_p = 0, time_e = 0;

for (size_t i = 0; i < mThreadNum - 1; i++) func_[i] = &func;
thread_time[0] = 0;
Expand All @@ -135,8 +142,8 @@ class StdThreading : public IThreading {
time_e += thread_time[i];
else
time_p += thread_time[i];
time_per_p = (time_p) / (1.0 * (mThreadNum - cr->E_core_num));
time_per_e = (time_e) / (1.0 * cr->E_core_num);
time_per_p = (time_p) / (1.0f * (mThreadNum - cr->E_core_num));
time_per_e = (time_e) / (1.0f * cr->E_core_num);
// printf("%d %d %f %f\n", time_p, time_e, time_per_p, time_per_e);
} else {
for (size_t i = 0; i < mThreadNum - 1; i++) {
Expand Down Expand Up @@ -810,7 +817,7 @@ class SchedulerDispatcher<Scheduler2D> {
} // namespace gemm

template <class Parallel_T, class Launch_T>
void GemmRun(Launch_T& launcher, const typename Launch_T::Param& args, parallel::IThreading* th) {
void GemmRun(const typename Launch_T::Param& args, parallel::IThreading* th) {
gemm::SchedulerDispatcher<Parallel_T> para(th, args.problem);
static bool flag = false;
if (flag) {
Expand All @@ -822,16 +829,16 @@ void GemmRun(Launch_T& launcher, const typename Launch_T::Param& args, parallel:
typename Parallel_T::ThreadProblem thdp{tidx};
para.getIndex(thdp);
if (thdp.valid) {
launcher.run(args, thdp);
Launch_T::run(args, thdp);
}
});
}

template <class Parallel_T, class Launch_T>
void GemmRunWithA(Launch_T& launcher, const typename Launch_T::Param& args, parallel::IThreading* th) {
void GemmRunWithA(const typename Launch_T::Param& args, parallel::IThreading* th) {
gemm::SchedulerDispatcher<Parallel_T> para(th, args.problem);
using AParall = typename Launch_T::PrologueA::Parallel;
AParall apara = launcher.mProA.createParallel(th->num_threads(), args.problem);
AParall apara = Launch_T::PrologueA::createParallel(th->num_threads(), args.problem);
static bool flag = false;
if (flag) {
printf("%s\n", __FUNCTION__);
Expand All @@ -842,13 +849,13 @@ void GemmRunWithA(Launch_T& launcher, const typename Launch_T::Param& args, para
typename AParall::ThreadProblem thdpA{tidx};
apara.getIndex(thdpA);
if (thdpA.valid) {
launcher.mProA.run(args.paramA, thdpA);
Launch_T::PrologueA::run(args.paramA, thdpA);
}
th->sync(tidx);
typename Parallel_T::ThreadProblem thdp{tidx};
para.getIndex(thdp);
if (thdp.valid) {
launcher.run(args, thdp);
Launch_T::run(args, thdp);
}
});
}
Expand Down
Loading

0 comments on commit e0e65bd

Please sign in to comment.