diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..7d2913bd3 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +################################################################################ +# This .gitignore file was automatically created by Microsoft(R) Visual Studio. +################################################################################ + +/.vs +/out diff --git a/CMakeLists.txt b/CMakeLists.txt index 463bb2117..69366f3a9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -70,6 +70,7 @@ if (NE_GPU) endif() option(NE_BUILD_TESTS "neural_engine: build tests" ${NE_STANDALONE}) +option(NE_BTLA_UT "enable BesTLA's unit tests" OFF) option(NE_BUILD_EXAMPLES "neural_engine: build examples" ${NE_STANDALONE}) if(NE_BUILD_TESTS) add_compile_definitions(NE_BUILD_TESTS) @@ -139,6 +140,11 @@ if (NE_PYTHON_API) add_subdirectory(third_party/pybind11) endif() -add_subdirectory(bestla jblas) +if (NE_BTLA_UT) + set(BTLA_UT_ALL ON) +endif() +include(FindOpenMP) + +add_subdirectory(bestla) add_subdirectory(neural_speed) diff --git a/CMakePresets.json b/CMakePresets.json index a3c8cdf25..6cca625b1 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -57,6 +57,16 @@ "description": "Target Windows (64-bit) with the Visual Studio development environment. (RelWithDebInfo)", "inherits": "x64-debug", "cacheVariables": { "CMAKE_BUILD_TYPE": "Release" } + }, + { + "name": "x64-bestla-UT", + "displayName": "x64 BesTLA unit test", + "description": "Target Windows (64-bit) with the Visual Studio development environment. (RelWithDebInfo)", + "inherits": "x64-debug", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Release", + "NE_BTLA_UT": "ON" + } } ] } diff --git a/README.md b/README.md index 4683601e5..1be563b17 100644 --- a/README.md +++ b/README.md @@ -385,7 +385,7 @@ Argument description of inference.py: | --keep | Number of tokens to keep from the initial prompt: Int (default: 0, -1 = all) | | --shift-roped-k | Use [ring-buffer](./docs/infinite_inference.md#shift-rope-k-and-ring-buffer) and thus do not re-computing after reaching ctx_size (default: False) | | --glm_tokenizer | The path of the chatglm tokenizer: String (default: THUDM/chatglm-6b) | -| --memory-f32
--memory-f16
--memory-auto | Data type of kv memory (default to auto);
If set to auto, the runtime will try with jblas flash attn managed format (currently requires GCC11+ & AMX) and fall back to fp16 if failed | +| --memory-f32
--memory-f16
--memory-auto | Data type of kv memory (default to auto);
If set to auto, the runtime will try with bestla flash attn managed format (currently requires GCC11+ & AMX) and fall back to fp16 if failed | ### 3. Tensor Parallelism cross nodes/sockets diff --git a/bestla/CMakeLists.txt b/bestla/CMakeLists.txt index d05d1b299..c07ac66dc 100644 --- a/bestla/CMakeLists.txt +++ b/bestla/CMakeLists.txt @@ -1,40 +1,40 @@ -cmake_minimum_required(VERSION 3.5) +cmake_minimum_required(VERSION 3.12) -project(jblas LANGUAGES CXX VERSION 0.1.0) +project(bestla LANGUAGES CXX VERSION 0.1.0) file(GLOB headers ${PROJECT_NAME}/*.h ${PROJECT_NAME}/*.hpp) file(GLOB xbyak_headers ${PROJECT_NAME}/xbyak/*.h ${PROJECT_NAME}/xbyak/*.hpp) -option(JBLAS_UT_ALL "Enable all unit tests" OFF) -option(JBLAS_UT_DEBUG "Enable debug unit tests" ON) -option(JBLAS_UT_EPILOGUE "Enable unit test for epilogue" OFF) -option(JBLAS_UT_PROLOGUE_A "Enable unit test for activation prologue" OFF) -option(JBLAS_UT_PROLOGUE_B "Enable unit test for weight prologue" OFF) -option(JBLAS_UT_GEMM "Enable unit test for micro gemm kernels" OFF) -option(JBLAS_UT_WRAPPER "Enable unit test for parallel gemms" OFF) -option(JBLAS_UT_PARALLEL "Enable unit test for parallel set" OFF) -option(JBLAS_UT_KERNEL_JIT "Enable unit test for jit kernels" OFF) -option(JBLAS_UT_KERNEL_INTRIN "Enable unit test for intrinsic kernels" OFF) -option(JBLAS_UT_KERNEL_WRAPPER "Enable unit test for runtime ISA kernels" OFF) -option(JBLAS_UT_NOASAN "Disable sanitize" OFF) -option(JBLAS_UT_BENCHMARK "Benchmark ON may take a long time to finish all tests" OFF) -option(JBLAS_UT_OPENMP "Use OpenMP" ON) - -if(JBLAS_UT_ALL) -set(JBLAS_UT_EPILOGUE ON) -set(JBLAS_UT_PROLOGUE_A ON) -set(JBLAS_UT_PROLOGUE_B ON) -set(JBLAS_UT_GEMM ON) -set(JBLAS_UT_WRAPPER ON) -set(JBLAS_UT_PARALLEL ON) -set(JBLAS_UT_KERNEL_JIT ON) -set(JBLAS_UT_KERNEL_INTRIN ON) -set(JBLAS_UT_KERNEL_WRAPPER ON) -endif(JBLAS_UT_ALL) +option(BTLA_UT_ALL "Enable all unit tests" OFF) +option(BTLA_UT_DEBUG "Enable debug unit tests" ON) +option(BTLA_UT_EPILOGUE "Enable unit test for epilogue" OFF) +option(BTLA_UT_PROLOGUE_A "Enable unit test for activation prologue" OFF) +option(BTLA_UT_PROLOGUE_B "Enable unit test for weight prologue" OFF) +option(BTLA_UT_GEMM "Enable unit test for micro gemm kernels" OFF) +option(BTLA_UT_WRAPPER "Enable unit test for parallel gemms" OFF) +option(BTLA_UT_PARALLEL "Enable unit test for parallel set" OFF) +option(BTLA_UT_KERNEL_JIT "Enable unit test for jit kernels" OFF) +option(BTLA_UT_KERNEL_INTRIN "Enable unit test for intrinsic kernels" OFF) +option(BTLA_UT_KERNEL_WRAPPER "Enable unit test for runtime ISA kernels" OFF) +option(BTLA_UT_NOASAN "Disable sanitize" OFF) +option(BTLA_UT_BENCHMARK "Benchmark ON may take a long time to finish all tests" OFF) +option(BTLA_UT_OPENMP "Use OpenMP" ON) + +if(BTLA_UT_ALL) +set(BTLA_UT_EPILOGUE ON) +set(BTLA_UT_PROLOGUE_A ON) +set(BTLA_UT_PROLOGUE_B ON) +set(BTLA_UT_GEMM ON) +set(BTLA_UT_WRAPPER ON) +set(BTLA_UT_PARALLEL ON) +set(BTLA_UT_KERNEL_JIT ON) +set(BTLA_UT_KERNEL_INTRIN ON) +set(BTLA_UT_KERNEL_WRAPPER ON) +endif(BTLA_UT_ALL) set(UT_BUILD FALSE) -if(JBLAS_UT_DEBUG OR JBLAS_UT_PROLOGUE_A OR JBLAS_UT_PROLOGUE_B OR JBLAS_UT_EPILOGUE OR JBLAS_UT_GEMM -OR JBLAS_UT_WRAPPER OR JBLAS_UT_PARALLEL OR JBLAS_UT_KERNEL_JIT OR JBLAS_UT_KERNEL_INTRIN -OR JBLAS_UT_KERNEL_WRAPPER) +if(BTLA_UT_DEBUG OR BTLA_UT_PROLOGUE_A OR BTLA_UT_PROLOGUE_B OR BTLA_UT_EPILOGUE OR BTLA_UT_GEMM +OR BTLA_UT_WRAPPER OR BTLA_UT_PARALLEL OR BTLA_UT_KERNEL_JIT OR BTLA_UT_KERNEL_INTRIN +OR BTLA_UT_KERNEL_WRAPPER) set(UT_BUILD TRUE) endif() @@ -91,10 +91,7 @@ if(WIN32) target_link_options(${PROJECT_NAME} INTERFACE /STACK:5242880) #Stack requires up to L2 cache size endif(WIN32) -if(JBLAS_UT_OPENMP) -include(FindOpenMP) -target_link_libraries(${PROJECT_NAME} INTERFACE OpenMP::OpenMP_CXX OpenMP::OpenMP_C) -endif() + set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) @@ -109,27 +106,31 @@ target_compile_features(${PROJECT_NAME} INTERFACE cxx_std_17) if(UT_BUILD) file(GLOB srcs ${PROJECT_NAME}/ut/*.cc ${PROJECT_NAME}/ut/*.cpp) #compile everthing even run parts of UTs file(GLOB ut_headers ${PROJECT_NAME}/ut/*.h) + include_directories(${PROJECT_NAME}) add_executable(${PROJECT_NAME}_ut ${srcs} ${headers} ${ut_headers}) - + if(BTLA_UT_OPENMP) + include(FindOpenMP) + target_link_libraries(${PROJECT_NAME}_ut PRIVATE OpenMP::OpenMP_CXX OpenMP::OpenMP_C) + endif() if(NOT WIN32) - if(NOT JBLAS_UT_NOASAN) - target_compile_options(${PROJECT_NAME}_ut PRIVATE -fsanitize=address) - target_link_options(${PROJECT_NAME}_ut PRIVATE -fsanitize=address) + if(NOT BTLA_UT_NOASAN) + target_compile_options(${PROJECT_NAME}_ut PRIVATE -fsanitize=address) + target_link_options(${PROJECT_NAME}_ut PRIVATE -fsanitize=address) endif() target_link_options(${PROJECT_NAME}_ut PRIVATE -lpthread) endif() - add_ut_flag(JBLAS_UT_DEBUG) - add_ut_flag(JBLAS_UT_EPILOGUE) - add_ut_flag(JBLAS_UT_PROLOGUE_A) - add_ut_flag(JBLAS_UT_PROLOGUE_B) - add_ut_flag(JBLAS_UT_GEMM) - add_ut_flag(JBLAS_UT_PARALLEL) - add_ut_flag(JBLAS_UT_WRAPPER) - add_ut_flag(JBLAS_UT_KERNEL_INTRIN) - add_ut_flag(JBLAS_UT_KERNEL_JIT) - add_ut_flag(JBLAS_UT_KERNEL_WRAPPER) - add_ut_flag(JBLAS_UT_BENCHMARK) + add_ut_flag(BTLA_UT_DEBUG) + add_ut_flag(BTLA_UT_EPILOGUE) + add_ut_flag(BTLA_UT_PROLOGUE_A) + add_ut_flag(BTLA_UT_PROLOGUE_B) + add_ut_flag(BTLA_UT_GEMM) + add_ut_flag(BTLA_UT_PARALLEL) + add_ut_flag(BTLA_UT_WRAPPER) + add_ut_flag(BTLA_UT_KERNEL_INTRIN) + add_ut_flag(BTLA_UT_KERNEL_JIT) + add_ut_flag(BTLA_UT_KERNEL_WRAPPER) + add_ut_flag(BTLA_UT_BENCHMARK) target_link_libraries(${PROJECT_NAME}_ut PRIVATE ${PROJECT_NAME}) endif(UT_BUILD) diff --git a/bestla/README.md b/bestla/README.md index 08a29a9c9..8b46f5a9b 100644 --- a/bestla/README.md +++ b/bestla/README.md @@ -52,6 +52,6 @@ Compile: Usage: ```cmake -add_subdirectory(jblas) -target_link_libraries("${YOUR_PROJECT}" jblas::jblas) +add_subdirectory(bestla) +target_link_libraries("${YOUR_PROJECT}" bestla::bestla) ``` diff --git a/bestla/jblas/jit_blas.h b/bestla/bestla/bestla.h similarity index 74% rename from bestla/jblas/jit_blas.h rename to bestla/bestla/bestla.h index 8446698e3..890704814 100644 --- a/bestla/jblas/jit_blas.h +++ b/bestla/bestla/bestla.h @@ -13,26 +13,26 @@ // limitations under the License. #pragma once #include -enum JBLAS_CODE { - JblasSuccess = 0, - JblasInvalidParam = 1, - JblasInvalidISA = 2, - JblasRuntimeError = 4, - JblasNotSupport = 8, +enum class BTLA_CODE { + Success = 0, + InvalidParam = 1, + InvalidISA = 2, + RuntimeError = 4, + NotSupport = 8, }; -enum JBLAS_ISA : uint8_t { - JblasNoSIMD = 0, - JblasAVX, - JblasAVX2, - JblasAVX_VNNI, - JblasAVX512F, - JblasAVX512_VNNI, - JblasAMX_BF16, - JblasAMX_INT8, - JblasAVX512_FP16, - JblasAVX512_BF16, +enum class BTLA_ISA : uint8_t { + NoSIMD = 0, + AVX, + AVX2, + AVX_VNNI, + AVX512F, + AVX512_VNNI, + AMX_BF16, + AMX_INT8, + AVX512_FP16, + AVX512_BF16, }; -enum class JBLAS_DTYPE : uint32_t { +enum class BTLA_DTYPE : uint32_t { EleBitsMask = 0xff, EleBitsShift = 0, EleBitsUndef = 0, @@ -70,15 +70,9 @@ enum class JBLAS_DTYPE : uint32_t { U32 = EleBits32 | TypeInt | SubType1, }; -enum JBLAS_LAYOUT { JblasRowMajor = 101, JblasColMajor = 102 }; -enum JBLAS_TRANSPOSE { - JblasNoTrans = 111, - JblasTrans = 112, - JblasConjTrans = 113, -}; -enum JBLAS_ELTWISEOP { GELU, SWISH, TANH, EXP, LOW_PRECISION_EXP, RELU, LINEAR }; +enum class BTLA_ELTWISEOP { GELU, SWISH, TANH, EXP, LOW_PRECISION_EXP, RELU, LINEAR }; -enum class JBLAS_PROLOGUEB_IDS : uint32_t { +enum class BTLA_PROLOGUEB_IDS : uint32_t { Undef = (uint32_t)-1, Begin = 0, NormalBegin = Begin, diff --git a/bestla/jblas/jit_blas_device.h b/bestla/bestla/bestla_device.h similarity index 95% rename from bestla/jblas/jit_blas_device.h rename to bestla/bestla/bestla_device.h index 4e54b63d3..ebfc7d8de 100644 --- a/bestla/jblas/jit_blas_device.h +++ b/bestla/bestla/bestla_device.h @@ -15,7 +15,7 @@ #include #include #include -#include "jit_blas.h" +#include "bestla.h" #include "xbyak/xbyak_util.h" #ifdef _WIN32 #include @@ -23,7 +23,7 @@ #include #endif -namespace jblas { +namespace bestla { namespace device { @@ -195,16 +195,16 @@ class SapphireRapids { static constexpr bool AMX_COMPLEX = 0; }; -template +template class isa_base { public: - static bool constexpr avx = ISA_T >= JblasAVX; - static bool constexpr avx2 = ISA_T >= JblasAVX2; - static bool constexpr avx512f = ISA_T >= JblasAVX512F; - static bool constexpr avx512_vnni = ISA_T >= JblasAVX512_VNNI; - static bool constexpr avx512_fp16 = ISA_T >= JblasAVX512_FP16; - static bool constexpr amx_bf16 = ISA_T >= JblasAMX_BF16; - static bool constexpr amx_int8 = ISA_T >= JblasAMX_INT8; + static bool constexpr avx = ISA_T >= BTLA_ISA::AVX; + static bool constexpr avx2 = ISA_T >= BTLA_ISA::AVX2; + static bool constexpr avx512f = ISA_T >= BTLA_ISA::AVX512F; + static bool constexpr avx512_vnni = ISA_T >= BTLA_ISA::AVX512_VNNI; + static bool constexpr avx512_fp16 = ISA_T >= BTLA_ISA::AVX512_FP16; + static bool constexpr amx_bf16 = ISA_T >= BTLA_ISA::AMX_BF16; + static bool constexpr amx_int8 = ISA_T >= BTLA_ISA::AMX_INT8; }; class CpuDevice { @@ -422,7 +422,7 @@ class CpuDevice { float P_power = 4.8, E_power = 2.3; }; -#define GetCPUDevice() auto _cd = jblas::device::CpuDevice::getInstance(); +#define GetCPUDevice() auto _cd = bestla::device::CpuDevice::getInstance(); class CpuBase { public: @@ -436,4 +436,4 @@ class CpuBase { int mNumThreads; }; } // namespace device -} // namespace jblas +} // namespace bestla diff --git a/bestla/jblas/jit_blas_epilogue.h b/bestla/bestla/bestla_epilogue.h similarity index 76% rename from bestla/jblas/jit_blas_epilogue.h rename to bestla/bestla/bestla_epilogue.h index 6c7e0ee4e..de25f330b 100644 --- a/bestla/jblas/jit_blas_epilogue.h +++ b/bestla/bestla/bestla_epilogue.h @@ -14,12 +14,12 @@ #pragma once #include -#include "jit_base.h" -#include "jit_blas.h" -#include "jit_blas_utils.h" +#include "bestla.h" +#include "bestla_jit.h" +#include "bestla_utils.h" #include "kernel_wrapper.h" -namespace jblas { +namespace bestla { namespace epilogue { namespace gemm { @@ -30,15 +30,15 @@ struct ParamAccumulatorWriteBack { void* elt_const_v; }; -template +template class AccumulatorWriteBack { public: using SType = _SRC_T; using DType = _DST_T; using Param = ParamAccumulatorWriteBack; - JBLAS_CODE forward(const _SRC_T* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, - const int N, const Param& _param, void* tmpcache, size_t cachesize) { + BTLA_CODE forward(const _SRC_T* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, + const int N, const Param& _param, void* tmpcache, size_t cachesize) { auto COffset = M_offset * _param.ldc + N_offset; auto cptr = _param.C + COffset; return kernel::wrapper::Memcpy2D::template forward(cacheptr, cptr, M, N, cachestep, _param.ldc, @@ -46,12 +46,12 @@ class AccumulatorWriteBack { } }; -template +template class CustomAccumulatorWriteBackWithEltop { public: using Param = ParamAccumulatorWriteBack<_DST_T>; - JBLAS_CODE forward(const _SRC_T* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, - const int N, const Param& _param, void* tmpcache, size_t cachesize) { + BTLA_CODE forward(const _SRC_T* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, + const int N, const Param& _param, void* tmpcache, size_t cachesize) { auto COffset = M_offset * _param.ldc + N_offset; auto cptr = _param.C + COffset; if constexpr (std::is_same<_SRC_T, float>::value && std::is_same<_DST_T, float>::value) { @@ -62,26 +62,27 @@ class CustomAccumulatorWriteBackWithEltop { } } }; -template +template using AccumulatorWriteBackFp32 = AccumulatorWriteBack; -template +template using AccumulatorWriteBackInt32 = AccumulatorWriteBack; -template +template using AccumulatorWriteBackBf16 = AccumulatorWriteBack; -template +template using AccumulatorWriteBackFp16 = AccumulatorWriteBack; -template +template using AccumulatorWriteBackBf16Fp32 = AccumulatorWriteBack; -template +template using AccumulatorWriteBackFp16Fp32 = AccumulatorWriteBack; -template +template using AccumulatorWriteBackFp32Bf16 = AccumulatorWriteBack; -template -using AccumulatorWriteBackWithGeluFp32 = CustomAccumulatorWriteBackWithEltop; +template +using AccumulatorWriteBackWithGeluFp32 = CustomAccumulatorWriteBackWithEltop; -template -using AccumulatorWriteBackWithSwishFp32 = CustomAccumulatorWriteBackWithEltop; +template +using AccumulatorWriteBackWithSwishFp32 = + CustomAccumulatorWriteBackWithEltop; template struct ParamAlphaBetaProcess { @@ -89,13 +90,13 @@ struct ParamAlphaBetaProcess { int ldc, ldd; float alpha, beta; }; -template +template class AlphaBetaProcessFp32 { public: using Param = ParamAlphaBetaProcess; - JBLAS_CODE forward(const float* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, - const int N, const Param& _param, void* tmpcache, size_t cachesize) { + BTLA_CODE forward(const float* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, + const int N, const Param& _param, void* tmpcache, size_t cachesize) { auto DOffset = M_offset * _param.ldd + N_offset; auto COffset = M_offset * _param.ldc + N_offset; auto cptr = _param.C + COffset; @@ -107,43 +108,43 @@ class AlphaBetaProcessFp32 { struct ParamCompFp32BlockEpilogue { void* scales; - JBLAS_DTYPE scaledtype; + BTLA_DTYPE scaledtype; int ldsb; int8_t* zps = nullptr; float* reduce = nullptr; int ldra; }; -template +template class CompFp32BlockEpilogue { public: using Param = ParamCompFp32BlockEpilogue; - JBLAS_CODE forward(const float* srcptr, float* dstptr, const int cachestep, const int M_offset, const int N_offset, - const int K_offset, const int M, const int N, const Param& _param, void* tmpcache, - size_t cachesize) { - auto ret = JblasNotSupport; - if (_param.scaledtype == JBLAS_DTYPE::F32) { + BTLA_CODE forward(const float* srcptr, float* dstptr, const int cachestep, const int M_offset, const int N_offset, + const int K_offset, const int M, const int N, const Param& _param, void* tmpcache, + size_t cachesize) { + auto ret = BTLA_CODE::NotSupport; + if (_param.scaledtype == BTLA_DTYPE::F32) { ret = kernel::wrapper::CompFp32BlockScale::template forward( reinterpret_cast(_param.scales) + K_offset * _param.ldsb + N_offset, srcptr, cachestep, dstptr, cachestep, M, N); - assert(ret == JblasSuccess); + assert(ret == BTLA_CODE::Success); if (_param.zps != nullptr) { ret = kernel::wrapper::RemoveZeroPointBias::forward_wei( dstptr, cachestep, M, N, _param.zps + K_offset * _param.ldsb + N_offset, reinterpret_cast(_param.scales) + K_offset * _param.ldsb + N_offset, _param.ldra, _param.reduce + M_offset * _param.ldra + K_offset); } - assert(ret == JblasSuccess); + assert(ret == BTLA_CODE::Success); return ret; - } else if (_param.scaledtype == JBLAS_DTYPE::BF16) { + } else if (_param.scaledtype == BTLA_DTYPE::BF16) { ret = kernel::wrapper::CompFp32BlockScale::template forward( reinterpret_cast(_param.scales) + K_offset * _param.ldsb + N_offset, srcptr, cachestep, dstptr, cachestep, M, N); if (_param.zps != nullptr) { assert(0); } - assert(ret == JblasSuccess); + assert(ret == BTLA_CODE::Success); return ret; - } else if (_param.scaledtype == JBLAS_DTYPE::F8_E8M0) { + } else if (_param.scaledtype == BTLA_DTYPE::F8_E8M0) { ret = kernel::wrapper::CompFp32BlockScale::template forward( reinterpret_cast(_param.scales) + K_offset * _param.ldsb + N_offset, srcptr, cachestep, dstptr, cachestep, M, N); @@ -151,7 +152,7 @@ class CompFp32BlockEpilogue { assert(0); } } - return JblasNotSupport; + return BTLA_CODE::NotSupport; } }; @@ -162,12 +163,12 @@ struct ParamDequantInt32ToFp32 { float* scalesA; float* scalesB; }; -template +template class DequantInt32ToFp32 { public: using Param = ParamDequantInt32ToFp32; - JBLAS_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, - const int N, const Param& _param, void* tmpcache, size_t cachesize) { + BTLA_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, + const int N, const Param& _param, void* tmpcache, size_t cachesize) { auto COffset = M_offset * _param.ldc + N_offset; auto cptr = _param.C + COffset; return kernel::wrapper::DequanS32Fp32::template forward(cacheptr, cachestep, cptr, _param.ldc, M, N, @@ -178,61 +179,61 @@ class DequantInt32ToFp32 { struct ParamCompInt8BlockEpilogue { void* scalesB; - JBLAS_DTYPE scaleBdtype; + BTLA_DTYPE scaleBdtype; int ldsb; float* scalesA; int ldsa; // optional if A asym uint8_t* zpA = nullptr; void* reduceB = nullptr; - JBLAS_DTYPE reduceBdtype = JBLAS_DTYPE::F32; + BTLA_DTYPE reduceBdtype = BTLA_DTYPE::F32; // optional if B asym int8_t* zpB = nullptr; float* reduceA = nullptr; int K = 1; }; -template +template class CompInt8BlockEpilogue { public: using Param = ParamCompInt8BlockEpilogue; - JBLAS_CODE forward(const int32_t* srcptr, float* dstptr, const int cachestep, const int M_offset, const int N_offset, - const int K_offset, const int M, const int N, const Param& _param, void* tmpcache, - size_t cachesize) { - JBLAS_CODE ret = JblasNotSupport; + BTLA_CODE forward(const int32_t* srcptr, float* dstptr, const int cachestep, const int M_offset, const int N_offset, + const int K_offset, const int M, const int N, const Param& _param, void* tmpcache, + size_t cachesize) { + BTLA_CODE ret = BTLA_CODE::NotSupport; float* scab = nullptr; size_t ScaleBTmpSize = N * sizeof(float); size_t ReduceBTmpSize = N * sizeof(float); assert(cachesize >= (ScaleBTmpSize + ReduceBTmpSize)); - if (_param.scaleBdtype == JBLAS_DTYPE::BF16) { + if (_param.scaleBdtype == BTLA_DTYPE::BF16) { auto scache = reinterpret_cast(tmpcache); ret = kernel::wrapper::Memcpy2DBf16CvtFp32::template forward( reinterpret_cast(_param.scalesB) + N_offset + K_offset * _param.ldsb, scache, 1, N, N, N, false); - assert(ret == JblasSuccess); + assert(ret == BTLA_CODE::Success); scab = scache; - } else if (_param.scaleBdtype == JBLAS_DTYPE::F32) { + } else if (_param.scaleBdtype == BTLA_DTYPE::F32) { scab = reinterpret_cast(_param.scalesB) + N_offset + K_offset * _param.ldsb; } float* redb = nullptr; if (_param.reduceB) { - if (_param.reduceBdtype == JBLAS_DTYPE::BF16) { + if (_param.reduceBdtype == BTLA_DTYPE::BF16) { auto rcache = reinterpret_cast(reinterpret_cast(tmpcache) + ScaleBTmpSize); ret = kernel::wrapper::Memcpy2DBf16CvtFp32::template forward( reinterpret_cast(_param.reduceB) + N_offset + K_offset * _param.ldsb, rcache, 1, N, N, N, false); - assert(ret == JblasSuccess); + assert(ret == BTLA_CODE::Success); redb = rcache; - } else if (_param.reduceBdtype == JBLAS_DTYPE::F32) { + } else if (_param.reduceBdtype == BTLA_DTYPE::F32) { redb = reinterpret_cast(_param.reduceB) + N_offset + K_offset * _param.ldsb; } } ret = kernel::wrapper::DequanS32Fp32::template forward( srcptr, cachestep, reinterpret_cast(const_cast(srcptr)), cachestep, M, N, _param.scalesA + M_offset * _param.ldsa + K_offset, _param.ldsa, scab); - assert(ret == JblasSuccess); + assert(ret == BTLA_CODE::Success); ret = kernel::wrapper::AccumulateFp32::template forward(reinterpret_cast(srcptr), cachestep, dstptr, cachestep, M, N); - assert(ret == JblasSuccess); + assert(ret == BTLA_CODE::Success); if (_param.zpA == nullptr) { if (_param.zpB == nullptr) { @@ -273,18 +274,18 @@ struct ParamZpDequantInt32ToFp32 { float* reduceA = nullptr; int K = 1; }; -template +template class ZpDequantInt32ToFp32 { public: using Param = ParamZpDequantInt32ToFp32; - JBLAS_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, - const int N, const Param& _param, void* tmpcache, size_t cachesize) { + BTLA_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, + const int N, const Param& _param, void* tmpcache, size_t cachesize) { auto COffset = M_offset * _param.ldc + N_offset; auto cptr = _param.C + COffset; auto ret = kernel::wrapper::DequanS32Fp32::template forward(cacheptr, cachestep, cptr, _param.ldc, M, N, _param.scalesA + M_offset * _param.ldsa, _param.ldsa, _param.scalesB + N_offset); - if (ret != JblasSuccess) { + if (ret != BTLA_CODE::Success) { return ret; } if (_param.zpA == nullptr && _param.zpB == nullptr) { @@ -314,12 +315,12 @@ struct ParamAlphaBetaProcessS32U8 { float scaleAcc, scaleC; int zpC; }; -template +template class AlphaBetaProcessS32U8 { public: using Param = ParamAlphaBetaProcessS32U8; - JBLAS_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, - const int N, const Param& _param, void* tmpcache, size_t cachesize) { + BTLA_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, + const int N, const Param& _param, void* tmpcache, size_t cachesize) { auto COffset = M_offset * _param.ldc + N_offset; auto cptr = _param.C + COffset; return kernel::wrapper::QuanOutS32U32::template forward(_param.alpha, cacheptr, cachestep, cptr, _param.ldc, @@ -329,4 +330,4 @@ class AlphaBetaProcessS32U8 { } // namespace gemm } // namespace epilogue -} // namespace jblas +} // namespace bestla diff --git a/bestla/jblas/jit_blas_gemm.h b/bestla/bestla/bestla_gemm.h similarity index 98% rename from bestla/jblas/jit_blas_gemm.h rename to bestla/bestla/bestla_gemm.h index add580da3..6829f1d0c 100644 --- a/bestla/jblas/jit_blas_gemm.h +++ b/bestla/bestla/bestla_gemm.h @@ -14,13 +14,13 @@ #pragma once #include -#include "jit_blas_utils.h" -#include "jit_base.h" +#include "bestla_utils.h" +#include "bestla_jit.h" -namespace jblas { +namespace bestla { namespace gemm { enum class CompType : uint16_t { - // base type, too many bits if reuse JBLAS_DTYPE + // base type, too many bits if reuse BTLA_DTYPE tFP32 = 0, tBF16 = 1, tFP16 = 2, @@ -87,7 +87,7 @@ class CoreAttr { static inline uint64_t get_mask_val(uint64_t raw, uint64_t mask, uint64_t shift) { return (raw & mask) >> shift; } - static constexpr uint64_t make_core_id(int NTile, int PackRow, CompType CompType, JBLAS_ISA ISA) { + static constexpr uint64_t make_core_id(int NTile, int PackRow, CompType CompType, BTLA_ISA ISA) { return (static_cast(NTile) << NTILE_SHIFT) | (static_cast(PackRow) << PACKROW_SHIFT) | (static_cast(CompType) << COMP_SHIFT) | (static_cast(ISA) << ISA_SHIFT); } @@ -114,7 +114,7 @@ class CoreAttr { return size_t(4 / packrow); } - static inline JBLAS_ISA get_ISA(uint64_t id) { return static_cast(get_mask_val(id, ISA_MASK, ISA_SHIFT)); } + static inline BTLA_ISA get_ISA(uint64_t id) { return static_cast(get_mask_val(id, ISA_MASK, ISA_SHIFT)); } static inline CompType get_comp(uint64_t id) { return static_cast(get_mask_val(id, COMP_MASK, COMP_SHIFT)); @@ -124,7 +124,7 @@ class CoreAttr { namespace code { template -class Avx2N8P1 : protected jblas::xbyak::JitAvx2 { +class Avx2N8P1 : protected bestla::xbyak::JitAvx2 { public: static int constexpr RegLen = 8, PackRow = 1; static_assert(_NTILE % RegLen == 0); @@ -133,7 +133,7 @@ class Avx2N8P1 : protected jblas::xbyak::JitAvx2 { static_assert(NRegs * MRegs <= RegCount - 1); static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 1; static int constexpr KUNROLL = 2; - static auto constexpr ISA = JBLAS_ISA::JblasAVX2; + static auto constexpr ISA = BTLA_ISA::AVX2; static auto constexpr COMPUTE = CompType::COMP_FP32; typedef float AType; typedef float BType; @@ -344,7 +344,7 @@ class Avx2N8P1 : protected jblas::xbyak::JitAvx2 { }; template -class Avx512fN16P1 : protected jblas::xbyak::JitAvx512f { +class Avx512fN16P1 : protected bestla::xbyak::JitAvx512f { public: static int constexpr RegLen = 16, PackRow = 1; static_assert(_NTILE % RegLen == 0); @@ -353,7 +353,7 @@ class Avx512fN16P1 : protected jblas::xbyak::JitAvx512f { static_assert(NRegs * MRegs <= RegCount - 1); static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 1; static int constexpr KUNROLL = 2; - static auto constexpr ISA = JBLAS_ISA::JblasAVX512F; + static auto constexpr ISA = BTLA_ISA::AVX512F; static auto constexpr COMPUTE = CompType::COMP_FP32; typedef float AType; typedef float BType; @@ -564,7 +564,7 @@ class Avx512fN16P1 : protected jblas::xbyak::JitAvx512f { }; template -class Avx512fp16N32P1 : protected jblas::xbyak::JitAvx512_fp16 { +class Avx512fp16N32P1 : protected bestla::xbyak::JitAvx512_fp16 { public: static int constexpr RegLen = 32, PackRow = 1; static_assert(_NTILE % RegLen == 0); @@ -573,7 +573,7 @@ class Avx512fp16N32P1 : protected jblas::xbyak::JitAvx512_fp16 { static_assert(NRegs * MRegs <= RegCount - 1); static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 1; static int constexpr KUNROLL = 2; - static auto constexpr ISA = JBLAS_ISA::JblasAVX512_FP16; + static auto constexpr ISA = BTLA_ISA::AVX512_FP16; static auto constexpr COMPUTE = CompType::COMP_FP16_FP16; typedef utils::fp16 AType; typedef utils::fp16 BType; @@ -784,7 +784,7 @@ class Avx512fp16N32P1 : protected jblas::xbyak::JitAvx512_fp16 { }; template -class Avx512bf16N16P2 : protected jblas::xbyak::JitAvx512_bf16 { +class Avx512bf16N16P2 : protected bestla::xbyak::JitAvx512_bf16 { public: static int constexpr RegLen = 16, PackRow = 2; static_assert(_NTILE % RegLen == 0); @@ -793,7 +793,7 @@ class Avx512bf16N16P2 : protected jblas::xbyak::JitAvx512_bf16 { static_assert(NRegs * MRegs <= RegCount - 1); static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 2; static int constexpr KUNROLL = 2; - static auto constexpr ISA = JBLAS_ISA::JblasAVX512_BF16; + static auto constexpr ISA = BTLA_ISA::AVX512_BF16; static auto constexpr COMPUTE = CompType::COMP_BF16_FP32; typedef utils::bf16 AType; typedef utils::bf16 BType; @@ -1004,7 +1004,7 @@ class Avx512bf16N16P2 : protected jblas::xbyak::JitAvx512_bf16 { }; template -class Avx512vnniN16P4 : protected jblas::xbyak::JitAvx512vnni { +class Avx512vnniN16P4 : protected bestla::xbyak::JitAvx512vnni { public: static int constexpr RegLen = 16, PackRow = 4; static_assert(_NTILE % RegLen == 0); @@ -1013,7 +1013,7 @@ class Avx512vnniN16P4 : protected jblas::xbyak::JitAvx512vnni { static_assert(NRegs * MRegs <= RegCount - 1); static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 4; static int constexpr KUNROLL = 2; - static auto constexpr ISA = JBLAS_ISA::JblasAVX512_VNNI; + static auto constexpr ISA = BTLA_ISA::AVX512_VNNI; static auto constexpr COMPUTE = CompType::COMP_INT8_US_INT32; typedef uint8_t AType; typedef int8_t BType; @@ -1223,7 +1223,7 @@ class Avx512vnniN16P4 : protected jblas::xbyak::JitAvx512vnni { }; template -class AvxvnniN8P4 : protected jblas::xbyak::JitAvxvnni { +class AvxvnniN8P4 : protected bestla::xbyak::JitAvxvnni { public: static int constexpr RegLen = 8, PackRow = 4; static_assert(_NTILE % RegLen == 0); @@ -1232,7 +1232,7 @@ class AvxvnniN8P4 : protected jblas::xbyak::JitAvxvnni { static_assert(NRegs * MRegs <= RegCount - 1); static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 4; static int constexpr KUNROLL = 2; - static auto constexpr ISA = JBLAS_ISA::JblasAVX_VNNI; + static auto constexpr ISA = BTLA_ISA::AVX_VNNI; static auto constexpr COMPUTE = CompType::COMP_INT8_US_INT32; typedef uint8_t AType; typedef int8_t BType; @@ -1443,7 +1443,7 @@ class AvxvnniN8P4 : protected jblas::xbyak::JitAvxvnni { }; template -class Amxbf16N16P2 : protected jblas::xbyak::JitAmxbf16 { +class Amxbf16N16P2 : protected bestla::xbyak::JitAmxbf16 { public: static int constexpr RegLen = 16, PackRow = 2; static_assert(_NTILE % RegLen == 0); @@ -1453,7 +1453,7 @@ class Amxbf16N16P2 : protected jblas::xbyak::JitAmxbf16 { static_assert(NRegs * MRegs + 2 <= TileCount); static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs * RegLen, KTILE = 32; static int constexpr KUNROLL = 2; - static auto constexpr ISA = JBLAS_ISA::JblasAMX_BF16; + static auto constexpr ISA = BTLA_ISA::AMX_BF16; static auto constexpr COMPUTE = CompType::COMP_BF16_FP32; typedef utils::bf16 AType; typedef utils::bf16 BType; @@ -1706,7 +1706,7 @@ class Amxbf16N16P2 : protected jblas::xbyak::JitAmxbf16 { }; template -class Amxint8N16P4 : protected jblas::xbyak::JitAmxint8 { +class Amxint8N16P4 : protected bestla::xbyak::JitAmxint8 { public: static int constexpr RegLen = 16, PackRow = 4; static_assert(_NTILE % RegLen == 0); @@ -1716,7 +1716,7 @@ class Amxint8N16P4 : protected jblas::xbyak::JitAmxint8 { static_assert(NRegs * MRegs + 2 <= TileCount); static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs * RegLen, KTILE = 64; static int constexpr KUNROLL = 2; - static auto constexpr ISA = JBLAS_ISA::JblasAMX_INT8; + static auto constexpr ISA = BTLA_ISA::AMX_INT8; static auto constexpr COMPUTE = (std::is_same_v ? std::is_same_v ? CompType::COMP_INT8_SS_INT32 : CompType::COMP_INT8_SU_INT32 @@ -1978,7 +1978,7 @@ using Amxint8N16P4US = Amxint8N16P4; template using Amxint8N16P4SS = Amxint8N16P4; -class AmxConfigure : protected jblas::xbyak::JitAmxtile { +class AmxConfigure : protected xbyak::JitAmxtile { public: typedef long long (*func_t)(tileconfig_t*); @@ -2003,7 +2003,7 @@ namespace kblock { // optimize for kblock gemm, each block size in k dimension has dequant operation // all accumulators use fp32 dtype. template -class Avx512fN16P1 : protected jblas::xbyak::JitAvx512f { +class Avx512fN16P1 : protected bestla::xbyak::JitAvx512f { public: static int constexpr RegLen = 16, PackRow = 1; static_assert(_NTILE % RegLen == 0); @@ -2012,7 +2012,7 @@ class Avx512fN16P1 : protected jblas::xbyak::JitAvx512f { static_assert(NRegs * MRegs <= RegCount - 1); static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 1; static int constexpr KUNROLL = 2; - static auto constexpr ISA = JBLAS_ISA::JblasAVX512F; + static auto constexpr ISA = BTLA_ISA::AVX512F; static auto constexpr COMPUTE = CompType::COMP_FP32; typedef float AType; typedef float BType; @@ -2223,7 +2223,7 @@ class Avx512fN16P1 : protected jblas::xbyak::JitAvx512f { }; template -class Avx512vnniN16P4 : protected jblas::xbyak::JitAvx512vnni { +class Avx512vnniN16P4 : protected bestla::xbyak::JitAvx512vnni { public: static int constexpr RegLen = 16, PackRow = 4; static_assert(_NTILE % RegLen == 0); @@ -2232,7 +2232,7 @@ class Avx512vnniN16P4 : protected jblas::xbyak::JitAvx512vnni { static_assert(NRegs * MRegs <= RegCount - 1); static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 4; static int constexpr KUNROLL = 2; - static auto constexpr ISA = JBLAS_ISA::JblasAVX512_VNNI; + static auto constexpr ISA = BTLA_ISA::AVX512_VNNI; static auto constexpr COMPUTE = CompType::COMP_INT8_US_FP32; typedef uint8_t AType; typedef int8_t BType; @@ -2521,7 +2521,7 @@ class Avx512vnniN16P4 : protected jblas::xbyak::JitAvx512vnni { }; template -class AvxvnniN8P4 : protected jblas::xbyak::JitAvxvnni { +class AvxvnniN8P4 : protected bestla::xbyak::JitAvxvnni { public: static int constexpr RegLen = 8, PackRow = 4; static_assert(_NTILE % RegLen == 0); @@ -2530,7 +2530,7 @@ class AvxvnniN8P4 : protected jblas::xbyak::JitAvxvnni { static_assert(NRegs * MRegs <= RegCount - 3); static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 4; static int constexpr KUNROLL = 2; - static auto constexpr ISA = JBLAS_ISA::JblasAVX_VNNI; + static auto constexpr ISA = BTLA_ISA::AVX_VNNI; static auto constexpr COMPUTE = CompType::COMP_INT8_US_FP32; typedef uint8_t AType; typedef int8_t BType; @@ -2865,7 +2865,7 @@ class AvxvnniN8P4 : protected jblas::xbyak::JitAvxvnni { }; template -class Amxint8N16P4 : protected jblas::xbyak::JitAmxint8 { +class Amxint8N16P4 : protected bestla::xbyak::JitAmxint8 { public: static int constexpr RegLen = 16, PackRow = 4; static_assert(_NTILE % RegLen == 0); @@ -2875,7 +2875,7 @@ class Amxint8N16P4 : protected jblas::xbyak::JitAmxint8 { static_assert(NRegs * MRegs + 2 <= TileCount); static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs * RegLen, KTILE = 64; static int constexpr KUNROLL = 2; - static auto constexpr ISA = JBLAS_ISA::JblasAMX_INT8; + static auto constexpr ISA = BTLA_ISA::AMX_INT8; static auto constexpr COMPUTE = (std::is_same_v ? std::is_same_v ? CompType::COMP_INT8_SS_FP32 : CompType::COMP_INT8_SU_FP32 : std::is_same_v ? CompType::COMP_INT8_US_FP32 @@ -3541,4 +3541,4 @@ class ICoreRowNAmxint8SSKBlock : public CoreCodeBaseAMX(const Xbyak::Tmm& x1, const Xbya tdpbuud(x1, x2, x3); } } // namespace xbyak -} // namespace jblas +} // namespace bestla diff --git a/bestla/jblas/jit_blas_parallel.h b/bestla/bestla/bestla_parallel.h similarity index 99% rename from bestla/jblas/jit_blas_parallel.h rename to bestla/bestla/bestla_parallel.h index 5e6b3b650..56078d6d3 100644 --- a/bestla/jblas/jit_blas_parallel.h +++ b/bestla/bestla/bestla_parallel.h @@ -18,10 +18,10 @@ #ifdef _OPENMP #include #endif -#include "jit_blas_utils.h" -#include "jit_blas_device.h" +#include "bestla_utils.h" +#include "bestla_device.h" -namespace jblas { +namespace bestla { namespace parallel { struct Config2D { int threads; @@ -698,4 +698,4 @@ void GemmRunWithA(Launch_T& launcher, const typename Launch_T::Param& args, para } } // namespace parallel -} // namespace jblas +} // namespace bestla diff --git a/bestla/jblas/jit_blas_prologue_a.h b/bestla/bestla/bestla_prologue_a.h similarity index 76% rename from bestla/jblas/jit_blas_prologue_a.h rename to bestla/bestla/bestla_prologue_a.h index 8cb2b69bd..677e42203 100644 --- a/bestla/jblas/jit_blas_prologue_a.h +++ b/bestla/bestla/bestla_prologue_a.h @@ -17,15 +17,15 @@ #include #include -#include "jit_blas.h" -#include "jit_blas_device.h" -#include "jit_blas_gemm.h" -#include "jit_blas_parallel.h" -#include "jit_blas_storage.h" -#include "jit_blas_utils.h" +#include "bestla.h" +#include "bestla_device.h" +#include "bestla_gemm.h" +#include "bestla_parallel.h" +#include "bestla_storage.h" +#include "bestla_utils.h" #include "kernel_wrapper.h" -namespace jblas { +namespace bestla { namespace prologue_a { namespace gemm { @@ -34,14 +34,14 @@ struct ParamActivationBase { const AType* A; int lda; }; -template +template class ActivationBase { public: using AType = typename _GemmCore_T::AType; using SRCType = AType; using Param = ParamActivationBase; - JBLAS_CODE getActivation(AType** dstptr, int* dststep, const Param& _param, int m_size, int k_size, int m_offset, - int k_offset, void* tmpcache, size_t cachesize) { + BTLA_CODE getActivation(AType** dstptr, int* dststep, const Param& _param, int m_size, int k_size, int m_offset, + int k_offset, void* tmpcache, size_t cachesize) { auto aptr = const_cast(_param.A) + m_offset * _param.lda + k_offset; auto alignedptr = utils::cpu_pointer_align(aptr); bool use_rawptr = k_size % _GemmCore_T::KTILE == 0 && m_size >= _GemmCore_T::MTILE; @@ -49,24 +49,24 @@ class ActivationBase { if (use_rawptr) { *dstptr = aptr; *dststep = _param.lda; - return JblasSuccess; + return BTLA_CODE::Success; } else { auto k_pad = utils::padto(k_size, _GemmCore_T::KTILE); *dststep = k_pad; - return kernel::wrapper::Memcpy2D::forward(aptr, *dstptr, m_size, k_size, _param.lda, - k_pad); + return kernel::wrapper::Memcpy2D::forward(aptr, *dstptr, m_size, k_size, + _param.lda, k_pad); } } }; -template +template class ActivationConverter : public ActivationBase<_GemmCore_T, ISA_T> { public: using AType = typename _GemmCore_T::AType; using SRCType = SRC_T; using Param = ParamActivationBase; - JBLAS_CODE getActivation(AType** dstptr, int* dststep, const Param& _param, int m_size, int k_size, int m_offset, - int k_offset, void* tmpcache, size_t cachesize) { + BTLA_CODE getActivation(AType** dstptr, int* dststep, const Param& _param, int m_size, int k_size, int m_offset, + int k_offset, void* tmpcache, size_t cachesize) { auto aptr = const_cast(_param.A); auto k_pad = utils::padto(k_size, _GemmCore_T::KTILE); *dststep = k_pad; @@ -88,20 +88,20 @@ class ActivationConverter : public ActivationBase<_GemmCore_T, ISA_T> { } else { assert(0); } - return JblasNotSupport; + return BTLA_CODE::NotSupport; } }; -template +template using ActivationConverterFp32 = ActivationConverter<_GemmCore_T, ISA_T, float>; -template +template using ActivationConverterBf16 = ActivationConverter<_GemmCore_T, ISA_T, utils::bf16>; template struct ParamActivationKBlockQuantize : ParamActivationBase { storage::gemm::StorageQuantActivation* quan; }; -template +template class ActivationKBlockQuantize { public: using AType = typename _GemmCore_T::AType; @@ -109,8 +109,8 @@ class ActivationKBlockQuantize { using QParam = storage::gemm::StorageQuantActivation; using SRCType = SRC_T; using Param = ParamActivationKBlockQuantize; - using Parallel = jblas::parallel::Scheduler2D; - using ThreadProblem = jblas::parallel::ThreadProblem2D; + using Parallel = parallel::Scheduler2D; + using ThreadProblem = parallel::ThreadProblem2D; inline Parallel createParallel(int nthreads, const utils::GemmProblem& prbm) { return Parallel({ @@ -125,8 +125,8 @@ class ActivationKBlockQuantize { QParam tmp; int kpad = utils::padto(k, _GemmCore_T::KTILE); int mpad = utils::padto(m, _GemmCore_T::MTILE); - tmp.resize(mpad, kpad, m, k, kblock == -1 ? kpad : kblock, JBLAS_DTYPE::U8, JBLAS_DTYPE::F32, JBLAS_DTYPE::U8, - JBLAS_DTYPE::F32, std::is_same_v, hasreduce); + tmp.resize(mpad, kpad, m, k, kblock == -1 ? kpad : kblock, BTLA_DTYPE::U8, BTLA_DTYPE::F32, BTLA_DTYPE::U8, + BTLA_DTYPE::F32, std::is_same_v, hasreduce); return tmp; } @@ -153,84 +153,84 @@ class ActivationKBlockQuantize { } } - JBLAS_CODE quantize(const Param& _param, int m, int k, jblas::parallel::IThreading* threading) { + BTLA_CODE quantize(const Param& _param, int m, int k, parallel::IThreading* threading) { auto paral = Parallel({threading->num_threads(), m, k, 1, _param.quan->mBlockSize}); threading->parallel_for([&](int tidx) { parallel::ThreadProblem2D thdp{tidx}; paral.getIndex(thdp); if (thdp.valid) run(_param, thdp); }); - return JblasSuccess; + return BTLA_CODE::Success; } public: // Runtime get by launcher - JBLAS_CODE getActivation(AType** dstptr, int* dststep, const Param& _param, int m_size, int k_size, int m_offset, - int k_offset, void* tmpcache, size_t cachesize) { + BTLA_CODE getActivation(AType** dstptr, int* dststep, const Param& _param, int m_size, int k_size, int m_offset, + int k_offset, void* tmpcache, size_t cachesize) { (void)m_size; (void)k_size; auto quan = _param.quan; auto aptr = quan->template APtr(); *dstptr = aptr + m_offset * quan->mKPad + k_offset; *dststep = quan->mKPad; - return JblasSuccess; + return BTLA_CODE::Success; } - JBLAS_CODE getZp(AType** dstptr, int* dststep, const Param& _param, int m_size, int k_size, int m_offset, - int k_offset, void* tmpcache, size_t cachesize) { + BTLA_CODE getZp(AType** dstptr, int* dststep, const Param& _param, int m_size, int k_size, int m_offset, int k_offset, + void* tmpcache, size_t cachesize) { auto quan = _param.quan; auto aptr = quan->template ZPtr(); if (aptr == nullptr) { // optional *dstptr = nullptr; - return JblasSuccess; + return BTLA_CODE::Success; } int kele = utils::updiv(k_size, quan->mBlockSize); *dststep = kele; kernel::ref::memcpy2d(aptr + m_offset * quan->CStep() + k_offset / quan->mBlockSize, *dstptr, m_size, kele * sizeof(AType), quan->CStep() * sizeof(AType), kele * sizeof(AType)); - return JblasSuccess; + return BTLA_CODE::Success; } - JBLAS_CODE getScale(float** dstptr, int* dststep, const Param& _param, int m_size, int k_size, int m_offset, - int k_offset, void* tmpcache, size_t cachesize) { + BTLA_CODE getScale(float** dstptr, int* dststep, const Param& _param, int m_size, int k_size, int m_offset, + int k_offset, void* tmpcache, size_t cachesize) { auto quan = _param.quan; auto aptr = quan->template SPtr(); int kele = utils::updiv(k_size, quan->mBlockSize); *dststep = kele; kernel::ref::memcpy2d(aptr + m_offset * quan->CStep() + k_offset / quan->mBlockSize, *dstptr, m_size, kele * sizeof(float), quan->CStep() * sizeof(float), kele * sizeof(float)); - return JblasSuccess; + return BTLA_CODE::Success; } - JBLAS_CODE getReduce(float** dstptr, int* dststep, const Param& _param, int m_size, int k_size, int m_offset, - int k_offset, void* tmpcache, size_t cachesize) { + BTLA_CODE getReduce(float** dstptr, int* dststep, const Param& _param, int m_size, int k_size, int m_offset, + int k_offset, void* tmpcache, size_t cachesize) { auto quan = _param.quan; auto aptr = quan->template RPtr(); int kele = utils::updiv(k_size, quan->mBlockSize); *dststep = kele; kernel::ref::memcpy2d(aptr + m_offset * quan->CStep() + k_offset / quan->mBlockSize, *dstptr, m_size, kele * sizeof(float), quan->CStep() * sizeof(float), kele * sizeof(float)); - return JblasSuccess; + return BTLA_CODE::Success; } }; -template +template using ActivationF32KBlockQuantize = ActivationKBlockQuantize<_GemmCore_T, ISA_T, float>; -template +template using ActivationBf16KBlockQuantize = ActivationKBlockQuantize<_GemmCore_T, ISA_T, utils::bf16>; template struct ParamActivationKBlockBase : ParamActivationBase { storage::gemm::StorageReduce* reduce; }; -template +template class ActivationKBlockBase : public ActivationConverter<_GemmCore_T, ISA_T, SRC_T> { public: using AType = typename _GemmCore_T::AType; using SType = storage::gemm::StorageReduce; using SRCType = SRC_T; using Param = ParamActivationKBlockBase; - using Parallel = jblas::parallel::Scheduler2D; - using ThreadProblem = jblas::parallel::ThreadProblem2D; + using Parallel = parallel::Scheduler2D; + using ThreadProblem = parallel::ThreadProblem2D; inline Parallel createParallel(int nthreads, const utils::GemmProblem& prbm) { return Parallel({ @@ -242,7 +242,7 @@ class ActivationKBlockBase : public ActivationConverter<_GemmCore_T, ISA_T, SRC_ } inline SType createStorage(int m, int k, int kblock) { SType tmp; - tmp.resize(m, k, kblock == -1 ? k : kblock, JBLAS_DTYPE::F32); + tmp.resize(m, k, kblock == -1 ? k : kblock, BTLA_DTYPE::F32); return tmp; } @@ -255,39 +255,39 @@ class ActivationKBlockBase : public ActivationConverter<_GemmCore_T, ISA_T, SRC_ auto thdrptr = stor->template RPtr() + blk_offset; auto ret = kernel::wrapper::ColBlockReduceSum::template forward( srcptr, _param.lda, thdp.size[0], thdp.size[1], stor->kblock, thdrptr, stor->lda); - assert(ret == JblasSuccess); + assert(ret == BTLA_CODE::Success); } } - JBLAS_CODE reduce(const Param& _param, int m, int k, int kblock, jblas::parallel::IThreading* threading) { + BTLA_CODE reduce(const Param& _param, int m, int k, int kblock, parallel::IThreading* threading) { auto paral = Parallel({threading->num_threads(), m, k, 1, kblock}); threading->parallel_for([&](int tidx) { parallel::ThreadProblem2D thdp{tidx}; paral.getIndex(thdp); if (thdp.valid) run(_param, thdp); }); - return JblasSuccess; + return BTLA_CODE::Success; } - JBLAS_CODE getActivation(AType** dstptr, int* dststep, const Param& _param, int m_size, int k_size, int m_offset, - int k_offset, void* tmpcache, size_t cachesize) { + BTLA_CODE getActivation(AType** dstptr, int* dststep, const Param& _param, int m_size, int k_size, int m_offset, + int k_offset, void* tmpcache, size_t cachesize) { return ActivationConverter<_GemmCore_T, ISA_T, SRC_T>::getActivation( dstptr, dststep, {_param.A, _param.lda}, m_size, k_size, m_offset, k_offset, tmpcache, cachesize); } - JBLAS_CODE getReduce(float** dstptr, int* dststep, const Param& _param, int m_size, int k_size, int m_offset, - int k_offset, void* tmpcache, size_t cachesize) { + BTLA_CODE getReduce(float** dstptr, int* dststep, const Param& _param, int m_size, int k_size, int m_offset, + int k_offset, void* tmpcache, size_t cachesize) { auto reduce = _param.reduce; auto aptr = reduce->template RPtr(); int kele = utils::updiv(k_size, reduce->kblock); *dststep = kele; kernel::ref::memcpy2d(aptr + m_offset * reduce->lda + k_offset / reduce->kblock, *dstptr, m_size, kele * sizeof(float), reduce->lda * sizeof(float), kele * sizeof(float)); - return JblasSuccess; + return BTLA_CODE::Success; } }; -template +template using ActivationKBlockBaseF32 = ActivationKBlockBase<_GemmCore_T, ISA_T, float>; template @@ -295,7 +295,7 @@ struct ParamShuffleActivationKBlockBase : ParamActivationKBlockBase { int* indices = nullptr; storage::gemm::StorageReorderActivation* reordered = nullptr; }; -template +template class ShuffleActivationKBlockBase : public ActivationKBlockBase<_GemmCore_T, ISA_T, SRC_T> { public: using AType = typename _GemmCore_T::AType; @@ -303,19 +303,19 @@ class ShuffleActivationKBlockBase : public ActivationKBlockBase<_GemmCore_T, ISA using RAType = storage::gemm::StorageReorderActivation; using SRCType = SRC_T; using Param = ParamShuffleActivationKBlockBase; - using Parallel = jblas::parallel::Scheduler2D; - using ThreadProblem = jblas::parallel::ThreadProblem2D; + using Parallel = parallel::Scheduler2D; + using ThreadProblem = parallel::ThreadProblem2D; inline RAType createReorderStorage(int m, int k, int kblock) { RAType tmp(_GemmCore_T::ID); int kpad = utils::padto(k, _GemmCore_T::KTILE); int mpad = utils::padto(m, _GemmCore_T::MTILE); - tmp.resize(mpad, kpad, m, k, kblock == -1 ? kpad : kblock, utils::jblas_dtype); + tmp.resize(mpad, kpad, m, k, kblock == -1 ? kpad : kblock, utils::bestla_dtype); return tmp; } inline RedType createReduceStorage(int m, int k, int kblock) { RedType tmp; - tmp.resize(m, k, kblock == -1 ? k : kblock, JBLAS_DTYPE::F32); + tmp.resize(m, k, kblock == -1 ? k : kblock, BTLA_DTYPE::F32); return tmp; } @@ -337,23 +337,23 @@ class ShuffleActivationKBlockBase : public ActivationKBlockBase<_GemmCore_T, ISA auto thdrptr = stor->template RPtr() + blk_offset; auto ret = kernel::wrapper::ColBlockReduceSum::template forward( srcptr, _param.lda, thdp.size[0], thdp.size[1], stor->kblock, thdrptr, stor->lda); - assert(ret == JblasSuccess); + assert(ret == BTLA_CODE::Success); } } } - JBLAS_CODE preprocess(const Param& _param, int m, int k, int kblock, jblas::parallel::IThreading* threading) { + BTLA_CODE preprocess(const Param& _param, int m, int k, int kblock, parallel::IThreading* threading) { auto paral = Parallel({threading->num_threads(), m, k, 1, kblock}); threading->parallel_for([&](int tidx) { parallel::ThreadProblem2D thdp{tidx}; paral.getIndex(thdp); run(_param, thdp); }); - return JblasSuccess; + return BTLA_CODE::Success; } - JBLAS_CODE getActivation(AType** dstptr, int* dststep, const Param& _param, int m_size, int k_size, int m_offset, - int k_offset, void* tmpcache, size_t cachesize) { + BTLA_CODE getActivation(AType** dstptr, int* dststep, const Param& _param, int m_size, int k_size, int m_offset, + int k_offset, void* tmpcache, size_t cachesize) { if (_param.indices == nullptr) { return ActivationConverter<_GemmCore_T, ISA_T, SRC_T>::getActivation( dstptr, dststep, {_param.A, _param.lda}, m_size, k_size, m_offset, k_offset, tmpcache, cachesize); @@ -365,7 +365,7 @@ class ShuffleActivationKBlockBase : public ActivationKBlockBase<_GemmCore_T, ISA } }; -template +template using ShuffleActivationKBlockBaseF32 = ShuffleActivationKBlockBase<_GemmCore_T, ISA_T, float>; template @@ -373,7 +373,7 @@ struct ParamShuffleActivationKBlockQuantize : ParamActivationKBlockQuantize +template class ShuffleActivationKBlockQuantize : public ActivationKBlockQuantize<_GemmCore_T, ISA_T, SRC_T> { public: using AType = typename _GemmCore_T::AType; @@ -382,15 +382,15 @@ class ShuffleActivationKBlockQuantize : public ActivationKBlockQuantize<_GemmCor using RAType = storage::gemm::StorageReorderActivation; using SRCType = SRC_T; using Param = ParamShuffleActivationKBlockQuantize; - using Parallel = jblas::parallel::Scheduler2D; - using ThreadProblem = jblas::parallel::ThreadProblem2D; + using Parallel = parallel::Scheduler2D; + using ThreadProblem = parallel::ThreadProblem2D; inline QParam createQuantStorage(int m, int k, int kblock, bool hasreduce) { QParam tmp; int kpad = utils::padto(k, _GemmCore_T::KTILE); int mpad = utils::padto(m, _GemmCore_T::MTILE); - tmp.resize(mpad, kpad, m, k, kblock == -1 ? kpad : kblock, JBLAS_DTYPE::U8, JBLAS_DTYPE::F32, JBLAS_DTYPE::U8, - JBLAS_DTYPE::F32, std::is_same_v, hasreduce); + tmp.resize(mpad, kpad, m, k, kblock == -1 ? kpad : kblock, BTLA_DTYPE::U8, BTLA_DTYPE::F32, BTLA_DTYPE::U8, + BTLA_DTYPE::F32, std::is_same_v, hasreduce); return tmp; } @@ -398,11 +398,11 @@ class ShuffleActivationKBlockQuantize : public ActivationKBlockQuantize<_GemmCor RAType tmp(_GemmCore_T::ID); int kpad = utils::padto(k, _GemmCore_T::KTILE); int mpad = utils::padto(m, _GemmCore_T::MTILE); - tmp.resize(mpad, kpad, m, k, kblock == -1 ? kpad : kblock, utils::jblas_dtype); + tmp.resize(mpad, kpad, m, k, kblock == -1 ? kpad : kblock, utils::bestla_dtype); return tmp; } - JBLAS_CODE quantize(const Param& _param, int m, int k, jblas::parallel::IThreading* threading) { + BTLA_CODE quantize(const Param& _param, int m, int k, parallel::IThreading* threading) { auto srcptr = const_cast(_param.A); if (_param.indices) { auto shuffle_src = _param.reordered->template APtr(); @@ -416,12 +416,12 @@ class ShuffleActivationKBlockQuantize : public ActivationKBlockQuantize<_GemmCor srcptr = shuffle_src; } ActivationKBlockQuantize<_GemmCore_T, ISA_T, SRC_T>::quantize({srcptr, k, _param.quan}, m, k, threading); - return JblasSuccess; + return BTLA_CODE::Success; } }; -template +template using ShuffleActivationKBlockQuantizeF32 = ShuffleActivationKBlockQuantize<_GemmCore_T, ISA_T, float>; } // namespace gemm } // namespace prologue_a -} // namespace jblas +} // namespace bestla diff --git a/bestla/jblas/jit_blas_prologue_b.h b/bestla/bestla/bestla_prologue_b.h similarity index 80% rename from bestla/jblas/jit_blas_prologue_b.h rename to bestla/bestla/bestla_prologue_b.h index 95371744f..84313000a 100644 --- a/bestla/jblas/jit_blas_prologue_b.h +++ b/bestla/bestla/bestla_prologue_b.h @@ -13,23 +13,23 @@ // limitations under the License. #pragma once #include -#include "jblas/jit_blas_utils.h" -#include "jit_blas_storage.h" -#include "jit_blas_device.h" -#include "jit_blas_parallel.h" +#include "bestla_utils.h" +#include "bestla_storage.h" +#include "bestla_device.h" +#include "bestla_parallel.h" #include "kernel_wrapper.h" -namespace jblas { +namespace bestla { namespace prologue_b { namespace gemm { -template +template static inline void transposeWeight(const int Row, const int Col, const WT* src, const int ld_src, WT* dst, const int ld_dst, parallel::IThreading* threading) { - jblas::parallel::Scheduler2D _para; + bestla::parallel::Scheduler2D _para; _para.update({threading->num_threads(), Row, Col, 16, 16}); threading->parallel_for([&](int tidx) { - jblas::parallel::ThreadProblem2D thdp{tidx}; + bestla::parallel::ThreadProblem2D thdp{tidx}; _para.getIndex(thdp); if (thdp.valid) { kernel::wrapper::Transpose2D::template forward(src + thdp.loc[0] * ld_src + thdp.loc[1], @@ -45,7 +45,7 @@ struct ParamWeightPack { storage::gemm::StoragePackedWeight* packedW; }; -template +template class WeightPack { public: using WType = typename _GemmCore_T::BType; @@ -56,7 +56,7 @@ class WeightPack { int KPad = utils::padto(k, _GemmCore_T::KTILE); int NPad = utils::padto(n, _GemmCore_T::NTILE); StorageType tmp(_GemmCore_T::ID); - tmp.resize(NPad, KPad, n, k, utils::jblas_dtype); + tmp.resize(NPad, KPad, n, k, utils::bestla_dtype); return tmp; } @@ -88,20 +88,20 @@ class WeightPack { using PaddingInterleaveMNWType = kernel::wrapper::PaddingInterleaveMN<_GemmCore_T::NTILE, _GemmCore_T::PACK_ROW>; auto ret = PaddingInterleaveMNWType::template forward( // src, dst, thdp.size[0], thdp.size[1], rowpadded, colpadded, _param.ldb, packedw->mKPad); - assert(ret == JblasSuccess); + assert(ret == BTLA_CODE::Success); (void)ret; } - inline JBLAS_CODE getWeight(WType** dstptr, int* dststep, int k_size, int n_size, int k_offset, int n_offset, - const Param param, void* tmpcache, size_t cachesize) { + inline BTLA_CODE getWeight(WType** dstptr, int* dststep, int k_size, int n_size, int k_offset, int n_offset, + const Param param, void* tmpcache, size_t cachesize) { auto wptr = param.packedW; auto KPad = wptr->mKPad; auto bptr = wptr->template WPtr() + n_offset * KPad + k_offset * _GemmCore_T::NTILE; - kernel::wrapper::Memcpy2D::template forward( + kernel::wrapper::Memcpy2D::template forward( bptr, *dstptr, n_size / _GemmCore_T::NTILE, _GemmCore_T::NTILE * k_size, _GemmCore_T::NTILE * KPad, _GemmCore_T::NTILE * k_size); *dststep = k_size; - return JblasSuccess; + return BTLA_CODE::Success; } }; @@ -109,14 +109,14 @@ struct ParamWeightKBlockNInteger { storage::gemm::StorageWeightKBlockNInteger* packedW; }; -template +template class WeightKBlockNInteger { public: using StorageWeight = storage::gemm::StorageWeightKBlockNInteger; using BType = typename _GemmCore_T::BType; using Param = ParamWeightKBlockNInteger; - static StorageWeight createStorage(int n, int k, int blocksize, JBLAS_DTYPE qtype, JBLAS_DTYPE scat, JBLAS_DTYPE redt, + static StorageWeight createStorage(int n, int k, int blocksize, BTLA_DTYPE qtype, BTLA_DTYPE scat, BTLA_DTYPE redt, bool is_asym) { int KPad = utils::padto(k, _GemmCore_T::KTILE); int NPad = utils::padto(n, _GemmCore_T::NTILE); @@ -208,7 +208,7 @@ class WeightKBlockNInteger { int rawnk_scale = utils::updiv(K, stor->mBlockSize); int nk_scale = utils::updiv(stor->mKPad, stor->mBlockSize); parallel::Scheduler2D _para({threading->num_threads(), 1, nk_scale, 1, 1}); - if (stor->SDtype() == JBLAS_DTYPE::F32) { // fp32 to fp32 direct copy + if (stor->SDtype() == BTLA_DTYPE::F32) { // fp32 to fp32 direct copy threading->parallel_for([&](int tidx) { parallel::ThreadProblem2D thdp{tidx}; _para.getIndex(thdp); @@ -229,7 +229,7 @@ class WeightKBlockNInteger { } } }); - } else if (stor->SDtype() == JBLAS_DTYPE::BF16) { + } else if (stor->SDtype() == BTLA_DTYPE::BF16) { threading->parallel_for([&](int tidx) { parallel::ThreadProblem2D thdp{tidx}; _para.getIndex(thdp); @@ -254,7 +254,7 @@ class WeightKBlockNInteger { } } }); - } else if (stor->SDtype() == JBLAS_DTYPE::F8_E8M0) { + } else if (stor->SDtype() == BTLA_DTYPE::F8_E8M0) { threading->parallel_for([&](int tidx) { parallel::ThreadProblem2D thdp{tidx}; _para.getIndex(thdp); @@ -304,7 +304,7 @@ class WeightKBlockNInteger { int rawnk_scale = utils::updiv(K, stor->mBlockSize); int nk_scale = utils::updiv(stor->mKPad, stor->mBlockSize); parallel::Scheduler2D _para({threading->num_threads(), 1, nk_scale, 1, 1}); - if (stor->SDtype() == JBLAS_DTYPE::F32) { // fp32 to fp32 direct copy + if (stor->SDtype() == BTLA_DTYPE::F32) { // fp32 to fp32 direct copy threading->parallel_for([&](int tidx) { parallel::ThreadProblem2D thdp{tidx}; _para.getIndex(thdp); @@ -322,7 +322,7 @@ class WeightKBlockNInteger { } } }); - } else if (stor->SDtype() == JBLAS_DTYPE::BF16) { + } else if (stor->SDtype() == BTLA_DTYPE::BF16) { threading->parallel_for([&](int tidx) { parallel::ThreadProblem2D thdp{tidx}; _para.getIndex(thdp); @@ -340,7 +340,7 @@ class WeightKBlockNInteger { } } }); - } else if (stor->SDtype() == JBLAS_DTYPE::F8_E8M0) { + } else if (stor->SDtype() == BTLA_DTYPE::F8_E8M0) { threading->parallel_for([&](int tidx) { parallel::ThreadProblem2D thdp{tidx}; _para.getIndex(thdp); @@ -382,8 +382,8 @@ class WeightKBlockNInteger { void packQWeight(const int N, const int K, const int8_t* B, const int ldb, const float* scales, const int8_t* zero_points, StorageWeight* stor, parallel::IThreading* threading) { setQuantCorrection(N, K, zero_points, scales, stor, threading); - if (stor->mDType == JBLAS_DTYPE::S8 || stor->mDType == JBLAS_DTYPE::F8_E4M3 || - stor->mDType == JBLAS_DTYPE::F8_E5M2) { + if (stor->mDType == BTLA_DTYPE::S8 || stor->mDType == BTLA_DTYPE::F8_E4M3 || + stor->mDType == BTLA_DTYPE::F8_E5M2) { reorderWeight(N, K, B, ldb, stor->WPtr(), threading); } else { auto reorded = utils::amalloc((size_t)stor->mKPad * stor->mNPad); @@ -453,10 +453,10 @@ class WeightKBlockNInteger { if (stor->HasReduce()) { auto deq = utils::amalloc((size_t)stor->mK * stor->mN); unpackWeight(stor->mN, stor->mK, stor, deq, stor->mN, threading); - if (stor->RDtype() == JBLAS_DTYPE::F32) { + if (stor->RDtype() == BTLA_DTYPE::F32) { reduce(stor->mN, stor->mK, stor->mBlockSize, deq, stor->mN, stor->template RPtr(), stor->CStep(), threading); - } else if (stor->RDtype() == JBLAS_DTYPE::BF16) { + } else if (stor->RDtype() == BTLA_DTYPE::BF16) { reduce(stor->mN, stor->mK, stor->mBlockSize, deq, stor->mN, stor->template RPtr(), stor->CStep(), threading); } else { @@ -498,14 +498,14 @@ class WeightKBlockNInteger { kernel::wrapper::PaddingInterleaveMN<_GemmCore_T::NTILE, _GemmCore_T::PACK_ROW>; auto ret = PaddingInterleaveMNWType::template forward( // src, dst, thdp.size[0], thdp.size[1], rowpadded, colpadded, ldb, KPad); - assert(ret == JblasSuccess); + assert(ret == BTLA_CODE::Success); (void)ret; } }); } - static void compressWeight(const int N, const int K, const int8_t* B, const int ldb, int8_t* dstptr, - JBLAS_DTYPE qtype, parallel::IThreading* threading) { + static void compressWeight(const int N, const int K, const int8_t* B, const int ldb, int8_t* dstptr, BTLA_DTYPE qtype, + parallel::IThreading* threading) { parallel::Scheduler2D _para({threading->num_threads(), K, N, _GemmCore_T::KTILE, _GemmCore_T::NTILE}); threading->parallel_for([&](int tidx) { parallel::ThreadProblem2D thdp({tidx}); @@ -513,7 +513,7 @@ class WeightKBlockNInteger { if (thdp.valid) { auto ret = doCompress(B + thdp.loc[0] * ldb + thdp.loc[1], dstptr + thdp.loc[0] * ldb / 2 + thdp.loc[1] / 2, thdp.size[0], thdp.size[1], ldb, ldb, qtype); - assert(ret == JblasSuccess); + assert(ret == BTLA_CODE::Success); (void)ret; } }); @@ -534,7 +534,7 @@ class WeightKBlockNInteger { int rowremain = utils::remainsize(thdp.loc[0] + i, K, KBlock); auto ret = RowReduceSum::template forward( // src + i * ldb, ldb, rowremain, thdp.size[1], dst + i / KBlock * ldr); - assert(ret == JblasSuccess); + assert(ret == BTLA_CODE::Success); (void)ret; } } @@ -542,123 +542,123 @@ class WeightKBlockNInteger { } public: - virtual inline JBLAS_CODE getWeight(float** dstptr, int* dststep, int k_size, int n_size, int k_offset, int n_offset, - const Param& _param, void* tmpcache, size_t cachesize) { + virtual inline BTLA_CODE getWeight(float** dstptr, int* dststep, int k_size, int n_size, int k_offset, int n_offset, + const Param& _param, void* tmpcache, size_t cachesize) { return getFpWeight(dstptr, dststep, k_size, n_size, k_offset, n_offset, _param, tmpcache, cachesize); } - virtual inline JBLAS_CODE getWeight(utils::bf16** dstptr, int* dststep, int k_size, int n_size, int k_offset, - int n_offset, const Param& _param, void* tmpcache, size_t cachesize) { + virtual inline BTLA_CODE getWeight(utils::bf16** dstptr, int* dststep, int k_size, int n_size, int k_offset, + int n_offset, const Param& _param, void* tmpcache, size_t cachesize) { return getFpWeight(dstptr, dststep, k_size, n_size, k_offset, n_offset, _param, tmpcache, cachesize); } - static inline JBLAS_CODE getWeight(int8_t** dstptr, int* dststep, int k_size, int n_size, int k_offset, int n_offset, - const Param& _param, void* tmpcache, size_t cachesize) { + static inline BTLA_CODE getWeight(int8_t** dstptr, int* dststep, int k_size, int n_size, int k_offset, int n_offset, + const Param& _param, void* tmpcache, size_t cachesize) { auto wptr = _param.packedW; - if (wptr->mDType == JBLAS_DTYPE::S8) { + if (wptr->mDType == BTLA_DTYPE::S8) { return getQ8Weight(dstptr, dststep, k_size, n_size, k_offset, n_offset, _param, tmpcache, cachesize); - } else if (wptr->mDType == JBLAS_DTYPE::S4_CLIP || wptr->mDType == JBLAS_DTYPE::S4_FULLRANGE) { + } else if (wptr->mDType == BTLA_DTYPE::S4_CLIP || wptr->mDType == BTLA_DTYPE::S4_FULLRANGE) { return getQ4Weight(dstptr, dststep, k_size, n_size, k_offset, n_offset, _param, tmpcache, cachesize); } else { assert(0); } - return JblasNotSupport; + return BTLA_CODE::NotSupport; } - static inline JBLAS_CODE getKBlockWeight(float** dstptr, int* dststep, int k_size, int n_size, int k_offset, - int n_offset, const Param& _param, void* tmpcache, size_t cachesize) { + static inline BTLA_CODE getKBlockWeight(float** dstptr, int* dststep, int k_size, int n_size, int k_offset, + int n_offset, const Param& _param, void* tmpcache, size_t cachesize) { return getFpKBlockWeight(dstptr, dststep, k_size, n_size, k_offset, n_offset, _param, tmpcache, cachesize); } - static inline JBLAS_CODE getKBlockWeight(utils::bf16** dstptr, int* dststep, int k_size, int n_size, int k_offset, - int n_offset, const Param& _param, void* tmpcache, size_t cachesize) { + static inline BTLA_CODE getKBlockWeight(utils::bf16** dstptr, int* dststep, int k_size, int n_size, int k_offset, + int n_offset, const Param& _param, void* tmpcache, size_t cachesize) { return getFpKBlockWeight(dstptr, dststep, k_size, n_size, k_offset, n_offset, _param, tmpcache, cachesize); } - static inline JBLAS_CODE getKBlockWeight(int8_t** dstptr, int* dststep, int k_size, int n_size, int k_offset, - int n_offset, const Param& _param, void* tmpcache, size_t cachesize) { + static inline BTLA_CODE getKBlockWeight(int8_t** dstptr, int* dststep, int k_size, int n_size, int k_offset, + int n_offset, const Param& _param, void* tmpcache, size_t cachesize) { return getWeight(dstptr, dststep, k_size, n_size, k_offset, n_offset, _param, tmpcache, cachesize); } - static inline JBLAS_CODE getScale(float** dstptr, int* dststep, int k_size, int n_size, int k_offset, int n_offset, - const Param& _param, void* tmpcache, size_t cachesize) { + static inline BTLA_CODE getScale(float** dstptr, int* dststep, int k_size, int n_size, int k_offset, int n_offset, + const Param& _param, void* tmpcache, size_t cachesize) { auto wptr = _param.packedW; - if (wptr->SDtype() == JBLAS_DTYPE::F32) { + if (wptr->SDtype() == BTLA_DTYPE::F32) { auto aptr = wptr->template SPtr(); - kernel::wrapper::Memcpy2D::template forward( + kernel::wrapper::Memcpy2D::template forward( aptr + k_offset / wptr->mBlockSize * wptr->CStep() + n_offset, *dstptr, utils::updiv(k_size, wptr->mBlockSize), n_size, wptr->CStep(), n_size); *dststep = n_size; } - if (wptr->SDtype() == JBLAS_DTYPE::BF16) { + if (wptr->SDtype() == BTLA_DTYPE::BF16) { auto aptr = wptr->template SPtr(); kernel::wrapper::Memcpy2DBf16CvtFp32::forward( aptr + k_offset / wptr->mBlockSize * wptr->CStep() + n_offset, *dstptr, utils::updiv(k_size, wptr->mBlockSize), n_size, wptr->CStep() * 2, n_size * 4, false); *dststep = n_size; } - return JblasSuccess; + return BTLA_CODE::Success; } - static inline JBLAS_CODE getReduce(float** dstptr, int* dststep, int k_size, int n_size, int k_offset, int n_offset, - const Param& _param, void* tmpcache, size_t cachesize) { + static inline BTLA_CODE getReduce(float** dstptr, int* dststep, int k_size, int n_size, int k_offset, int n_offset, + const Param& _param, void* tmpcache, size_t cachesize) { auto wptr = _param.packedW; - if (wptr->RDtype() == JBLAS_DTYPE::F32) { + if (wptr->RDtype() == BTLA_DTYPE::F32) { auto aptr = wptr->template RPtr(); - kernel::wrapper::Memcpy2D::template forward( + kernel::wrapper::Memcpy2D::template forward( aptr + k_offset / wptr->mBlockSize * wptr->CStep() + n_offset, *dstptr, utils::updiv(k_size, wptr->mBlockSize), n_size, wptr->CStep(), n_size); *dststep = n_size; } - if (wptr->RDtype() == JBLAS_DTYPE::BF16) { + if (wptr->RDtype() == BTLA_DTYPE::BF16) { auto aptr = wptr->template RPtr(); kernel::wrapper::Memcpy2DBf16CvtFp32::forward( aptr + k_offset / wptr->mBlockSize * wptr->CStep() + n_offset, *dstptr, utils::updiv(k_size, wptr->mBlockSize), n_size, wptr->CStep() * 2, n_size * 4, false); *dststep = n_size; } - return JblasSuccess; + return BTLA_CODE::Success; } protected: template - static inline JBLAS_CODE getFpKBlockWeight(T** dstptr, int* dststep, int k_size, int n_size, int k_offset, - int n_offset, const Param& _param, void* tmpcache, size_t cachesize) { + static inline BTLA_CODE getFpKBlockWeight(T** dstptr, int* dststep, int k_size, int n_size, int k_offset, + int n_offset, const Param& _param, void* tmpcache, size_t cachesize) { auto wptr = _param.packedW; auto NPad = wptr->mNPad; auto KPad = wptr->mKPad; int constexpr ColSize = _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW; for (int i = 0; i < n_size; i += _GemmCore_T::NTILE) { - if (wptr->SDtype() == JBLAS_DTYPE::F32) { + if (wptr->SDtype() == BTLA_DTYPE::F32) { auto sptr = wptr->template SPtr() + n_offset + i; - if (wptr->mDType == JBLAS_DTYPE::S4_CLIP) { - kernel::wrapper::DecompressKBlockS4S8Fp::template forward( + if (wptr->mDType == BTLA_DTYPE::S4_CLIP) { + kernel::wrapper::DecompressKBlockS4S8Fp::template forward( wptr->template WPtr() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2 + i * KPad / 2, *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, tmpcache, cachesize); - } else if (wptr->mDType == JBLAS_DTYPE::S4_FULLRANGE) { - kernel::wrapper::DecompressKBlockS4S8Fp::template forward( + } else if (wptr->mDType == BTLA_DTYPE::S4_FULLRANGE) { + kernel::wrapper::DecompressKBlockS4S8Fp::template forward( wptr->template WPtr() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2 + i * KPad / 2, *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, tmpcache, cachesize); - } else if (wptr->mDType == JBLAS_DTYPE::S8) { + } else if (wptr->mDType == BTLA_DTYPE::S8) { kernel::wrapper::DecompressKBlockS8S8Fp::template forward( wptr->template WPtr() + n_offset * KPad + k_offset * _GemmCore_T::NTILE + i * KPad, *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, tmpcache, cachesize); } - } else if (wptr->SDtype() == JBLAS_DTYPE::BF16) { + } else if (wptr->SDtype() == BTLA_DTYPE::BF16) { auto sptr = wptr->template SPtr() + n_offset + i; - if (wptr->mDType == JBLAS_DTYPE::S4_CLIP) { - kernel::wrapper::DecompressKBlockS4S8Fp::template forward( + if (wptr->mDType == BTLA_DTYPE::S4_CLIP) { + kernel::wrapper::DecompressKBlockS4S8Fp::template forward( wptr->template WPtr() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2 + i * KPad / 2, *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, tmpcache, cachesize); - } else if (wptr->mDType == JBLAS_DTYPE::S4_FULLRANGE) { - kernel::wrapper::DecompressKBlockS4S8Fp::template forward( + } else if (wptr->mDType == BTLA_DTYPE::S4_FULLRANGE) { + kernel::wrapper::DecompressKBlockS4S8Fp::template forward( wptr->template WPtr() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2 + i * KPad / 2, *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, tmpcache, cachesize); - } else if (wptr->mDType == JBLAS_DTYPE::S8) { + } else if (wptr->mDType == BTLA_DTYPE::S8) { kernel::wrapper::DecompressKBlockS8S8Fp::template forward( wptr->template WPtr() + n_offset * KPad + k_offset * _GemmCore_T::NTILE + i * KPad, *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, tmpcache, cachesize); @@ -666,37 +666,37 @@ class WeightKBlockNInteger { } } *dststep = k_size; - return JblasSuccess; + return BTLA_CODE::Success; } template - static inline JBLAS_CODE getFpWeight(_T** dstptr, int* dststep, int k_size, int n_size, int k_offset, int n_offset, - const Param& _param, void* tmpcache, size_t cachesize) { + static inline BTLA_CODE getFpWeight(_T** dstptr, int* dststep, int k_size, int n_size, int k_offset, int n_offset, + const Param& _param, void* tmpcache, size_t cachesize) { auto wptr = _param.packedW; auto NPad = wptr->mNPad; auto KPad = wptr->mKPad; int constexpr ColSize = _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW; for (int i = 0; i < n_size; i += _GemmCore_T::NTILE) { auto zptr = wptr->template ZPtr(); - if (wptr->SDtype() == JBLAS_DTYPE::F32) { + if (wptr->SDtype() == BTLA_DTYPE::F32) { auto sptr = wptr->template SPtr() + n_offset + i; - if (wptr->mDType == JBLAS_DTYPE::S4_CLIP) { + if (wptr->mDType == BTLA_DTYPE::S4_CLIP) { kernel::wrapper::DecompressKBlockS4Fp<_T, _GemmCore_T::PACK_ROW>::template forward( + BTLA_DTYPE::S4_CLIP>( wptr->template WPtr() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2 + i * KPad / 2, *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, sptr, zptr != nullptr ? zptr + n_offset + i : nullptr, k_offset / _GemmCore_T::PACK_ROW, wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, tmpcache, cachesize); - } else if (wptr->mDType == JBLAS_DTYPE::S4_FULLRANGE) { + } else if (wptr->mDType == BTLA_DTYPE::S4_FULLRANGE) { kernel::wrapper::DecompressKBlockS4Fp<_T, _GemmCore_T::PACK_ROW>::template forward( + BTLA_DTYPE::S4_FULLRANGE>( wptr->template WPtr() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2 + i * KPad / 2, *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, sptr, zptr != nullptr ? zptr + n_offset + i : nullptr, k_offset / _GemmCore_T::PACK_ROW, wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, tmpcache, cachesize); - } else if (wptr->mDType == JBLAS_DTYPE::S8) { + } else if (wptr->mDType == BTLA_DTYPE::S8) { kernel::wrapper::DecompressKBlockS8Fp<_T, _GemmCore_T::PACK_ROW>::template forward( wptr->template WPtr() + n_offset * KPad + k_offset * _GemmCore_T::NTILE + i * KPad, *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, sptr, @@ -705,25 +705,25 @@ class WeightKBlockNInteger { } else { assert(0); } - } else if (wptr->SDtype() == JBLAS_DTYPE::BF16) { + } else if (wptr->SDtype() == BTLA_DTYPE::BF16) { auto sptr = wptr->template SPtr() + n_offset + i; - if (wptr->mDType == JBLAS_DTYPE::S4_CLIP) { + if (wptr->mDType == BTLA_DTYPE::S4_CLIP) { kernel::wrapper::DecompressKBlockS4Fp<_T, _GemmCore_T::PACK_ROW>::template forward( + BTLA_DTYPE::S4_CLIP>( wptr->template WPtr() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2 + i * KPad / 2, *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, sptr, zptr != nullptr ? zptr + n_offset + i : nullptr, k_offset / _GemmCore_T::PACK_ROW, wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, tmpcache, cachesize); - } else if (wptr->mDType == JBLAS_DTYPE::S4_FULLRANGE) { + } else if (wptr->mDType == BTLA_DTYPE::S4_FULLRANGE) { kernel::wrapper::DecompressKBlockS4Fp<_T, _GemmCore_T::PACK_ROW>::template forward( + BTLA_DTYPE::S4_FULLRANGE>( wptr->template WPtr() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2 + i * KPad / 2, *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, sptr, zptr != nullptr ? zptr + n_offset + i : nullptr, k_offset / _GemmCore_T::PACK_ROW, wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, tmpcache, cachesize); - } else if (wptr->mDType == JBLAS_DTYPE::S8) { + } else if (wptr->mDType == BTLA_DTYPE::S8) { kernel::wrapper::DecompressKBlockS8Fp<_T, _GemmCore_T::PACK_ROW>::template forward( wptr->template WPtr() + n_offset * KPad + k_offset * _GemmCore_T::NTILE + i * KPad, *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, sptr, @@ -735,69 +735,69 @@ class WeightKBlockNInteger { } } *dststep = k_size; - return JblasSuccess; + return BTLA_CODE::Success; } - static inline JBLAS_CODE getQ8Weight(int8_t** dstptr, int* dststep, int k_size, int n_size, int k_offset, - int n_offset, const Param& _param, void* tmpcache, size_t cachesize) { + static inline BTLA_CODE getQ8Weight(int8_t** dstptr, int* dststep, int k_size, int n_size, int k_offset, int n_offset, + const Param& _param, void* tmpcache, size_t cachesize) { auto wptr = _param.packedW; auto KPad = wptr->mKPad; auto bptr = wptr->template WPtr() + n_offset * KPad + k_offset * _GemmCore_T::NTILE; - kernel::wrapper::Memcpy2D::template forward( + kernel::wrapper::Memcpy2D::template forward( bptr, *dstptr, n_size / _GemmCore_T::NTILE, _GemmCore_T::NTILE * k_size, _GemmCore_T::NTILE * KPad, _GemmCore_T::NTILE * k_size); *dststep = k_size; - return JblasSuccess; + return BTLA_CODE::Success; } - static inline JBLAS_CODE getQ4Weight(int8_t** dstptr, int* dststep, int k_size, int n_size, int k_offset, - int n_offset, const Param& _param, void* tmpcache, size_t cachesize) { + static inline BTLA_CODE getQ4Weight(int8_t** dstptr, int* dststep, int k_size, int n_size, int k_offset, int n_offset, + const Param& _param, void* tmpcache, size_t cachesize) { auto wptr = _param.packedW; auto KPad = wptr->mKPad; auto bptr = wptr->template WPtr() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2; int constexpr ColSize = _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW; for (int i = 0; i < n_size; i += _GemmCore_T::NTILE) { - if (wptr->mDType == JBLAS_DTYPE::S4_CLIP) { - kernel::wrapper::DecompressKBlockS4S8::template forward( + if (wptr->mDType == BTLA_DTYPE::S4_CLIP) { + kernel::wrapper::DecompressKBlockS4S8::template forward( bptr + i * KPad / 2, *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize); - } else if (wptr->mDType == JBLAS_DTYPE::S4_FULLRANGE) { - kernel::wrapper::DecompressKBlockS4S8::template forward( + } else if (wptr->mDType == BTLA_DTYPE::S4_FULLRANGE) { + kernel::wrapper::DecompressKBlockS4S8::template forward( bptr + i * KPad / 2, *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize); } } *dststep = k_size; - return JblasSuccess; + return BTLA_CODE::Success; } virtual inline void quantRowBlock(const float* srcptr, int8_t* dstptr, int row, int col, int ld_src, int ld_dst, float* scales, int8_t* zero_points, void* stor) { auto ptr = reinterpret_cast(stor); auto quant_dtype = ptr->mDType; - if (quant_dtype == JBLAS_DTYPE::S8) { - kernel::wrapper::QuantizeSignIntRowBlock::forward( + if (quant_dtype == BTLA_DTYPE::S8) { + kernel::wrapper::QuantizeSignIntRowBlock::forward( srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, ptr->mBlockSize); - } else if (quant_dtype == JBLAS_DTYPE::S4_FULLRANGE) { - kernel::wrapper::QuantizeSignIntRowBlock::forward( + } else if (quant_dtype == BTLA_DTYPE::S4_FULLRANGE) { + kernel::wrapper::QuantizeSignIntRowBlock::forward( srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, ptr->mBlockSize); - } else if (quant_dtype == JBLAS_DTYPE::S4_CLIP) { - kernel::wrapper::QuantizeSignIntRowBlock::forward( + } else if (quant_dtype == BTLA_DTYPE::S4_CLIP) { + kernel::wrapper::QuantizeSignIntRowBlock::forward( srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, ptr->mBlockSize); } } - static inline JBLAS_CODE doCompress(const int8_t* srcptr, void* dstptr, int row, int col, int ld_src, int ld_dst, - JBLAS_DTYPE quant_dtype) { - if (quant_dtype == JBLAS_DTYPE::S4_CLIP || quant_dtype == JBLAS_DTYPE::S4_FULLRANGE) { + static inline BTLA_CODE doCompress(const int8_t* srcptr, void* dstptr, int row, int col, int ld_src, int ld_dst, + BTLA_DTYPE quant_dtype) { + if (quant_dtype == BTLA_DTYPE::S4_CLIP || quant_dtype == BTLA_DTYPE::S4_FULLRANGE) { return kernel::wrapper::CompressS8S4<_GemmCore_T::NTILE>::template forward( srcptr, reinterpret_cast(dstptr), row, col, ld_src, ld_dst); - } else if (quant_dtype == JBLAS_DTYPE::F4_BNB || quant_dtype == JBLAS_DTYPE::F4_NF4 || - quant_dtype == JBLAS_DTYPE::F4_E2M1) { + } else if (quant_dtype == BTLA_DTYPE::F4_BNB || quant_dtype == BTLA_DTYPE::F4_NF4 || + quant_dtype == BTLA_DTYPE::F4_E2M1) { return kernel::wrapper::CompressFp4<_GemmCore_T::NTILE>::template forward( srcptr, reinterpret_cast(dstptr), row, col, ld_src, ld_dst); // ld_dst here not stride } else { assert(0); - return JblasNotSupport; + return BTLA_CODE::NotSupport; } } }; @@ -806,13 +806,13 @@ struct ParamWeightKBlockNFloat { storage::gemm::StorageWeightKBlockNFloat* packedW; }; -template +template class WeightKBlockNFloat : public WeightKBlockNInteger<_GemmCore_T, ISA_T> { public: using Param = ParamWeightKBlockNInteger; // NFloat storage Param same with NInteger storage. using StorageWeight = storage::gemm::StorageWeightKBlockNFloat; - StorageWeight createStorage(const int N, const int K, int blocksize, JBLAS_DTYPE fT, JBLAS_DTYPE scaT) { + StorageWeight createStorage(const int N, const int K, int blocksize, BTLA_DTYPE fT, BTLA_DTYPE scaT) { int KPad = utils::padto(K, _GemmCore_T::KTILE); int NPad = utils::padto(N, _GemmCore_T::NTILE); StorageWeight tmp(_GemmCore_T::ID); @@ -820,81 +820,81 @@ class WeightKBlockNFloat : public WeightKBlockNInteger<_GemmCore_T, ISA_T> { return tmp; } - inline JBLAS_CODE getWeight(float** dstptr, int* dststep, int k_size, int n_size, int k_offset, int n_offset, - const Param& _param, void* tmpcache, size_t cachesize) override { + inline BTLA_CODE getWeight(float** dstptr, int* dststep, int k_size, int n_size, int k_offset, int n_offset, + const Param& _param, void* tmpcache, size_t cachesize) override { return getFpWeight(dstptr, dststep, k_size, n_size, k_offset, n_offset, _param, tmpcache, cachesize); } - inline JBLAS_CODE getWeight(utils::bf16** dstptr, int* dststep, int k_size, int n_size, int k_offset, int n_offset, - const Param& _param, void* tmpcache, size_t cachesize) override { + inline BTLA_CODE getWeight(utils::bf16** dstptr, int* dststep, int k_size, int n_size, int k_offset, int n_offset, + const Param& _param, void* tmpcache, size_t cachesize) override { return getFpWeight(dstptr, dststep, k_size, n_size, k_offset, n_offset, _param, tmpcache, cachesize); } template - inline JBLAS_CODE getFpWeight(_DST_T** dstptr, int* dststep, int k_size, int n_size, int k_offset, int n_offset, - const Param& _param, void* tmpcache, size_t cachesize) { + inline BTLA_CODE getFpWeight(_DST_T** dstptr, int* dststep, int k_size, int n_size, int k_offset, int n_offset, + const Param& _param, void* tmpcache, size_t cachesize) { auto wptr = reinterpret_cast(_param.packedW); auto NPad = wptr->mNPad; auto KPad = wptr->mKPad; char* bptr; - if (wptr->mDType == JBLAS_DTYPE::F8_E5M2 || wptr->mDType == JBLAS_DTYPE::F8_E4M3) { + if (wptr->mDType == BTLA_DTYPE::F8_E5M2 || wptr->mDType == BTLA_DTYPE::F8_E4M3) { bptr = wptr->template WPtr() + n_offset * KPad + k_offset * _GemmCore_T::NTILE; } else { bptr = wptr->template WPtr() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2; } int constexpr ColSize = _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW; for (int i = 0; i < n_size; i += _GemmCore_T::NTILE) { - if (wptr->SDtype() == JBLAS_DTYPE::F8_E8M0) { - assert(wptr->mDType == JBLAS_DTYPE::F8_E4M3 || wptr->mDType == JBLAS_DTYPE::F8_E5M2); + if (wptr->SDtype() == BTLA_DTYPE::F8_E8M0) { + assert(wptr->mDType == BTLA_DTYPE::F8_E4M3 || wptr->mDType == BTLA_DTYPE::F8_E5M2); auto sptr = wptr->template SPtr() + n_offset + i; kernel::wrapper::DecompressKBlockF8FP<_GemmCore_T::PACK_ROW>::template forward( reinterpret_cast(bptr) + i * KPad, *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, sptr, k_offset / _GemmCore_T::PACK_ROW, wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, wptr->mDType); - } else if (wptr->SDtype() == JBLAS_DTYPE::F32) { + } else if (wptr->SDtype() == BTLA_DTYPE::F32) { auto sptr = wptr->template SPtr() + n_offset + i; auto f4ptr = reinterpret_cast(bptr + i * KPad / 2); auto fp_ptr = *dstptr + i * k_size; - if (wptr->mDType == JBLAS_DTYPE::F8_E4M3 || wptr->mDType == JBLAS_DTYPE::F8_E5M2) { + if (wptr->mDType == BTLA_DTYPE::F8_E4M3 || wptr->mDType == BTLA_DTYPE::F8_E5M2) { kernel::wrapper::DecompressKBlockF8FP<_GemmCore_T::PACK_ROW>::template forward( reinterpret_cast(bptr) + i * KPad, *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, sptr, k_offset / _GemmCore_T::PACK_ROW, wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, wptr->mDType); - } else if (wptr->mDType == JBLAS_DTYPE::F4_NF4) { + } else if (wptr->mDType == BTLA_DTYPE::F4_NF4) { kernel::wrapper::DecompressKBlockF4Fp<_DST_T, _GemmCore_T::PACK_ROW>::template forward( + BTLA_DTYPE::F4_NF4>( f4ptr, fp_ptr, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, sptr, k_offset / _GemmCore_T::PACK_ROW, wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, tmpcache, cachesize); - } else if (wptr->mDType == JBLAS_DTYPE::F4_E2M1) { + } else if (wptr->mDType == BTLA_DTYPE::F4_E2M1) { kernel::wrapper::DecompressKBlockF4Fp<_DST_T, _GemmCore_T::PACK_ROW>::template forward( + BTLA_DTYPE::F4_E2M1>( f4ptr, fp_ptr, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, sptr, k_offset / _GemmCore_T::PACK_ROW, wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, tmpcache, cachesize); - } else if (wptr->mDType == JBLAS_DTYPE::F4_BNB) { + } else if (wptr->mDType == BTLA_DTYPE::F4_BNB) { kernel::wrapper::DecompressKBlockF4Fp<_DST_T, _GemmCore_T::PACK_ROW>::template forward( + BTLA_DTYPE::F4_BNB>( f4ptr, fp_ptr, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, sptr, k_offset / _GemmCore_T::PACK_ROW, wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, tmpcache, cachesize); } else { assert(0); } - } else if (wptr->SDtype() == JBLAS_DTYPE::BF16) { + } else if (wptr->SDtype() == BTLA_DTYPE::BF16) { auto sptr = wptr->template SPtr() + n_offset + i; auto f4ptr = reinterpret_cast(bptr + i * KPad / 2); auto fp_ptr = *dstptr + i * k_size; - if (wptr->mDType == JBLAS_DTYPE::F4_NF4) { + if (wptr->mDType == BTLA_DTYPE::F4_NF4) { kernel::wrapper::DecompressKBlockF4Fp<_DST_T, _GemmCore_T::PACK_ROW>::template forward( + BTLA_DTYPE::F4_NF4>( f4ptr, fp_ptr, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, sptr, k_offset / _GemmCore_T::PACK_ROW, wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, tmpcache, cachesize); - } else if (wptr->mDType == JBLAS_DTYPE::F4_E2M1) { + } else if (wptr->mDType == BTLA_DTYPE::F4_E2M1) { kernel::wrapper::DecompressKBlockF4Fp<_DST_T, _GemmCore_T::PACK_ROW>::template forward( + BTLA_DTYPE::F4_E2M1>( f4ptr, fp_ptr, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, sptr, k_offset / _GemmCore_T::PACK_ROW, wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, tmpcache, cachesize); - } else if (wptr->mDType == JBLAS_DTYPE::F4_BNB) { + } else if (wptr->mDType == BTLA_DTYPE::F4_BNB) { kernel::wrapper::DecompressKBlockF4Fp<_DST_T, _GemmCore_T::PACK_ROW>::template forward( + BTLA_DTYPE::F4_BNB>( f4ptr, fp_ptr, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, sptr, k_offset / _GemmCore_T::PACK_ROW, wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, tmpcache, cachesize); } else { @@ -905,38 +905,38 @@ class WeightKBlockNFloat : public WeightKBlockNInteger<_GemmCore_T, ISA_T> { } } *dststep = k_size; - return JblasSuccess; + return BTLA_CODE::Success; } template - static inline JBLAS_CODE getKBlockWeight(T** dstptr, int* dststep, int k_size, int n_size, int k_offset, int n_offset, - const Param& _param, void* tmpcache, size_t cachesize) { + static inline BTLA_CODE getKBlockWeight(T** dstptr, int* dststep, int k_size, int n_size, int k_offset, int n_offset, + const Param& _param, void* tmpcache, size_t cachesize) { auto wptr = reinterpret_cast(_param.packedW); auto NPad = wptr->mNPad; auto KPad = wptr->mKPad; char* bptr; - if (wptr->mDType == JBLAS_DTYPE::F8_E4M3 || wptr->mDType == JBLAS_DTYPE::F8_E5M2) { + if (wptr->mDType == BTLA_DTYPE::F8_E4M3 || wptr->mDType == BTLA_DTYPE::F8_E5M2) { bptr = wptr->template WPtr() + n_offset * KPad + k_offset * _GemmCore_T::NTILE; } else { bptr = wptr->template WPtr() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2; } int constexpr ColSize = _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW; for (int i = 0; i < n_size; i += _GemmCore_T::NTILE) { - if (wptr->mDType == JBLAS_DTYPE::F8_E4M3 || wptr->mDType == JBLAS_DTYPE::F8_E5M2) { + if (wptr->mDType == BTLA_DTYPE::F8_E4M3 || wptr->mDType == BTLA_DTYPE::F8_E5M2) { kernel::wrapper::DecompressKBlockF8FpNoScale::template forward( reinterpret_cast(bptr) + i * KPad, *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, tmpcache, cachesize, wptr->mDType); } else { auto f4ptr = reinterpret_cast(bptr + i * KPad / 2); auto fp_ptr = *dstptr + i * k_size; - if (wptr->mDType == JBLAS_DTYPE::F4_NF4) { - kernel::wrapper::DecompressKBlockF4FpNoscale::template forward( + if (wptr->mDType == BTLA_DTYPE::F4_NF4) { + kernel::wrapper::DecompressKBlockF4FpNoscale::template forward( f4ptr, fp_ptr, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, tmpcache, cachesize); - } else if (wptr->mDType == JBLAS_DTYPE::F4_E2M1) { - kernel::wrapper::DecompressKBlockF4FpNoscale::template forward( + } else if (wptr->mDType == BTLA_DTYPE::F4_E2M1) { + kernel::wrapper::DecompressKBlockF4FpNoscale::template forward( f4ptr, fp_ptr, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, tmpcache, cachesize); - } else if (wptr->mDType == JBLAS_DTYPE::F4_BNB) { - kernel::wrapper::DecompressKBlockF4FpNoscale::template forward( + } else if (wptr->mDType == BTLA_DTYPE::F4_BNB) { + kernel::wrapper::DecompressKBlockF4FpNoscale::template forward( f4ptr, fp_ptr, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, tmpcache, cachesize); } else { assert(0); @@ -944,7 +944,7 @@ class WeightKBlockNFloat : public WeightKBlockNInteger<_GemmCore_T, ISA_T> { } } *dststep = k_size; - return JblasSuccess; + return BTLA_CODE::Success; } protected: @@ -952,20 +952,20 @@ class WeightKBlockNFloat : public WeightKBlockNInteger<_GemmCore_T, ISA_T> { int8_t* zero_points, void* stor) override { auto ptr = reinterpret_cast(stor); auto quant_dtype = ptr->mDType; - if (quant_dtype == JBLAS_DTYPE::F8_E4M3) { - kernel::wrapper::QuantizeF8RowBlock::forward( + if (quant_dtype == BTLA_DTYPE::F8_E4M3) { + kernel::wrapper::QuantizeF8RowBlock::forward( srcptr, dstptr, row, col, ld_src, ld_dst, scales, ptr->mBlockSize, ptr->SDtype()); - } else if (quant_dtype == JBLAS_DTYPE::F8_E5M2) { - kernel::wrapper::QuantizeF8RowBlock::forward( + } else if (quant_dtype == BTLA_DTYPE::F8_E5M2) { + kernel::wrapper::QuantizeF8RowBlock::forward( srcptr, dstptr, row, col, ld_src, ld_dst, scales, ptr->mBlockSize, ptr->SDtype()); - } else if (quant_dtype == JBLAS_DTYPE::F4_BNB) { - kernel::wrapper::QuantizeF4RowBlock::forward(srcptr, dstptr, row, col, ld_src, ld_dst, + } else if (quant_dtype == BTLA_DTYPE::F4_BNB) { + kernel::wrapper::QuantizeF4RowBlock::forward(srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, ptr->mBlockSize); - } else if (quant_dtype == JBLAS_DTYPE::F4_E2M1) { - kernel::wrapper::QuantizeF4RowBlock::forward( + } else if (quant_dtype == BTLA_DTYPE::F4_E2M1) { + kernel::wrapper::QuantizeF4RowBlock::forward( srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, ptr->mBlockSize); - } else if (quant_dtype == JBLAS_DTYPE::F4_NF4) { - kernel::wrapper::QuantizeF4RowBlock::forward(srcptr, dstptr, row, col, ld_src, ld_dst, + } else if (quant_dtype == BTLA_DTYPE::F4_NF4) { + kernel::wrapper::QuantizeF4RowBlock::forward(srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, ptr->mBlockSize); } else { assert(0); @@ -974,4 +974,4 @@ class WeightKBlockNFloat : public WeightKBlockNInteger<_GemmCore_T, ISA_T> { }; } // namespace gemm } // namespace prologue_b -} // namespace jblas +} // namespace bestla diff --git a/bestla/jblas/jit_blas_storage.h b/bestla/bestla/bestla_storage.h similarity index 84% rename from bestla/jblas/jit_blas_storage.h rename to bestla/bestla/bestla_storage.h index e3c49aa4d..5f4993f72 100644 --- a/bestla/jblas/jit_blas_storage.h +++ b/bestla/bestla/bestla_storage.h @@ -12,12 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. #pragma once -#include "jit_base.h" -#include "jit_blas.h" -#include "jit_blas_gemm.h" -#include "jit_blas_utils.h" +#include "bestla.h" +#include "bestla_gemm.h" +#include "bestla_utils.h" -namespace jblas { +namespace bestla { namespace storage { constexpr size_t Alignment = 64; @@ -154,7 +153,7 @@ class ObjectQuantCorrection : public ISerialObject { public: size_t mCSize = 0; int mCStep = 0; - JBLAS_DTYPE mScaT = JBLAS_DTYPE::F32, mZpT = JBLAS_DTYPE::F32, mRedT = JBLAS_DTYPE::F32; + BTLA_DTYPE mScaT = BTLA_DTYPE::F32, mZpT = BTLA_DTYPE::F32, mRedT = BTLA_DTYPE::F32; ObjectAlignedBuffer mScaleBuf; ObjectOptionalBuffer mZpBuf, mRedBuf; @@ -162,7 +161,7 @@ class ObjectQuantCorrection : public ISerialObject { public: int mScaEleSize = 0, mZpEleSize = 0, mRedEleSize = 0; - size_t resize(int Rows, int Step, JBLAS_DTYPE scalet, JBLAS_DTYPE zpt, JBLAS_DTYPE redt, bool _is_asym, + size_t resize(int Rows, int Step, BTLA_DTYPE scalet, BTLA_DTYPE zpt, BTLA_DTYPE redt, bool _is_asym, bool _has_reduce) { mScaT = scalet; mZpT = zpt; @@ -199,16 +198,16 @@ class ObjectQuantCorrection : public ISerialObject { } virtual void deserializeBuffer(int8_t*& rptr, bool locate_buf) override { if (!locate_buf) { - mScaT = utils::deserialize(rptr); - mZpT = utils::deserialize(rptr); - mRedT = utils::deserialize(rptr); + mScaT = utils::deserialize(rptr); + mZpT = utils::deserialize(rptr); + mRedT = utils::deserialize(rptr); updateSize(); mCStep = utils::deserialize(rptr); mCSize = utils::deserialize(rptr); } else { - utils::serialize(rptr, mScaT); - utils::serialize(rptr, mZpT); - utils::serialize(rptr, mRedT); + utils::serialize(rptr, mScaT); + utils::serialize(rptr, mZpT); + utils::serialize(rptr, mRedT); utils::serialize(rptr, mCStep); utils::serialize(rptr, mCSize); } @@ -219,9 +218,9 @@ class ObjectQuantCorrection : public ISerialObject { protected: inline void updateSize() { - mScaEleSize = int(utils::jblas_dtype_size(mScaT)); - mZpEleSize = int(utils::jblas_dtype_size(mZpT)); - mRedEleSize = int(utils::jblas_dtype_size(mRedT)); + mScaEleSize = int(utils::bestla_dtype_size(mScaT)); + mZpEleSize = int(utils::bestla_dtype_size(mZpT)); + mRedEleSize = int(utils::bestla_dtype_size(mRedT)); } inline constexpr size_t getMiscSize() { @@ -237,9 +236,9 @@ class ObjectQuantCorrection : public ISerialObject { class IWeightBase : public storage::ISerializable { public: - JBLAS_PROLOGUEB_IDS mPrologueID = JBLAS_PROLOGUEB_IDS::Undef; + BTLA_PROLOGUEB_IDS mPrologueID = BTLA_PROLOGUEB_IDS::Undef; uint64_t mCoreId = 0; - JBLAS_DTYPE mDType = JBLAS_DTYPE::F32; + BTLA_DTYPE mDType = BTLA_DTYPE::F32; int mNPad = 0, mKPad = 0; int mN = 0, mK = 0; @@ -249,7 +248,7 @@ class IWeightBase : public storage::ISerializable { static constexpr inline size_t offset() { return sizeof(mSize); } protected: - void resize(int NPad, int KPad, int N, int K, JBLAS_DTYPE dtype) { + void resize(int NPad, int KPad, int N, int K, BTLA_DTYPE dtype) { mNPad = NPad; mKPad = KPad; mN = N; @@ -273,21 +272,21 @@ class IWeightBase : public storage::ISerializable { virtual void deserializeBuffer(int8_t*& rptr, bool map_buf) { ISerializable::deserializeBuffer(rptr, map_buf); if (!map_buf) { - mPrologueID = utils::deserialize(rptr); + mPrologueID = utils::deserialize(rptr); mCoreId = utils::deserialize(rptr); mNPad = utils::deserialize(rptr); mKPad = utils::deserialize(rptr); mN = utils::deserialize(rptr); mK = utils::deserialize(rptr); - mDType = utils::deserialize(rptr); + mDType = utils::deserialize(rptr); } else { - utils::serialize(rptr, mPrologueID); + utils::serialize(rptr, mPrologueID); utils::serialize(rptr, mCoreId); utils::serialize(rptr, mNPad); utils::serialize(rptr, mKPad); utils::serialize(rptr, mN); utils::serialize(rptr, mK); - utils::serialize(rptr, mDType); + utils::serialize(rptr, mDType); } } @@ -308,7 +307,7 @@ class IWeightKBlockBase : public IWeightBase { public: int mBlockSize = 1; IWeightKBlockBase(uint64_t _id) : IWeightBase(_id) {} - void resize(int NPad, int KPad, int Block, int N, int K, JBLAS_DTYPE dtype) { + void resize(int NPad, int KPad, int Block, int N, int K, BTLA_DTYPE dtype) { IWeightBase::resize(NPad, KPad, N, K, dtype); mBlockSize = Block; } @@ -341,9 +340,9 @@ class IWeightKBlockBase : public IWeightBase { class IActivationBase : public storage::ISerializable { public: - JBLAS_PROLOGUEB_IDS mPrologueID = JBLAS_PROLOGUEB_IDS::Undef; + BTLA_PROLOGUEB_IDS mPrologueID = BTLA_PROLOGUEB_IDS::Undef; uint64_t mCoreId = 0; - JBLAS_DTYPE mDType = JBLAS_DTYPE::F32; + BTLA_DTYPE mDType = BTLA_DTYPE::F32; int mMPad = 0, mKPad = 0; int mM = 0, mK = 0; @@ -353,7 +352,7 @@ class IActivationBase : public storage::ISerializable { static constexpr inline size_t offset() { return sizeof(mSize); } protected: - void resize(int NPad, int KPad, int N, int K, JBLAS_DTYPE dtype) { + void resize(int NPad, int KPad, int N, int K, BTLA_DTYPE dtype) { mMPad = NPad; mKPad = KPad; mM = N; @@ -377,21 +376,21 @@ class IActivationBase : public storage::ISerializable { virtual void deserializeBuffer(int8_t*& rptr, bool map_buf) { ISerializable::deserializeBuffer(rptr, map_buf); if (!map_buf) { - mPrologueID = utils::deserialize(rptr); + mPrologueID = utils::deserialize(rptr); mCoreId = utils::deserialize(rptr); mMPad = utils::deserialize(rptr); mKPad = utils::deserialize(rptr); mM = utils::deserialize(rptr); mK = utils::deserialize(rptr); - mDType = utils::deserialize(rptr); + mDType = utils::deserialize(rptr); } else { - utils::serialize(rptr, mPrologueID); + utils::serialize(rptr, mPrologueID); utils::serialize(rptr, mCoreId); utils::serialize(rptr, mMPad); utils::serialize(rptr, mKPad); utils::serialize(rptr, mM); utils::serialize(rptr, mK); - utils::serialize(rptr, mDType); + utils::serialize(rptr, mDType); } } @@ -412,7 +411,7 @@ class IActivationKBlockBase : public IActivationBase { public: int mBlockSize = 1; IActivationKBlockBase(uint64_t _id) : IActivationBase(_id) {} - void resize(int MPad, int KPad, int Block, int N, int K, JBLAS_DTYPE dtype) { + void resize(int MPad, int KPad, int Block, int N, int K, BTLA_DTYPE dtype) { IActivationBase::resize(MPad, KPad, N, K, dtype); mBlockSize = Block; } @@ -446,11 +445,11 @@ class IActivationKBlockBase : public IActivationBase { class StoragePackedWeight : public IWeightBase { public: ObjectAlignedBuffer mWBuf; - StoragePackedWeight(uint64_t _id) : IWeightBase(_id) { mPrologueID = JBLAS_PROLOGUEB_IDS::WeightPack; } + StoragePackedWeight(uint64_t _id) : IWeightBase(_id) { mPrologueID = BTLA_PROLOGUEB_IDS::WeightPack; } - size_t resize(int NPad, int KPad, int N, int K, JBLAS_DTYPE dtype) { + size_t resize(int NPad, int KPad, int N, int K, BTLA_DTYPE dtype) { IWeightBase::resize(NPad, KPad, N, K, dtype); - auto bsize = static_cast(NPad) * KPad * jblas::utils::jblas_dtype_size(dtype); + auto bsize = static_cast(NPad) * KPad * utils::bestla_dtype_size(dtype); mWBuf.resize(bsize); mSize = IWeightBase::getSerializedSize() + mWBuf.getSerializedSize(); mSize = utils::padto(mSize, Alignment); @@ -483,12 +482,12 @@ class StorageReduce : public ISerializable { using CorrectionType = ObjectQuantCorrection; int m = 0, k = 0, lda = 0, kblock = 1; ObjectAlignedBuffer mRedBuf; - size_t resize(int _m, int _k, int _kblock, JBLAS_DTYPE redt) { + size_t resize(int _m, int _k, int _kblock, BTLA_DTYPE redt) { kblock = _kblock; m = _m; k = _k; lda = utils::updiv(_k, _kblock); - size_t bufsize = static_cast(m) * lda * utils::jblas_dtype_size(redt); + size_t bufsize = static_cast(m) * lda * utils::bestla_dtype_size(redt); mRedBuf.resize(bufsize); mSize = getSerializedSize(); mSize = utils::padto(mSize, Alignment); @@ -555,11 +554,11 @@ class StorageReduce : public ISerializable { class StorageReorderActivation : public IActivationKBlockBase { public: ObjectAlignedBuffer mABuf; - StorageReorderActivation(uint64_t _id) : IActivationKBlockBase(_id) { mPrologueID = JBLAS_PROLOGUEB_IDS::WeightPack; } + StorageReorderActivation(uint64_t _id) : IActivationKBlockBase(_id) { mPrologueID = BTLA_PROLOGUEB_IDS::WeightPack; } - size_t resize(int MPad, int KPad, int M, int K, int KBlock, JBLAS_DTYPE dtype) { + size_t resize(int MPad, int KPad, int M, int K, int KBlock, BTLA_DTYPE dtype) { IActivationKBlockBase::resize(MPad, KPad, KBlock, M, K, dtype); - auto bsize = static_cast(MPad) * KPad * jblas::utils::jblas_dtype_size(dtype); + auto bsize = static_cast(MPad) * KPad * utils::bestla_dtype_size(dtype); mABuf.resize(bsize); mSize = IActivationKBlockBase::getSerializedSize() + mABuf.getSerializedSize(); mSize = utils::padto(mSize, Alignment); @@ -593,14 +592,14 @@ class StorageQuantActivation : public IActivationKBlockBase { CorrectionType mCorrection; ObjectAlignedBuffer mQBuf; StorageQuantActivation(uint64_t _id = 0) : IActivationKBlockBase(_id) { - mPrologueID = JBLAS_PROLOGUEB_IDS::WeightPack; + mPrologueID = BTLA_PROLOGUEB_IDS::WeightPack; } - size_t resize(int _mpad, int _kpad, int _m, int _k, int _kblock, JBLAS_DTYPE buft, JBLAS_DTYPE scalet, - JBLAS_DTYPE zpt, JBLAS_DTYPE redt, bool is_asym, bool has_reduce) { + size_t resize(int _mpad, int _kpad, int _m, int _k, int _kblock, BTLA_DTYPE buft, BTLA_DTYPE scalet, BTLA_DTYPE zpt, + BTLA_DTYPE redt, bool is_asym, bool has_reduce) { IActivationKBlockBase::resize(_mpad, _kpad, _kblock, _m, _k, buft); mCorrection.resize(_mpad, utils::updiv(_kpad, _kblock), scalet, zpt, redt, is_asym, has_reduce); - size_t bufsize = static_cast(_mpad) * _kpad * utils::jblas_dtype_size(buft); + size_t bufsize = static_cast(_mpad) * _kpad * utils::bestla_dtype_size(buft); mQBuf.resize(bufsize); mSize = getSerializedSize(); mSize = utils::padto(mSize, Alignment); @@ -626,9 +625,9 @@ class StorageQuantActivation : public IActivationKBlockBase { return mCorrection.mRedBuf.get(); } - inline constexpr JBLAS_DTYPE RDtype() { return mCorrection.mRedT; } - inline constexpr JBLAS_DTYPE ZDtype() { return mCorrection.mZpT; } - inline constexpr JBLAS_DTYPE SDtype() { return mCorrection.mScaT; } + inline constexpr BTLA_DTYPE RDtype() { return mCorrection.mRedT; } + inline constexpr BTLA_DTYPE ZDtype() { return mCorrection.mZpT; } + inline constexpr BTLA_DTYPE SDtype() { return mCorrection.mScaT; } inline constexpr bool IsAsym() { return mCorrection.mZpBuf.mNotEmpty; } inline constexpr bool HasReduce() { return mCorrection.mRedBuf.mNotEmpty; } inline constexpr size_t CSize() { return mCorrection.mCSize; } @@ -680,20 +679,20 @@ class StorageWeightKBlockNInteger : public IWeightKBlockBase { CorrectionType mCorrection; ObjectOptionalBuffer mShuffleIndices; StorageWeightKBlockNInteger(uint64_t _type) : IWeightKBlockBase(_type) { - mPrologueID = JBLAS_PROLOGUEB_IDS::WeightKBlockNInteger; + mPrologueID = BTLA_PROLOGUEB_IDS::WeightKBlockNInteger; } - size_t resize(int NPad, int KPad, int Block, int N, int K, JBLAS_DTYPE qtype, JBLAS_DTYPE scalet, JBLAS_DTYPE redt, + size_t resize(int NPad, int KPad, int Block, int N, int K, BTLA_DTYPE qtype, BTLA_DTYPE scalet, BTLA_DTYPE redt, bool IsAsym) { - JBLAS_DTYPE zpt = JBLAS_DTYPE::S8; + BTLA_DTYPE zpt = BTLA_DTYPE::S8; InfoType::resize(NPad, KPad, Block, N, K, qtype); - auto bits = utils::jblas_dtype_bits(qtype); + auto bits = utils::bestla_dtype_bits(qtype); auto elesize = static_cast(NPad) * KPad; auto bytes = utils::updiv(elesize * bits, 8); // add 3bits, 5btis, 7bits size calculation here mQBuf.resize(bytes); int nk_scale = utils::updiv(KPad, Block); - auto gemm_comp = jblas::gemm::CoreAttr::get_comp(mCoreId); - auto is_cint = jblas::gemm::CompTypeHelper::is_integer(gemm_comp); + auto gemm_comp = bestla::gemm::CoreAttr::get_comp(mCoreId); + auto is_cint = bestla::gemm::CompTypeHelper::is_integer(gemm_comp); mCorrection.resize(nk_scale, NPad, scalet, zpt, redt, IsAsym, is_cint); update_size(); return mSize; @@ -705,9 +704,9 @@ class StorageWeightKBlockNInteger : public IWeightKBlockBase { update_size(); } - inline constexpr JBLAS_DTYPE RDtype() { return mCorrection.mRedT; } - inline constexpr JBLAS_DTYPE ZDtype() { return mCorrection.mZpT; } - inline constexpr JBLAS_DTYPE SDtype() { return mCorrection.mScaT; } + inline constexpr BTLA_DTYPE RDtype() { return mCorrection.mRedT; } + inline constexpr BTLA_DTYPE ZDtype() { return mCorrection.mZpT; } + inline constexpr BTLA_DTYPE SDtype() { return mCorrection.mScaT; } inline constexpr bool IsAsym() { return mCorrection.mZpBuf.mNotEmpty; } inline constexpr bool HasReduce() { return mCorrection.mRedBuf.mNotEmpty; } inline constexpr size_t CSize() { return mCorrection.mCSize; } @@ -771,18 +770,18 @@ class StorageWeightKBlockNInteger : public IWeightKBlockBase { class StorageWeightKBlockNFloat : public StorageWeightKBlockNInteger { public: StorageWeightKBlockNFloat(uint64_t _type) : StorageWeightKBlockNInteger(_type) { - mPrologueID = JBLAS_PROLOGUEB_IDS::WeightKBlockNFloat; + mPrologueID = BTLA_PROLOGUEB_IDS::WeightKBlockNFloat; } - size_t resize(int NPad, int KPad, int Block, int N, int K, JBLAS_DTYPE ftype, JBLAS_DTYPE scalet) { + size_t resize(int NPad, int KPad, int Block, int N, int K, BTLA_DTYPE ftype, BTLA_DTYPE scalet) { StorageWeightKBlockNInteger::InfoType::resize(NPad, KPad, Block, N, K, ftype); - auto bits = utils::jblas_dtype_bits(ftype); + auto bits = utils::bestla_dtype_bits(ftype); auto elesize = static_cast(NPad) * KPad; auto bytes = utils::updiv(elesize * bits, 8); // add fp6 size calculation here StorageWeightKBlockNInteger::mQBuf.resize(bytes); int nk_scale = utils::updiv(KPad, Block); - StorageWeightKBlockNInteger::mCorrection.resize(nk_scale, NPad, scalet, JBLAS_DTYPE::EleBitsUndef, - JBLAS_DTYPE::EleBitsUndef, false, false); + StorageWeightKBlockNInteger::mCorrection.resize(nk_scale, NPad, scalet, BTLA_DTYPE::EleBitsUndef, + BTLA_DTYPE::EleBitsUndef, false, false); mSize = StorageWeightKBlockNInteger::InfoType::getSerializedSize() + StorageWeightKBlockNInteger::mQBuf.getSerializedSize() + StorageWeightKBlockNInteger::mCorrection.getSerializedSize(); @@ -801,17 +800,17 @@ class PackedWeightParser { rptr += IWeightBase::offset(); int mProID = utils::deserialize(rptr); IWeightBase* ptr = NULL; - if (mProID >= int(JBLAS_PROLOGUEB_IDS::Begin) && mProID < int(JBLAS_PROLOGUEB_IDS::End)) { + if (mProID >= int(BTLA_PROLOGUEB_IDS::Begin) && mProID < int(BTLA_PROLOGUEB_IDS::End)) { rptr = reinterpret_cast(serialized_buf); - auto type = static_cast(mProID); + auto type = static_cast(mProID); switch (type) { - case JBLAS_PROLOGUEB_IDS::WeightPack: + case BTLA_PROLOGUEB_IDS::WeightPack: ptr = new gemm::StoragePackedWeight(0); break; - case JBLAS_PROLOGUEB_IDS::WeightKBlockNInteger: + case BTLA_PROLOGUEB_IDS::WeightKBlockNInteger: ptr = new gemm::StorageWeightKBlockNInteger(0); break; - case JBLAS_PROLOGUEB_IDS::WeightKBlockNFloat: + case BTLA_PROLOGUEB_IDS::WeightKBlockNFloat: ptr = new gemm::StorageWeightKBlockNFloat(0); break; default: @@ -826,4 +825,4 @@ class PackedWeightParser { }; } // namespace gemm } // namespace storage -} // namespace jblas +} // namespace bestla diff --git a/bestla/jblas/jit_blas_utils.h b/bestla/bestla/bestla_utils.h similarity index 87% rename from bestla/jblas/jit_blas_utils.h rename to bestla/bestla/bestla_utils.h index eed0c3bda..1acdc24a6 100644 --- a/bestla/jblas/jit_blas_utils.h +++ b/bestla/bestla/bestla_utils.h @@ -47,7 +47,7 @@ #define ARCH_REQ_XCOMP_PERM 0x1023 #endif -#include "jit_blas.h" +#include "bestla.h" // As long as the compiler supports the ISA, we will enable it. // Only the ISA you use in your project will be compiled. @@ -72,7 +72,7 @@ #include #endif -namespace jblas { +namespace bestla { namespace utils { template @@ -243,15 +243,15 @@ struct GemmProblem { }; template -inline constexpr JBLAS_DTYPE jblas_dtype = std::is_same_v ? JBLAS_DTYPE::F64 - : std::is_same_v ? JBLAS_DTYPE::F32 - : std::is_same_v ? JBLAS_DTYPE::BF16 - : std::is_same_v ? JBLAS_DTYPE::F16 - : std::is_same_v ? JBLAS_DTYPE::S8 - : std::is_same_v ? JBLAS_DTYPE::U8 - : std::is_same_v ? JBLAS_DTYPE::S32 - : std::is_same_v ? JBLAS_DTYPE::F8_E8M0 - : (assert(0), JBLAS_DTYPE::F32); +inline constexpr BTLA_DTYPE bestla_dtype = std::is_same_v ? BTLA_DTYPE::F64 + : std::is_same_v ? BTLA_DTYPE::F32 + : std::is_same_v ? BTLA_DTYPE::BF16 + : std::is_same_v ? BTLA_DTYPE::F16 + : std::is_same_v ? BTLA_DTYPE::S8 + : std::is_same_v ? BTLA_DTYPE::U8 + : std::is_same_v ? BTLA_DTYPE::S32 + : std::is_same_v ? BTLA_DTYPE::F8_E8M0 + : (assert(0), BTLA_DTYPE::F32); template inline constexpr const char* type_str = std::is_same_v ? "double" : std::is_same_v ? "float" @@ -262,75 +262,75 @@ inline constexpr const char* type_str = std::is_same_v ? "double" : std::is_same_v ? "f8" // TODO(zhe): more f8 cases? : (assert(0), "undef"); -inline const char* jblas_dtype_str(JBLAS_DTYPE dtype) { +inline const char* bestla_dtype_str(BTLA_DTYPE dtype) { switch (dtype) { - case JBLAS_DTYPE::F64: + case BTLA_DTYPE::F64: return "float64"; - case JBLAS_DTYPE::F32: + case BTLA_DTYPE::F32: return "float32"; - case JBLAS_DTYPE::F16: + case BTLA_DTYPE::F16: return "float16"; - case JBLAS_DTYPE::BF16: + case BTLA_DTYPE::BF16: return "bfloat16"; - case JBLAS_DTYPE::F8_E4M3: + case BTLA_DTYPE::F8_E4M3: return "fp8_e4m3"; - case JBLAS_DTYPE::F8_E5M2: + case BTLA_DTYPE::F8_E5M2: return "fp8_e5m2"; - case JBLAS_DTYPE::F8_E3M4: + case BTLA_DTYPE::F8_E3M4: return "fp8_e3m4"; - case JBLAS_DTYPE::S8: + case BTLA_DTYPE::S8: return "signed_int8"; - case JBLAS_DTYPE::U8: + case BTLA_DTYPE::U8: return "unsigned_int8"; - case JBLAS_DTYPE::S4_CLIP: + case BTLA_DTYPE::S4_CLIP: return "int4_clip"; - case JBLAS_DTYPE::S4_FULLRANGE: + case BTLA_DTYPE::S4_FULLRANGE: return "int4_fullrange"; - case JBLAS_DTYPE::F4_E2M1: + case BTLA_DTYPE::F4_E2M1: return "fp4_e2m1"; - case JBLAS_DTYPE::F4_BNB: + case BTLA_DTYPE::F4_BNB: return "fp4_bitsandbytes"; - case JBLAS_DTYPE::F4_NF4: + case BTLA_DTYPE::F4_NF4: return "fp4_nf4"; - case JBLAS_DTYPE::S32: + case BTLA_DTYPE::S32: return "signed_int32"; - case JBLAS_DTYPE::U32: + case BTLA_DTYPE::U32: return "unsigned_int32"; default: return "ErrType"; } } -template +template inline constexpr const char* dtype_str() { - return jblas_dtype_str(DT); + return bestla_dtype_str(DT); } -inline constexpr uint32_t jblas_dtype_get_mask_val(const JBLAS_DTYPE& t, const JBLAS_DTYPE& mask, - const JBLAS_DTYPE& shift) { +inline constexpr uint32_t bestla_dtype_get_mask_val(const BTLA_DTYPE& t, const BTLA_DTYPE& mask, + const BTLA_DTYPE& shift) { return (static_cast(t) & static_cast(mask)) >> static_cast(shift); } -inline constexpr size_t jblas_dtype_bits(const JBLAS_DTYPE t) { - return jblas_dtype_get_mask_val(t, JBLAS_DTYPE::EleBitsMask, JBLAS_DTYPE::EleBitsShift); +inline constexpr size_t bestla_dtype_bits(const BTLA_DTYPE t) { + return bestla_dtype_get_mask_val(t, BTLA_DTYPE::EleBitsMask, BTLA_DTYPE::EleBitsShift); } -inline constexpr size_t jblas_dtype_type(const JBLAS_DTYPE t) { - return jblas_dtype_get_mask_val(t, JBLAS_DTYPE::TypeMask, JBLAS_DTYPE::TypeShift); +inline constexpr size_t bestla_dtype_type(const BTLA_DTYPE t) { + return bestla_dtype_get_mask_val(t, BTLA_DTYPE::TypeMask, BTLA_DTYPE::TypeShift); } -inline constexpr size_t jblas_dtype_size(const JBLAS_DTYPE t) { - auto bits = jblas_dtype_get_mask_val(t, JBLAS_DTYPE::EleBitsMask, JBLAS_DTYPE::EleBitsShift); +inline constexpr size_t bestla_dtype_size(const BTLA_DTYPE t) { + auto bits = bestla_dtype_get_mask_val(t, BTLA_DTYPE::EleBitsMask, BTLA_DTYPE::EleBitsShift); return bits >> 3; // bits to bytes } -inline int jblas_dtype_get_f8_ebits(const JBLAS_DTYPE t) { +inline int bestla_dtype_get_f8_ebits(const BTLA_DTYPE t) { int ret = -1; switch (t) { - case JBLAS_DTYPE::F8_E4M3: + case BTLA_DTYPE::F8_E4M3: ret = 4; break; - case JBLAS_DTYPE::F8_E5M2: + case BTLA_DTYPE::F8_E5M2: ret = 5; break; default: @@ -339,13 +339,13 @@ inline int jblas_dtype_get_f8_ebits(const JBLAS_DTYPE t) { return ret; } -inline int jblas_dtype_get_f8_quant_mbits(const JBLAS_DTYPE t) { +inline int bestla_dtype_get_f8_quant_mbits(const BTLA_DTYPE t) { int ret = -1; switch (t) { - case JBLAS_DTYPE::F8_E4M3: + case BTLA_DTYPE::F8_E4M3: ret = 5; break; - case JBLAS_DTYPE::F8_E5M2: + case BTLA_DTYPE::F8_E5M2: ret = 4; break; default: @@ -354,11 +354,11 @@ inline int jblas_dtype_get_f8_quant_mbits(const JBLAS_DTYPE t) { return ret; } -inline float get_mxfp_maxnorm(const JBLAS_DTYPE t, int ebits, int mantissa_bits) { +inline float get_mxfp_maxnorm(const BTLA_DTYPE t, int ebits, int mantissa_bits) { auto emax = std::pow(2, ebits - 1); - if (t == JBLAS_DTYPE::F8_E5M2) emax -= 1; + if (t == BTLA_DTYPE::F8_E5M2) emax -= 1; auto max_norm = std::pow(2, emax); - if (t != JBLAS_DTYPE::F8_E4M3) { + if (t != BTLA_DTYPE::F8_E4M3) { max_norm *= ((std::pow(2, mantissa_bits - 1) - 1) / std::pow(2, mantissa_bits - 2)); } else { max_norm *= 1.75; @@ -384,16 +384,16 @@ static void request_perm_xtile_data() { static void request_perm_xtile_data() {} #endif -template +template class isa_base { public: - static bool constexpr avx = ISA_T >= JblasAVX; - static bool constexpr avx2 = ISA_T >= JblasAVX2; - static bool constexpr avx512f = ISA_T >= JblasAVX512F; - static bool constexpr avx512_vnni = ISA_T >= JblasAVX512_VNNI; - static bool constexpr avx512_fp16 = ISA_T >= JblasAVX512_FP16; - static bool constexpr amx_bf16 = ISA_T >= JblasAMX_BF16; - static bool constexpr amx_int8 = ISA_T >= JblasAMX_INT8; + static bool constexpr avx = ISA_T >= BTLA_ISA::AVX; + static bool constexpr avx2 = ISA_T >= BTLA_ISA::AVX2; + static bool constexpr avx512f = ISA_T >= BTLA_ISA::AVX512F; + static bool constexpr avx512_vnni = ISA_T >= BTLA_ISA::AVX512_VNNI; + static bool constexpr avx512_fp16 = ISA_T >= BTLA_ISA::AVX512_FP16; + static bool constexpr amx_bf16 = ISA_T >= BTLA_ISA::AMX_BF16; + static bool constexpr amx_int8 = ISA_T >= BTLA_ISA::AMX_INT8; }; static inline int padto_le(int src, int padding) { return src / padding * padding; } @@ -680,4 +680,4 @@ static float nf4_dequant_fp32_LUT[] = {0.f, 0.5626170039176941f, 0.7229568362236023f, 1.0f}; -} // namespace jblas +} // namespace bestla diff --git a/bestla/jblas/jit_blas_wrapper.h b/bestla/bestla/bestla_wrapper.h similarity index 95% rename from bestla/jblas/jit_blas_wrapper.h rename to bestla/bestla/bestla_wrapper.h index f27a1745a..d300aa03c 100644 --- a/bestla/jblas/jit_blas_wrapper.h +++ b/bestla/bestla/bestla_wrapper.h @@ -14,21 +14,21 @@ #pragma once #include -#include "jit_blas_epilogue.h" -#include "jit_blas_gemm.h" -#include "jit_blas_prologue_a.h" -#include "jit_blas_prologue_b.h" -#include "jit_blas_utils.h" +#include "bestla_epilogue.h" +#include "bestla_gemm.h" +#include "bestla_prologue_a.h" +#include "bestla_prologue_b.h" +#include "bestla_utils.h" -namespace jblas { +namespace bestla { namespace wrapper { namespace gemm { -template class _PrologueA_T, - template class _PrologueB_T, template class _Epilogue_T> +template class _PrologueA_T, + template class _PrologueB_T, template class _Epilogue_T> class LauncherBase { public: using GemmCore = _GemmCore_T; - static constexpr JBLAS_ISA ISA = _RT_ISA_T; + static constexpr BTLA_ISA ISA = _RT_ISA_T; using PrologueA = _PrologueA_T; using PrologueB = _PrologueB_T; using Epilogue = _Epilogue_T<_RT_ISA_T>; @@ -114,13 +114,13 @@ class LauncherBase { } }; -template class _PrologueA_T, - template class _PrologueB_T, template class _BlockEpilogue_T, - template class _Epilogue_T> +template class _PrologueA_T, + template class _PrologueB_T, template class _BlockEpilogue_T, + template class _Epilogue_T> class LauncherKBlock { public: using GemmCore = _GemmCore_T; - static constexpr JBLAS_ISA ISA = _RT_ISA_T; + static constexpr BTLA_ISA ISA = _RT_ISA_T; using PrologueA = _PrologueA_T; using PrologueB = _PrologueB_T; using Epilogue = _Epilogue_T<_RT_ISA_T>; @@ -281,12 +281,12 @@ class LauncherKBlock { } }; -template class _PrologueA_T, - template class _PrologueB_T, template class _Epilogue_T> +template class _PrologueA_T, + template class _PrologueB_T, template class _Epilogue_T> class LauncherIntKBlock { public: using GemmCore = _GemmCore_T; - static constexpr JBLAS_ISA ISA = _RT_ISA_T; + static constexpr BTLA_ISA ISA = _RT_ISA_T; using PrologueA = _PrologueA_T; using PrologueB = _PrologueB_T; using Epilogue = _Epilogue_T<_RT_ISA_T>; @@ -394,9 +394,8 @@ class LauncherIntKBlock { mProA.getScale(&scaleA_cache, &ldsa_cache, _param.paramA, m_remain, k_padded, (blk_m + i + _config.loc[0]), iterk, tmp_, _config.tmpcachesize); mGemmCore.forward(aptr_cache, bptr_cache, cptr_cache, zpA_cache, scaleA_cache, ldsa_cache, scaleB_cache, - reduceB_cache, ldsb_cache, m_remain, n_padded, k_padded, KBlock, - acache_step * sizeof(AType), bcache_stride, ccache_stride, iterk, 1.f, tmp_, - _config.tmpcachesize); + reduceB_cache, ldsb_cache, m_remain, n_padded, k_padded, KBlock, acache_step * sizeof(AType), + bcache_stride, ccache_stride, iterk, 1.f, tmp_, _config.tmpcachesize); } } mEpilogue.forward(tmpC, _config.block[1], (_config.loc[0] + blk_m), _config.loc[1] + blk_n, blk_msize, blk_nsize, @@ -475,4 +474,4 @@ class LauncherIntKBlock { }; } // namespace gemm } // namespace wrapper -} // namespace jblas +} // namespace bestla diff --git a/bestla/jblas/kernel_avx2.h b/bestla/bestla/kernel_avx2.h similarity index 82% rename from bestla/jblas/kernel_avx2.h rename to bestla/bestla/kernel_avx2.h index 1e9fdf287..c5b66b4bf 100644 --- a/bestla/jblas/kernel_avx2.h +++ b/bestla/bestla/kernel_avx2.h @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. #pragma once -#include "jblas/jit_blas.h" +#include "bestla.h" +#include "bestla_utils.h" #include "kernel_ref.h" -#include "jit_blas_utils.h" #if CompileAVX2() #include #endif -namespace jblas { +namespace bestla { namespace kernel { namespace avx2 { #if CompileAVX2() @@ -31,7 +31,7 @@ namespace avx2 { static uint8_t shuffle_map[] = {0x00, 0x01, 0x02, 0x03, 0xff, 0xff, 0xff, 0xff, 0x04, 0x05, 0x06, 0x07, 0xff, 0xff, 0xff, 0xff}; -template +template static inline __m128i unpack_4bits_sse(void* srcptr) { auto shuffle_v = _mm_loadu_si128(reinterpret_cast<__m128i*>(shuffle_map)); auto raw_data = _mm_loadl_epi64(reinterpret_cast<__m128i*>(srcptr)); @@ -43,7 +43,7 @@ static inline __m128i unpack_4bits_sse(void* srcptr) { auto xmm2 = _mm_unpacklo_epi8(xmm0, xmm1); auto xmm3 = _mm_unpackhi_epi8(xmm0, xmm1); xmm2 = _mm_unpacklo_epi64(xmm2, xmm3); - if constexpr (S4_T != JBLAS_DTYPE::S4_FULLRANGE) xmm2 = _mm_slli_epi32(xmm2, 4); + if constexpr (S4_T != BTLA_DTYPE::S4_FULLRANGE) xmm2 = _mm_slli_epi32(xmm2, 4); return xmm2; } @@ -70,10 +70,10 @@ inline __m128i ymm_cvt_fp32_bf16(__m256 vfp32) { return ymm_cvtepi32_epi16(_mm256_bsrli_epi128(_mm256_castps_si256(vfp32), 2)); } -template +template static inline void convert_s4_s8_16_sse(int8_t* dstptr, int8_t* srcptr) { auto dst0 = unpack_4bits_sse(srcptr); - if constexpr (S4_T == JBLAS_DTYPE::S4_FULLRANGE) { + if constexpr (S4_T == BTLA_DTYPE::S4_FULLRANGE) { auto s8 = _mm_set1_epi8(8); dst0 = _mm_sub_epi8(dst0, s8); } @@ -94,7 +94,7 @@ static inline void convert_s8_fp_v8(T* dstptr, int8_t* srcptr) { } static inline void fp4_pad_4bit(int8_t* dstptr, int8_t* srcptr) { - auto dst0 = unpack_4bits_sse(srcptr); + auto dst0 = unpack_4bits_sse(srcptr); _mm_storeu_si128(reinterpret_cast<__m128i*>(dstptr), dst0); } @@ -112,9 +112,9 @@ static inline void dequant_s8_N_avx2(float* dstptr, int8_t* srcptr, __m256* vsca } } -static inline JBLAS_CODE alphabeta_f32_f32(const float alpha, const float* srcptr, const int srcstep, const float beta, - const float* src1ptr, const int src1step, float* dstptr, const int dststep, - const int M, const int N) { +static inline BTLA_CODE alphabeta_f32_f32(const float alpha, const float* srcptr, const int srcstep, const float beta, + const float* src1ptr, const int src1step, float* dstptr, const int dststep, + const int M, const int N) { int constexpr Vlen = 8; auto vN = utils::padto_le(N, Vlen); auto valpha = _mm256_set1_ps(alpha); @@ -144,12 +144,12 @@ static inline JBLAS_CODE alphabeta_f32_f32(const float alpha, const float* srcpt } } } - return JblasSuccess; + return BTLA_CODE::Success; } template -JBLAS_CODE dequant_kblock_s8_fp_fwd(int8_t* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, - float* scales, int8_t* zero_points, int k_offset, int kblock, int NPad) { +BTLA_CODE dequant_kblock_s8_fp_fwd(int8_t* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, + float* scales, int8_t* zero_points, int k_offset, int kblock, int NPad) { const int Vlen = 8; size_t simd_process_num = utils::padto_le(col, Vlen); auto packrow4_permute_idx = _mm256_setr_epi32(0, 0, 0, 0, 1, 1, 1, 1); @@ -185,12 +185,12 @@ JBLAS_CODE dequant_kblock_s8_fp_fwd(int8_t* srcptr, _DST_T* dstptr, int row, int dstptr[i * ld_dst + j] = tmp * sptr[j / PACK_ROW]; } } - return JblasSuccess; + return BTLA_CODE::Success; } template -static inline JBLAS_CODE dequant_kblock_s8_fp(int8_t* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, - float* scales, int8_t* zero_points, int k_offset, int kblock, int NPad) { +static inline BTLA_CODE dequant_kblock_s8_fp(int8_t* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, + float* scales, int8_t* zero_points, int k_offset, int kblock, int NPad) { if (zero_points == nullptr) return dequant_kblock_s8_fp_fwd(srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, k_offset, kblock, NPad); @@ -200,9 +200,9 @@ static inline JBLAS_CODE dequant_kblock_s8_fp(int8_t* srcptr, _DST_T* dstptr, in } template -static inline JBLAS_CODE dequant_s32_fp32(const int32_t* srcptr, const int srcstep, float* dstptr, const int dststep, - const int row, const int col, const float* scaleA, const int ldsa, - const SCAB_T* scaleB) { +static inline BTLA_CODE dequant_s32_fp32(const int32_t* srcptr, const int srcstep, float* dstptr, const int dststep, + const int row, const int col, const float* scaleA, const int ldsa, + const SCAB_T* scaleB) { int col8 = utils::padto_le(col, 8); for (int irow = 0; irow < row; irow++) { auto scale = scaleA[irow * ldsa]; @@ -226,11 +226,11 @@ static inline JBLAS_CODE dequant_s32_fp32(const int32_t* srcptr, const int srcst dstptr[irow * dststep + icol] = scale * scaleB[icol] * srcptr[irow * srcstep + icol]; } } - return JblasSuccess; + return BTLA_CODE::Success; } -static inline JBLAS_CODE remove_act_zeropoint_bias(float* accptr, int ldacc, int row, int col, uint8_t* zps, - float* scales, int lds, const float* reduce) { +static inline BTLA_CODE remove_act_zeropoint_bias(float* accptr, int ldacc, int row, int col, uint8_t* zps, + float* scales, int lds, const float* reduce) { int constexpr VLen = 8; auto col8 = utils::padto_le(col, VLen); for (int i = 0; i < row; i++) { @@ -249,11 +249,11 @@ static inline JBLAS_CODE remove_act_zeropoint_bias(float* accptr, int ldacc, int } } } - return JblasSuccess; + return BTLA_CODE::Success; } -static inline JBLAS_CODE remove_wei_zeropoint_bias(float* accptr, int ldacc, int row, int col, int8_t* zps, - float* scales, int lds, const float* reduce) { +static inline BTLA_CODE remove_wei_zeropoint_bias(float* accptr, int ldacc, int row, int col, int8_t* zps, + float* scales, int lds, const float* reduce) { int constexpr VLen = 8; auto col8 = utils::padto_le(col, VLen); const int32_t mask[] = {-1, -1, 0, 0}; @@ -275,12 +275,12 @@ static inline JBLAS_CODE remove_wei_zeropoint_bias(float* accptr, int ldacc, int } } } - return JblasSuccess; + return BTLA_CODE::Success; } -static inline JBLAS_CODE remove_zeropoint_bias(float* accptr, int ldacc, int row, int col, uint8_t* zpa, int8_t* zpb, - float* scalea, float* scaleb, int lds, int k, const float* reducea, - const float* reduceb) { +static inline BTLA_CODE remove_zeropoint_bias(float* accptr, int ldacc, int row, int col, uint8_t* zpa, int8_t* zpb, + float* scalea, float* scaleb, int lds, int k, const float* reducea, + const float* reduceb) { int constexpr VLen = 8; auto col8 = utils::padto_le(col, VLen); auto vk = _mm256_set1_ps(static_cast(k)); @@ -311,12 +311,12 @@ static inline JBLAS_CODE remove_zeropoint_bias(float* accptr, int ldacc, int row } } } - return JblasSuccess; + return BTLA_CODE::Success; } -template -static inline JBLAS_CODE decompress_s4_s8(utils::int4x2* srcptr, int8_t* dstptr, int row, int col, int ld_src, - int ld_dst) { +template +static inline BTLA_CODE decompress_s4_s8(utils::int4x2* srcptr, int8_t* dstptr, int row, int col, int ld_src, + int ld_dst) { uint32_t mask = 0xf0f0f0f0; auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); if (col == ld_src) { @@ -328,17 +328,17 @@ static inline JBLAS_CODE decompress_s4_s8(utils::int4x2* srcptr, int8_t* dstptr, } for (; i < elesize; i += 2) { auto tmp = srcptr[i / 2]; - dstptr[i + 0] = jblas::kernel::ref::get_s8(tmp.x); - dstptr[i + 1] = jblas::kernel::ref::get_s8(tmp.y); + dstptr[i + 0] = kernel::ref::get_s8(tmp.x); + dstptr[i + 1] = kernel::ref::get_s8(tmp.y); } - return JblasSuccess; + return BTLA_CODE::Success; } - return JblasNotSupport; + return BTLA_CODE::NotSupport; } -template -inline JBLAS_CODE decompress_kblock_s4_s8fp(utils::int4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, - int ld_dst, int8_t* tmp, size_t tmpsize) { +template +inline BTLA_CODE decompress_kblock_s4_s8fp(utils::int4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, + int ld_dst, int8_t* tmp, size_t tmpsize) { uint32_t mask = 0xf0f0f0f0; auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); if (col == ld_src) { @@ -356,17 +356,17 @@ inline JBLAS_CODE decompress_kblock_s4_s8fp(utils::int4x2* srcptr, _DST_T* dstpt dstptr[i + 0] = static_cast<_DST_T>(static_cast(ref::get_s8(tmp.x))); dstptr[i + 1] = static_cast<_DST_T>(static_cast(ref::get_s8(tmp.y))); } - return JblasSuccess; + return BTLA_CODE::Success; } - return JblasSuccess; + return BTLA_CODE::Success; } template -inline JBLAS_CODE decompress_kblock_f8_fp(utils::f8* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, - _S_T* scales, int k_offset, int kblock, int NPad, JBLAS_DTYPE src_f8_type) { +inline BTLA_CODE decompress_kblock_f8_fp(utils::f8* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, + _S_T* scales, int k_offset, int kblock, int NPad, BTLA_DTYPE src_f8_type) { int align_col = col / 16 * 16; int col_tail = col - align_col; - auto ebits = utils::jblas_dtype_get_f8_ebits(src_f8_type); + auto ebits = utils::bestla_dtype_get_f8_ebits(src_f8_type); auto mantissabit = 7 - ebits; auto sign_revert_and_mask = _mm256_set1_epi32(0x80000000); auto e_revert_and_mask = _mm256_set1_epi32(0x0000007f); @@ -423,11 +423,11 @@ inline JBLAS_CODE decompress_kblock_f8_fp(utils::f8* srcptr, _DST_T* dstptr, int } } } - return JblasSuccess; + return BTLA_CODE::Success; } template -inline JBLAS_CODE decompress_kblock_s8_s8fp(int8_t* srcptr, DST_T* dstptr, int row, int col, int ld_src, int ld_dst) { +inline BTLA_CODE decompress_kblock_s8_s8fp(int8_t* srcptr, DST_T* dstptr, int row, int col, int ld_src, int ld_dst) { if (col == ld_src) { size_t elesize = (size_t)row * col; size_t ele64 = utils::padto_le(elesize, 64); @@ -443,14 +443,14 @@ inline JBLAS_CODE decompress_kblock_s8_s8fp(int8_t* srcptr, DST_T* dstptr, int r auto tmp = srcptr[i]; dstptr[i] = static_cast(static_cast(tmp)); } - return JblasSuccess; + return BTLA_CODE::Success; } - return JblasNotSupport; + return BTLA_CODE::NotSupport; } template -static inline JBLAS_CODE accum_alphaN_f32_f32(const SCA_T* alpha, const float* srcptr, const int srcstep, float* dstptr, - const int dststep, const int M, const int N) { +static inline BTLA_CODE accum_alphaN_f32_f32(const SCA_T* alpha, const float* srcptr, const int srcstep, float* dstptr, + const int dststep, const int M, const int N) { int constexpr Vlen = 8; auto vN = utils::padto_le(N, Vlen); int j = 0; @@ -482,20 +482,20 @@ static inline JBLAS_CODE accum_alphaN_f32_f32(const SCA_T* alpha, const float* s } } } - return JblasSuccess; + return BTLA_CODE::Success; } -template +template static inline void dequant_f4_N(_DST_T* dstptr, int8_t* srcptr, __m256* vscales, __m256i* vzps) { static_assert(N % 8 == 0); float* LUT; - static_assert(F4_T == JBLAS_DTYPE::F4_BNB || F4_T == JBLAS_DTYPE::F4_NF4 || F4_T == JBLAS_DTYPE::F4_E2M1, + static_assert(F4_T == BTLA_DTYPE::F4_BNB || F4_T == BTLA_DTYPE::F4_NF4 || F4_T == BTLA_DTYPE::F4_E2M1, "Unsupported F4 type"); - if constexpr (F4_T == JBLAS_DTYPE::F4_BNB) { + if constexpr (F4_T == BTLA_DTYPE::F4_BNB) { LUT = fp4_bnb_dequant_fp32_LUT; - } else if constexpr (F4_T == JBLAS_DTYPE::F4_NF4) { + } else if constexpr (F4_T == BTLA_DTYPE::F4_NF4) { LUT = nf4_dequant_fp32_LUT; - } else if constexpr (F4_T == JBLAS_DTYPE::F4_E2M1) { + } else if constexpr (F4_T == BTLA_DTYPE::F4_E2M1) { LUT = fp4_e2m1_dequant_fp32_LUT; } int constexpr VLoop = N / 8; @@ -513,17 +513,17 @@ static inline void dequant_f4_N(_DST_T* dstptr, int8_t* srcptr, __m256* vscales, } } -template +template static inline void unpack_f4_N(_DST_T* dstptr, int8_t* srcptr) { static_assert(N % 8 == 0); float* LUT; - static_assert(F4_T == JBLAS_DTYPE::F4_BNB || F4_T == JBLAS_DTYPE::F4_NF4 || F4_T == JBLAS_DTYPE::F4_E2M1, + static_assert(F4_T == BTLA_DTYPE::F4_BNB || F4_T == BTLA_DTYPE::F4_NF4 || F4_T == BTLA_DTYPE::F4_E2M1, "Unsupported F4 type"); - if constexpr (F4_T == JBLAS_DTYPE::F4_BNB) { + if constexpr (F4_T == BTLA_DTYPE::F4_BNB) { LUT = fp4_bnb_dequant_fp32_LUT; - } else if constexpr (F4_T == JBLAS_DTYPE::F4_NF4) { + } else if constexpr (F4_T == BTLA_DTYPE::F4_NF4) { LUT = nf4_dequant_fp32_LUT; - } else if constexpr (F4_T == JBLAS_DTYPE::F4_E2M1) { + } else if constexpr (F4_T == BTLA_DTYPE::F4_E2M1) { LUT = fp4_e2m1_dequant_fp32_LUT; } int constexpr VLoop = N / 8; @@ -540,9 +540,9 @@ static inline void unpack_f4_N(_DST_T* dstptr, int8_t* srcptr) { } } -template -inline JBLAS_CODE decompress_kblock_f4_fp_noscale(utils::f4x2* srcptr, DST_T* dstptr, int row, int col, int ld_src, - int ld_dst, int8_t* tmp, size_t tmpsize) { +template +inline BTLA_CODE decompress_kblock_f4_fp_noscale(utils::f4x2* srcptr, DST_T* dstptr, int row, int col, int ld_src, + int ld_dst, int8_t* tmp, size_t tmpsize) { uint32_t mask = 0xf0f0f0f0; auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); if (col == ld_src) { @@ -559,13 +559,13 @@ inline JBLAS_CODE decompress_kblock_f4_fp_noscale(utils::f4x2* srcptr, DST_T* ds dstptr[i + 0] = static_cast(ref::f4_unpack(tmp.x)); dstptr[i + 1] = static_cast(ref::f4_unpack(tmp.y)); } - return JblasSuccess; + return BTLA_CODE::Success; } - return JblasSuccess; + return BTLA_CODE::Success; } template -static inline JBLAS_CODE decompress_kblock_bit4_packrow1( +static inline BTLA_CODE decompress_kblock_bit4_packrow1( utils::bit4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, _ST* scales, int8_t* zero_points, int k_offset, int kblock, int NPad, void (*dequantize)(_DST_T*, int8_t*, __m256*, __m256i*), void (*pad_bit4_16)(int8_t*, int8_t*), void (*pad_bit4_8)(int8_t*, int8_t*), int8_t* tmpbuf, size_t tmpsize) { @@ -659,23 +659,23 @@ static inline JBLAS_CODE decompress_kblock_bit4_packrow1( dequantize(dstptr + irow * ld_dst, tmpbuf, vscales, vzps); } } - return JblasSuccess; + return BTLA_CODE::Success; } template -static inline JBLAS_CODE decompress_kblock_bit4_packrow2(utils::bit4x2* srcptr, _DST_T* dstptr, int row, int col, - int ld_src, int ld_dst, _ST* scales, int8_t* zero_points, - int k_offset, int kblock, int NPad, - void (*dequantize)(_DST_T*, int8_t*, __m256*, __m256i*), - void (*pad_bit4)(int8_t*, int8_t*), int8_t* tmp, - size_t tmpsize) { - return JblasNotSupport; -} - -template -static inline JBLAS_CODE decompress_kblock_f4_fp(utils::f4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, - int ld_dst, _ST* scales, int k_offset, int kblock, int NPad, - int8_t* tmp, size_t tmpsize) { +static inline BTLA_CODE decompress_kblock_bit4_packrow2(utils::bit4x2* srcptr, _DST_T* dstptr, int row, int col, + int ld_src, int ld_dst, _ST* scales, int8_t* zero_points, + int k_offset, int kblock, int NPad, + void (*dequantize)(_DST_T*, int8_t*, __m256*, __m256i*), + void (*pad_bit4)(int8_t*, int8_t*), int8_t* tmp, + size_t tmpsize) { + return BTLA_CODE::NotSupport; +} + +template +static inline BTLA_CODE decompress_kblock_f4_fp(utils::f4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, + int ld_dst, _ST* scales, int k_offset, int kblock, int NPad, + int8_t* tmp, size_t tmpsize) { if constexpr (_PACK_ROW == 1) { if (col == 24) { return decompress_kblock_bit4_packrow1( @@ -692,7 +692,7 @@ static inline JBLAS_CODE decompress_kblock_f4_fp(utils::f4x2* srcptr, _DST_T* ds k_offset, kblock, NPad, &dequant_f4_N<64, _DST_T, _F4_T>, fp4_pad_4bit, tmp, tmpsize); } - return JblasNotSupport; + return BTLA_CODE::NotSupport; } enum class AVX2_REDUCE_TYPE { MAX, MIN, ADD }; @@ -745,9 +745,9 @@ inline __m128i avx2_cvtepi32_epu8(__m256i x) { } template -static inline JBLAS_CODE quantize_fp_u8_colblock(int row, int col, const SRC_T* srcptr, int ld_src, uint8_t* dstptr, - int ld_dst, float* scales, int ld_scale, uint8_t* zps, int blocksize, - float* blkreduce) { +static inline BTLA_CODE quantize_fp_u8_colblock(int row, int col, const SRC_T* srcptr, int ld_src, uint8_t* dstptr, + int ld_dst, float* scales, int ld_scale, uint8_t* zps, int blocksize, + float* blkreduce) { int constexpr VLen = 8; auto vff = _mm256_set1_epi32(255); auto v0 = _mm256_set1_epi32(0); @@ -763,8 +763,8 @@ static inline JBLAS_CODE quantize_fp_u8_colblock(int row, int col, const SRC_T* __m256 vsrc; if constexpr (std::is_same_v) vsrc = _mm256_loadu_ps(&srcptr[(j + ij) + i * ld_src]); if constexpr (std::is_same_v) { - auto vtmp = _mm_loadu_si128( - reinterpret_cast<__m128i*>(const_cast(&srcptr[(j + ij) + i * ld_src]))); + auto vtmp = + _mm_loadu_si128(reinterpret_cast<__m128i*>(const_cast(&srcptr[(j + ij) + i * ld_src]))); vsrc = ymm_cvt_bf16_fp32(vtmp); } vmaxval = _mm256_max_ps(vmaxval, vsrc); @@ -793,8 +793,8 @@ static inline JBLAS_CODE quantize_fp_u8_colblock(int row, int col, const SRC_T* __m256 vsrc; if constexpr (std::is_same_v) vsrc = _mm256_loadu_ps(&srcptr[(j + ij) + i * ld_src]); if constexpr (std::is_same_v) { - auto vtmp = _mm_loadu_si128( - reinterpret_cast<__m128i*>(const_cast(&srcptr[(j + ij) + i * ld_src]))); + auto vtmp = + _mm_loadu_si128(reinterpret_cast<__m128i*>(const_cast(&srcptr[(j + ij) + i * ld_src]))); vsrc = ymm_cvt_bf16_fp32(vtmp); } vsrc = _mm256_mul_ps(vsrc, vrscale); @@ -811,8 +811,8 @@ static inline JBLAS_CODE quantize_fp_u8_colblock(int row, int col, const SRC_T* __m256 vsrc; if constexpr (std::is_same_v) vsrc = _mm256_loadu_ps(&srcptr[(j + ij) + i * ld_src]); if constexpr (std::is_same_v) { - auto vtmp = _mm_loadu_si128( - reinterpret_cast<__m128i*>(const_cast(&srcptr[(j + ij) + i * ld_src]))); + auto vtmp = + _mm_loadu_si128(reinterpret_cast<__m128i*>(const_cast(&srcptr[(j + ij) + i * ld_src]))); vsrc = ymm_cvt_bf16_fp32(vtmp); } vsrc = _mm256_mul_ps(vsrc, vrscale); @@ -864,12 +864,12 @@ static inline JBLAS_CODE quantize_fp_u8_colblock(int row, int col, const SRC_T* } } } - return JblasSuccess; + return BTLA_CODE::Success; } template -static inline JBLAS_CODE col_block_reduce_sum(const SRC_T* srcptr, int ldsrc, int row, int col, int blocksize, - float* reduce, int ldr) { +static inline BTLA_CODE col_block_reduce_sum(const SRC_T* srcptr, int ldsrc, int row, int col, int blocksize, + float* reduce, int ldr) { int constexpr VLen = 8; auto vblock2_ = utils::padto_le(blocksize, VLen * 2); auto vblock_ = utils::padto_le(blocksize, VLen); @@ -901,11 +901,11 @@ static inline JBLAS_CODE col_block_reduce_sum(const SRC_T* srcptr, int ldsrc, in reduce[i * ldr + j / blocksize] = tmp; } } - return JblasSuccess; + return BTLA_CODE::Success; } -static inline JBLAS_CODE bf16_cvt_fp32_2D_write_back(const utils::bf16* src_ptr, float* dst_ptr, int row, int col, - int src_step, int dst_step, bool zeropadding) { +static inline BTLA_CODE bf16_cvt_fp32_2D_write_back(const utils::bf16* src_ptr, float* dst_ptr, int row, int col, + int src_step, int dst_step, bool zeropadding) { const int npadding = (dst_step - col) * sizeof(float); constexpr int simd_proc_elt = 8; auto col_body = col / simd_proc_elt * simd_proc_elt; @@ -923,7 +923,7 @@ static inline JBLAS_CODE bf16_cvt_fp32_2D_write_back(const utils::bf16* src_ptr, } if (zeropadding && npadding) std::memset(dst + col, 0, npadding); } - return JblasSuccess; + return BTLA_CODE::Success; } static const uint8_t avx2_bf16_convert_maigc_num[32] = { @@ -941,8 +941,8 @@ static inline __m128i cvt_fp32_to_bf16(const __m256 src, __m256i* and_helper, __ return _mm256_castsi256_si128(ordered); } -static inline JBLAS_CODE fp32_cvt_bf16_2D_write_back(const void* raw_srcptr, void* raw_dstptr, int row, int col, - int srcstride, int dststride, bool zeropadding) { +static inline BTLA_CODE fp32_cvt_bf16_2D_write_back(const void* raw_srcptr, void* raw_dstptr, int row, int col, + int srcstride, int dststride, bool zeropadding) { auto srcptr = reinterpret_cast(raw_srcptr); auto dstptr = reinterpret_cast(raw_dstptr); constexpr int simd_proc_elt = 8; @@ -957,16 +957,16 @@ static inline JBLAS_CODE fp32_cvt_bf16_2D_write_back(const void* raw_srcptr, voi for (; j < col_body_loop; j += simd_proc_elt) { auto pack_bf16_value = cvt_fp32_to_bf16(_mm256_loadu_ps(reinterpret_cast(src) + j), &bf16_and_helper, &bf16_add_helper); - _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + j * sizeof(jblas::utils::bf16)), pack_bf16_value); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + j * sizeof(utils::bf16)), pack_bf16_value); } for (; j < col; j++) { - (reinterpret_cast(dst) + j)->fromfloat(*(reinterpret_cast(src) + j)); + (reinterpret_cast(dst) + j)->fromfloat(*(reinterpret_cast(src) + j)); } if (zeropadding && npadding) { std::memset(dst + col * sizeof(utils::bf16), 0, npadding); } } - return JblasSuccess; + return BTLA_CODE::Success; } #ifdef __GNUC__ @@ -976,4 +976,4 @@ static inline JBLAS_CODE fp32_cvt_bf16_2D_write_back(const void* raw_srcptr, voi #endif } // namespace avx2 } // namespace kernel -} // namespace jblas +} // namespace bestla diff --git a/bestla/jblas/kernel_avx512_bf16.h b/bestla/bestla/kernel_avx512_bf16.h similarity index 82% rename from bestla/jblas/kernel_avx512_bf16.h rename to bestla/bestla/kernel_avx512_bf16.h index 70cea4749..453b88afd 100644 --- a/bestla/jblas/kernel_avx512_bf16.h +++ b/bestla/bestla/kernel_avx512_bf16.h @@ -14,17 +14,17 @@ #pragma once #include #include "kernel_avx512f.h" -#include "jit_blas_utils.h" +#include "bestla_utils.h" -namespace jblas { +namespace bestla { namespace kernel { namespace avx512_bf16 { #if CompileBF16() #pragma GCC push_options #pragma GCC target("avx512bf16", "avx512vl", "avx512bw") #endif -static inline JBLAS_CODE bf16_cvt_fp32_2D_write_back(const utils::bf16* src_ptr, float* dst_ptr, int row, int col, - int src_step, int dst_step, bool zeropadding) { +static inline BTLA_CODE bf16_cvt_fp32_2D_write_back(const utils::bf16* src_ptr, float* dst_ptr, int row, int col, + int src_step, int dst_step, bool zeropadding) { #if CompileBF16() const int npadding = (dst_step - col) * sizeof(float); constexpr int simd_proc_elt = 16; @@ -45,13 +45,13 @@ static inline JBLAS_CODE bf16_cvt_fp32_2D_write_back(const utils::bf16* src_ptr, reinterpret_cast<__m512>(_mm512_bslli_epi128(_mm512_cvtepu16_epi32(_mm256_loadu_epi16(src + j)), 2))); if (zeropadding && npadding) std::memset(dst + col, 0, npadding); } - return JblasSuccess; + return BTLA_CODE::Success; #endif return avx512f::bf16_cvt_fp32_2D_write_back(src_ptr, dst_ptr, row, col, src_step, dst_step, zeropadding); } -static inline JBLAS_CODE fp32_cvt_bf16_2D_write_back(const void* raw_srcptr, void* raw_dstptr, int row, int col, - int srcstride, int dststride, bool zeropadding) { +static inline BTLA_CODE fp32_cvt_bf16_2D_write_back(const void* raw_srcptr, void* raw_dstptr, int row, int col, + int srcstride, int dststride, bool zeropadding) { #if CompileBF16() auto srcptr = reinterpret_cast(raw_srcptr); auto dstptr = reinterpret_cast(raw_dstptr); @@ -66,13 +66,13 @@ static inline JBLAS_CODE fp32_cvt_bf16_2D_write_back(const void* raw_srcptr, voi int j = 0; for (; j < col_body_loop; j++) { _mm512_storeu_epi16( - (dst + (j * simd_proc_elt) * sizeof(jblas::utils::bf16)), + (dst + (j * simd_proc_elt) * sizeof(utils::bf16)), (__m512i)_mm512_cvtne2ps_pbh(_mm512_loadu_ps(src + sizeof(float) * simd_proc_elt * j + sizeof(float) * 16), _mm512_loadu_ps(src + sizeof(float) * simd_proc_elt * j + sizeof(float) * 0))); } if (col_tail > 0) { _mm512_mask_storeu_epi16( - (dst + (j * simd_proc_elt) * sizeof(jblas::utils::bf16)), tail_mask, // + (dst + (j * simd_proc_elt) * sizeof(utils::bf16)), tail_mask, // (__m512i)_mm512_cvtne2ps_pbh( _mm512_maskz_loadu_ps(tail_mask >> 16, src + sizeof(float) * simd_proc_elt * j + sizeof(float) * 16), _mm512_maskz_loadu_ps(tail_mask >> 0, src + sizeof(float) * simd_proc_elt * j + sizeof(float) * 0))); @@ -81,6 +81,7 @@ static inline JBLAS_CODE fp32_cvt_bf16_2D_write_back(const void* raw_srcptr, voi std::memset(dst + col * sizeof(utils::bf16), 0, npadding); } } + return BTLA_CODE::Success; #endif return avx512f::fp32_cvt_bf16_2D_write_back(raw_srcptr, raw_dstptr, row, col, srcstride, dststride, zeropadding); } @@ -89,4 +90,4 @@ static inline JBLAS_CODE fp32_cvt_bf16_2D_write_back(const void* raw_srcptr, voi #endif } // namespace avx512_bf16 } // namespace kernel -} // namespace jblas +} // namespace bestla diff --git a/bestla/jblas/kernel_avx512f.h b/bestla/bestla/kernel_avx512f.h similarity index 85% rename from bestla/jblas/kernel_avx512f.h rename to bestla/bestla/kernel_avx512f.h index 503287ff6..2e5c3f3e3 100644 --- a/bestla/jblas/kernel_avx512f.h +++ b/bestla/bestla/kernel_avx512f.h @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. #pragma once -#include "jit_blas_utils.h" +#include "bestla_utils.h" #include "kernel_ref.h" #include @@ -23,7 +23,7 @@ #include #endif -namespace jblas { +namespace bestla { namespace kernel { namespace avx512f { #if CompileAVX512F() @@ -62,11 +62,11 @@ static inline __m512i unpack_4bits(__m256i v4bits, __m512i vmask) { return zmm1; } -template +template static inline void convert_s4_s8(int8_t* dstptr, int8_t* srcptr, __m512i vmask, int LoadMask) { auto ymm = _mm256_maskz_loadu_epi32(__mmask8(LoadMask), reinterpret_cast(srcptr)); auto zmm = unpack_4bits(ymm, vmask); - if constexpr (S4_T == JBLAS_DTYPE::S4_FULLRANGE) { + if constexpr (S4_T == BTLA_DTYPE::S4_FULLRANGE) { zmm = _mm512_srli_epi32(zmm, 4); auto s8 = _mm512_set1_epi8(8); zmm = _mm512_sub_epi8(zmm, s8); @@ -74,12 +74,12 @@ static inline void convert_s4_s8(int8_t* dstptr, int8_t* srcptr, __m512i vmask, _mm512_mask_storeu_epi64(dstptr, __mmask8(LoadMask), zmm); } -template +template static inline void convert_s4_s8_v32(int8_t* dstptr, int8_t* srcptr, __m512i vmask, int LoadMask) { auto xmm = _mm_maskz_loadu_epi32(__mmask8(LoadMask), reinterpret_cast(srcptr)); auto ymm = _mm256_castsi128_si256(xmm); auto zmm = unpack_4bits(ymm, vmask); - if constexpr (S4_T == JBLAS_DTYPE::S4_FULLRANGE) { + if constexpr (S4_T == BTLA_DTYPE::S4_FULLRANGE) { zmm = _mm512_srli_epi32(zmm, 4); auto s8 = _mm512_set1_epi8(8); zmm = _mm512_sub_epi8(zmm, s8); @@ -101,7 +101,7 @@ static inline void convert_s8_fp_v16(T* dstptr, int8_t* srcptr) { } } -constexpr void (*pad_fp4)(int8_t* dstptr, int8_t* srcptr, __m512i vmask, int) = &convert_s4_s8; +constexpr void (*pad_fp4)(int8_t* dstptr, int8_t* srcptr, __m512i vmask, int) = &convert_s4_s8; template static inline void dequant_s8_N(_DST_T* dstptr, int8_t* srcptr, __m512* vscales, __m512i* vzps = nullptr) { @@ -124,18 +124,18 @@ static inline void dequant_s8_N(_DST_T* dstptr, int8_t* srcptr, __m512* vscales, } } -template +template static inline void dequant_f4_N(_DST_T* dstptr, int8_t* srcptr, __m512* vscales, __m512i* vzps = nullptr) { static_assert(N % 16 == 0); int constexpr VLoop = N / 16; float* LUT; - static_assert(F4_T == JBLAS_DTYPE::F4_BNB || F4_T == JBLAS_DTYPE::F4_NF4 || F4_T == JBLAS_DTYPE::F4_E2M1, + static_assert(F4_T == BTLA_DTYPE::F4_BNB || F4_T == BTLA_DTYPE::F4_NF4 || F4_T == BTLA_DTYPE::F4_E2M1, "Unsupported F4 type"); - if constexpr (F4_T == JBLAS_DTYPE::F4_BNB) { + if constexpr (F4_T == BTLA_DTYPE::F4_BNB) { LUT = fp4_bnb_dequant_fp32_LUT; - } else if constexpr (F4_T == JBLAS_DTYPE::F4_NF4) { + } else if constexpr (F4_T == BTLA_DTYPE::F4_NF4) { LUT = nf4_dequant_fp32_LUT; - } else if constexpr (F4_T == JBLAS_DTYPE::F4_E2M1) { + } else if constexpr (F4_T == BTLA_DTYPE::F4_E2M1) { LUT = fp4_e2m1_dequant_fp32_LUT; } for (int iv = 0; iv < VLoop; iv += 1) { @@ -156,18 +156,18 @@ static inline void dequant_f4_N(_DST_T* dstptr, int8_t* srcptr, __m512* vscales, } } -template +template static inline void unpack_f4_N(_DST_T* dstptr, int8_t* srcptr) { static_assert(N % 16 == 0); int constexpr VLoop = N / 16; float* LUT; - static_assert(F4_T == JBLAS_DTYPE::F4_BNB || F4_T == JBLAS_DTYPE::F4_NF4 || F4_T == JBLAS_DTYPE::F4_E2M1, + static_assert(F4_T == BTLA_DTYPE::F4_BNB || F4_T == BTLA_DTYPE::F4_NF4 || F4_T == BTLA_DTYPE::F4_E2M1, "Unsupported F4 type"); - if constexpr (F4_T == JBLAS_DTYPE::F4_BNB) { + if constexpr (F4_T == BTLA_DTYPE::F4_BNB) { LUT = fp4_bnb_dequant_fp32_LUT; - } else if constexpr (F4_T == JBLAS_DTYPE::F4_NF4) { + } else if constexpr (F4_T == BTLA_DTYPE::F4_NF4) { LUT = nf4_dequant_fp32_LUT; - } else if constexpr (F4_T == JBLAS_DTYPE::F4_E2M1) { + } else if constexpr (F4_T == BTLA_DTYPE::F4_E2M1) { LUT = fp4_e2m1_dequant_fp32_LUT; } for (int iv = 0; iv < VLoop; iv += 1) { @@ -228,12 +228,12 @@ static inline void vec_broadcast_epi32_2_4(__m512i* dst4regs, __m512i* src2regs) } template -static inline JBLAS_CODE decompress_kblock_bit4_packrow1(utils::bit4x2* srcptr, _DT* dstptr, int row, int col, - int ld_src, int ld_dst, _ST* scales, int8_t* zero_points, - int k_offset, int kblock, int NPad, - void (*dequantize)(_DT*, int8_t*, __m512*, __m512i*), - void (*pad_bit4)(int8_t*, int8_t*, __m512i, int), - int8_t* tmpbuf, size_t tmpsize) { +static inline BTLA_CODE decompress_kblock_bit4_packrow1(utils::bit4x2* srcptr, _DT* dstptr, int row, int col, + int ld_src, int ld_dst, _ST* scales, int8_t* zero_points, + int k_offset, int kblock, int NPad, + void (*dequantize)(_DT*, int8_t*, __m512*, __m512i*), + void (*pad_bit4)(int8_t*, int8_t*, __m512i, int), + int8_t* tmpbuf, size_t tmpsize) { uint32_t mask = 0xf0f0f0f0; auto zmm_mask = _mm512_set1_epi32(*reinterpret_cast(&mask)); if (col == 48) { @@ -345,22 +345,22 @@ static inline JBLAS_CODE decompress_kblock_bit4_packrow1(utils::bit4x2* srcptr, } } } - return JblasSuccess; + return BTLA_CODE::Success; } - return JblasNotSupport; + return BTLA_CODE::NotSupport; } -template -static inline JBLAS_CODE decompress_kblock_bit4_packrow2(utils::bit4x2* srcptr, _DT* dstptr, int row, int col, - int ld_src, int ld_dst, _ST* scales, int8_t* zero_points, - int k_offset, int kblock, int NPad, int8_t* tmpbuf, - size_t tmpsize) { +template +static inline BTLA_CODE decompress_kblock_bit4_packrow2(utils::bit4x2* srcptr, _DT* dstptr, int row, int col, + int ld_src, int ld_dst, _ST* scales, int8_t* zero_points, + int k_offset, int kblock, int NPad, int8_t* tmpbuf, + size_t tmpsize) { uint32_t mask = 0xf0f0f0f0; auto zmm_mask = _mm512_set1_epi32(*reinterpret_cast(&mask)); auto broadcast_idx = _mm512_setr_epi64(0, 4, 1, 5, 2, 6, 3, 7); auto broadcast_idx_128 = _mm_setr_epi16(0, 1, 2, 3, 4, 5, 6, 7); auto constexpr SRC_TYPE = - static_cast(utils::jblas_dtype_get_mask_val(_SRCT, JBLAS_DTYPE::TypeMask, JBLAS_DTYPE::TypeShift)); + static_cast(utils::bestla_dtype_get_mask_val(_SRCT, BTLA_DTYPE::TypeMask, BTLA_DTYPE::TypeShift)); if (col % 64 == 0) { constexpr int ColTile = 64; constexpr int NRegs = ColTile / 16; @@ -388,9 +388,9 @@ static inline JBLAS_CODE decompress_kblock_bit4_packrow2(utils::bit4x2* srcptr, } for (; irow < row0; irow++) { - if constexpr (SRC_TYPE == JBLAS_DTYPE::TypeFloat) { - convert_s4_s8( - tmpbuf, reinterpret_cast(srcptr + irow * ld_src / 2 + icol / 2), zmm_mask, LoadMask64); + if constexpr (SRC_TYPE == BTLA_DTYPE::TypeFloat) { + convert_s4_s8(tmpbuf, reinterpret_cast(srcptr + irow * ld_src / 2 + icol / 2), + zmm_mask, LoadMask64); dequant_f4_N(dstptr + irow * ld_dst + icol, tmpbuf, vscales, vzps); } else { convert_s4_s8<_SRCT>(tmpbuf, reinterpret_cast(srcptr + irow * ld_src / 2 + icol / 2), zmm_mask, @@ -414,8 +414,8 @@ static inline JBLAS_CODE decompress_kblock_bit4_packrow2(utils::bit4x2* srcptr, } for (int irr = 0; irr < kblock; irr += 1) { - if constexpr (SRC_TYPE == JBLAS_DTYPE::TypeFloat) { - convert_s4_s8( + if constexpr (SRC_TYPE == BTLA_DTYPE::TypeFloat) { + convert_s4_s8( tmpbuf, reinterpret_cast(srcptr + (irow + irr) * ld_src / 2 + icol / 2), zmm_mask, LoadMask64); dequant_f4_N(dstptr + (irow + irr) * ld_dst + icol, tmpbuf, vscales, vzps); } else { @@ -438,9 +438,9 @@ static inline JBLAS_CODE decompress_kblock_bit4_packrow2(utils::bit4x2* srcptr, } } for (; irow < row; irow++) { - if constexpr (SRC_TYPE == JBLAS_DTYPE::TypeFloat) { - convert_s4_s8(tmpbuf, reinterpret_cast(srcptr + irow * ld_src / 2 + icol / 2), - zmm_mask, LoadMask64); + if constexpr (SRC_TYPE == BTLA_DTYPE::TypeFloat) { + convert_s4_s8(tmpbuf, reinterpret_cast(srcptr + irow * ld_src / 2 + icol / 2), + zmm_mask, LoadMask64); dequant_f4_N(dstptr + irow * ld_dst + icol, tmpbuf, vscales, vzps); } else { convert_s4_s8<_SRCT>(tmpbuf, reinterpret_cast(srcptr + irow * ld_src / 2 + icol / 2), zmm_mask, @@ -449,7 +449,7 @@ static inline JBLAS_CODE decompress_kblock_bit4_packrow2(utils::bit4x2* srcptr, } } } - return JblasSuccess; + return BTLA_CODE::Success; } else if (col % 96 == 0) { constexpr int ColTile = 96; constexpr int NRegs = ColTile / 16; @@ -477,10 +477,10 @@ static inline JBLAS_CODE decompress_kblock_bit4_packrow2(utils::bit4x2* srcptr, } for (; irow < row0; irow++) { - if constexpr (SRC_TYPE == JBLAS_DTYPE::TypeFloat) { - convert_s4_s8( - tmpbuf, reinterpret_cast(srcptr + irow * ld_src / 2 + icol / 2), zmm_mask, LoadMask64); - convert_s4_s8_v32( + if constexpr (SRC_TYPE == BTLA_DTYPE::TypeFloat) { + convert_s4_s8(tmpbuf, reinterpret_cast(srcptr + irow * ld_src / 2 + icol / 2), + zmm_mask, LoadMask64); + convert_s4_s8_v32( tmpbuf + 64, reinterpret_cast(srcptr + irow * ld_src / 2 + icol / 2 + 32), zmm_mask, LoadMask64); dequant_f4_N(dstptr + irow * ld_dst + icol, tmpbuf, vscales, vzps); @@ -508,10 +508,10 @@ static inline JBLAS_CODE decompress_kblock_bit4_packrow2(utils::bit4x2* srcptr, } for (int irr = 0; irr < kblock; irr += 1) { - if constexpr (SRC_TYPE == JBLAS_DTYPE::TypeFloat) { - convert_s4_s8( + if constexpr (SRC_TYPE == BTLA_DTYPE::TypeFloat) { + convert_s4_s8( tmpbuf, reinterpret_cast(srcptr + (irow + irr) * ld_src / 2 + icol / 2), zmm_mask, LoadMask64); - convert_s4_s8_v32( + convert_s4_s8_v32( tmpbuf + 64, reinterpret_cast(srcptr + (irow + irr) * ld_src / 2 + icol / 2 + 32), zmm_mask, LoadMask64); dequant_f4_N(dstptr + (irow + irr) * ld_dst + icol, tmpbuf, vscales, vzps); @@ -538,10 +538,10 @@ static inline JBLAS_CODE decompress_kblock_bit4_packrow2(utils::bit4x2* srcptr, } } for (; irow < row; irow++) { - if constexpr (SRC_TYPE == JBLAS_DTYPE::TypeFloat) { - convert_s4_s8(tmpbuf, reinterpret_cast(srcptr + irow * ld_src / 2 + icol / 2), - zmm_mask, LoadMask64); - convert_s4_s8_v32( + if constexpr (SRC_TYPE == BTLA_DTYPE::TypeFloat) { + convert_s4_s8(tmpbuf, reinterpret_cast(srcptr + irow * ld_src / 2 + icol / 2), + zmm_mask, LoadMask64); + convert_s4_s8_v32( tmpbuf + 64, reinterpret_cast(srcptr + irow * ld_src / 2 + icol / 2 + 32), zmm_mask, LoadMask64); dequant_f4_N(dstptr + irow * ld_dst + icol, tmpbuf, vscales, vzps); } else { @@ -554,17 +554,17 @@ static inline JBLAS_CODE decompress_kblock_bit4_packrow2(utils::bit4x2* srcptr, } } - return JblasSuccess; + return BTLA_CODE::Success; } - return JblasNotSupport; + return BTLA_CODE::NotSupport; } template -inline JBLAS_CODE decompress_kblock_f8_fp(utils::f8* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, - _S_T* scales, int k_offset, int kblock, int NPad, JBLAS_DTYPE src_f8_type) { +inline BTLA_CODE decompress_kblock_f8_fp(utils::f8* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, + _S_T* scales, int k_offset, int kblock, int NPad, BTLA_DTYPE src_f8_type) { int align_col = col / 16 * 16; int col_tail = col - align_col; - auto ebits = utils::jblas_dtype_get_f8_ebits(src_f8_type); + auto ebits = utils::bestla_dtype_get_f8_ebits(src_f8_type); auto mantissabit = 7 - ebits; auto sign_revert_and_mask = _mm512_set1_epi32(0x80000000); auto e_revert_and_mask = _mm512_set1_epi32(0x0000007f); @@ -614,13 +614,13 @@ inline JBLAS_CODE decompress_kblock_f8_fp(utils::f8* srcptr, _DST_T* dstptr, int for (; j < align_col; j += 16) quant(_cvtu32_mask16(0xffff)); if (col_tail > 0) quant(_cvtu32_mask16(0xffff >> (16 - col_tail))); } - return JblasSuccess; + return BTLA_CODE::Success; } -template -static inline JBLAS_CODE decompress_kblock_s4_fp(utils::int4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, - int ld_dst, _ST* scales, int8_t* zero_points, int k_offset, int kblock, - int NPad, int8_t* tmp, size_t tmpsize) { +template +static inline BTLA_CODE decompress_kblock_s4_fp(utils::int4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, + int ld_dst, _ST* scales, int8_t* zero_points, int k_offset, int kblock, + int NPad, int8_t* tmp, size_t tmpsize) { if constexpr (_PACK_ROW == 1) { if (zero_points == nullptr) { return decompress_kblock_bit4_packrow1<_ST, _DST_T, true>( @@ -640,13 +640,13 @@ static inline JBLAS_CODE decompress_kblock_s4_fp(utils::int4x2* srcptr, _DST_T* srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, k_offset, kblock, NPad, tmp, tmpsize); } } - return JblasNotSupport; + return BTLA_CODE::NotSupport; } -template -static inline JBLAS_CODE decompress_kblock_f4_fp(utils::f4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, - int ld_dst, _ST* scales, int k_offset, int kblock, int NPad, - int8_t* tmp, size_t tmpsize) { +template +static inline BTLA_CODE decompress_kblock_f4_fp(utils::f4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, + int ld_dst, _ST* scales, int k_offset, int kblock, int NPad, + int8_t* tmp, size_t tmpsize) { if constexpr (_PACK_ROW == 1) { return decompress_kblock_bit4_packrow1<_ST, _DST_T, true>(srcptr, dstptr, row, col, ld_src, ld_dst, scales, nullptr, k_offset, kblock, NPad, &dequant_f4_N<48, _DST_T, _F4_T>, @@ -655,12 +655,12 @@ static inline JBLAS_CODE decompress_kblock_f4_fp(utils::f4x2* srcptr, _DST_T* ds return decompress_kblock_bit4_packrow2<_F4_T, _ST, _DST_T, true>(srcptr, dstptr, row, col, ld_src, ld_dst, scales, nullptr, k_offset, kblock, NPad, tmp, tmpsize); } - return JblasNotSupport; + return BTLA_CODE::NotSupport; } -template -inline JBLAS_CODE decompress_kblock_f4_fp_noscale(utils::f4x2* srcptr, DST_T* dstptr, int row, int col, int ld_src, - int ld_dst, int8_t* tmp, size_t tmpsize) { +template +inline BTLA_CODE decompress_kblock_f4_fp_noscale(utils::f4x2* srcptr, DST_T* dstptr, int row, int col, int ld_src, + int ld_dst, int8_t* tmp, size_t tmpsize) { uint32_t mask = 0xf0f0f0f0; auto zmm_mask = _mm512_set1_epi32(*reinterpret_cast(&mask)); if (col == ld_src) { @@ -690,14 +690,14 @@ inline JBLAS_CODE decompress_kblock_f4_fp_noscale(utils::f4x2* srcptr, DST_T* ds dstptr[i + 0] = static_cast(ref::f4_unpack(tmp.x)); dstptr[i + 1] = static_cast(ref::f4_unpack(tmp.y)); } - return JblasSuccess; + return BTLA_CODE::Success; } - return JblasNotSupport; + return BTLA_CODE::NotSupport; } -template -static inline JBLAS_CODE decompress_s4_s8(utils::int4x2* srcptr, int8_t* dstptr, int row, int col, int ld_src, - int ld_dst) { +template +static inline BTLA_CODE decompress_s4_s8(utils::int4x2* srcptr, int8_t* dstptr, int row, int col, int ld_src, + int ld_dst) { uint32_t mask = 0xf0f0f0f0; auto zmm_mask = _mm512_set1_epi32(*reinterpret_cast(&mask)); if (col == ld_src) { @@ -719,16 +719,16 @@ static inline JBLAS_CODE decompress_s4_s8(utils::int4x2* srcptr, int8_t* dstptr, } for (; i < elesize; i += 2) { auto tmp = srcptr[i / 2]; - dstptr[i + 0] = jblas::kernel::ref::get_s8(tmp.x); - dstptr[i + 1] = jblas::kernel::ref::get_s8(tmp.y); + dstptr[i + 0] = kernel::ref::get_s8(tmp.x); + dstptr[i + 1] = kernel::ref::get_s8(tmp.y); } - return JblasSuccess; + return BTLA_CODE::Success; } - return JblasNotSupport; + return BTLA_CODE::NotSupport; } -static inline JBLAS_CODE quantize_f32_sign_int_rowblock_sym(const float* srcptr, int8_t* dstptr, int row, int col, - int ld_src, int ld_dst, float* scales, int blocksize) { +static inline BTLA_CODE quantize_f32_sign_int_rowblock_sym(const float* srcptr, int8_t* dstptr, int row, int col, + int ld_src, int ld_dst, float* scales, int blocksize) { int constexpr VLen = 16; auto v127 = _mm512_set1_ps(127.f); int col16 = utils::padto_le(col, 16); @@ -775,12 +775,12 @@ static inline JBLAS_CODE quantize_f32_sign_int_rowblock_sym(const float* srcptr, for (; j < align_row; j += blocksize) scalar_process_block(blocksize); if (j < row) scalar_process_block(row - align_row); } - return JblasSuccess; + return BTLA_CODE::Success; } -static inline JBLAS_CODE quantize_f32_sign_int_rowblock_asym(const float* srcptr, int8_t* dstptr, int row, int col, - int ld_src, int ld_dst, float* scales, int8_t* zero_points, - int blocksize) { +static inline BTLA_CODE quantize_f32_sign_int_rowblock_asym(const float* srcptr, int8_t* dstptr, int row, int col, + int ld_src, int ld_dst, float* scales, int8_t* zero_points, + int blocksize) { int constexpr VLen = 16; auto v255 = _mm512_set1_ps(255.f); auto v2 = _mm512_set1_ps(2.f); @@ -842,13 +842,13 @@ static inline JBLAS_CODE quantize_f32_sign_int_rowblock_asym(const float* srcptr for (; j < align_row; j += blocksize) scalar_process_block(blocksize); if (j < row) scalar_process_block(row - align_row); } - return JblasSuccess; + return BTLA_CODE::Success; } -template -static inline JBLAS_CODE quantize_f32_sign_int_rowblock(const float* srcptr, int8_t* dstptr, int row, int col, - int ld_src, int ld_dst, float* scales, int8_t* zero_points, - int blocksize) { +template +static inline BTLA_CODE quantize_f32_sign_int_rowblock(const float* srcptr, int8_t* dstptr, int row, int col, + int ld_src, int ld_dst, float* scales, int8_t* zero_points, + int blocksize) { if (zero_points == nullptr) return quantize_f32_sign_int_rowblock_sym(srcptr, dstptr, row, col, ld_src, ld_dst, scales, blocksize); else @@ -878,7 +878,7 @@ constexpr auto broadcast_N_2_Nx16(const int8_t* arr) { return broadcast_N_2_Nx16(arr, std::make_index_sequence{}); } -template +template inline void f32_f4_quantize_4x16(const float* srcptr, int8_t* dstptr, int ld_src, int ld_dst, const int8_t* broadcast_f4_v, float* scales, __mmask16 ls_mask) { __m128i xmm0{}, xmm1{}, xmm2{}, xmm3{}; @@ -895,7 +895,7 @@ inline void f32_f4_quantize_4x16(const float* srcptr, int8_t* dstptr, int ld_src zmm1 = _mm512_mul_ps(zmm1, zmm_scale); zmm2 = _mm512_mul_ps(zmm2, zmm_scale); zmm3 = _mm512_mul_ps(zmm3, zmm_scale); - if constexpr (F4_T == JBLAS_DTYPE::F4_NF4) { + if constexpr (F4_T == BTLA_DTYPE::F4_NF4) { auto zmm_zp = _mm512_set1_ps(0.8480964004993439f); zmm0 = _mm512_add_ps(zmm0, zmm_zp); zmm1 = _mm512_add_ps(zmm1, zmm_zp); @@ -912,12 +912,12 @@ inline void f32_f4_quantize_4x16(const float* srcptr, int8_t* dstptr, int ld_src zmm2 = _mm512_abs_ps(zmm2); zmm3 = _mm512_abs_ps(zmm3); } - constexpr int loop_num = F4_T == JBLAS_DTYPE::F4_NF4 ? 16 : 8; + constexpr int loop_num = F4_T == BTLA_DTYPE::F4_NF4 ? 16 : 8; for (int i = 0; i < loop_num; i++) { __m512 sub_v; - if constexpr (F4_T == JBLAS_DTYPE::F4_NF4) sub_v = _mm512_set1_ps(F4_NF4_quant_sub_helper[i]); - if constexpr (F4_T == JBLAS_DTYPE::F4_BNB) sub_v = _mm512_set1_ps(F4_BNB_quant_sub_helper[i]); - if constexpr (F4_T == JBLAS_DTYPE::F4_E2M1) sub_v = _mm512_set1_ps(F4_E2M1_quant_sub_helper[i]); + if constexpr (F4_T == BTLA_DTYPE::F4_NF4) sub_v = _mm512_set1_ps(F4_NF4_quant_sub_helper[i]); + if constexpr (F4_T == BTLA_DTYPE::F4_BNB) sub_v = _mm512_set1_ps(F4_BNB_quant_sub_helper[i]); + if constexpr (F4_T == BTLA_DTYPE::F4_E2M1) sub_v = _mm512_set1_ps(F4_E2M1_quant_sub_helper[i]); zmm4 = _mm512_sub_ps(zmm0, sub_v); zmm5 = _mm512_sub_ps(zmm1, sub_v); zmm6 = _mm512_sub_ps(zmm2, sub_v); @@ -935,7 +935,7 @@ inline void f32_f4_quantize_4x16(const float* srcptr, int8_t* dstptr, int ld_src zmm2 = _mm512_mask_add_ps(zmm2, mask2, zmm2, avoid_double_cmp); zmm3 = _mm512_mask_add_ps(zmm3, mask3, zmm3, avoid_double_cmp); } - if constexpr (F4_T != JBLAS_DTYPE::F4_NF4) { + if constexpr (F4_T != BTLA_DTYPE::F4_NF4) { auto xmm_bias = _mm_set1_epi8(0x08); xmm0 = _mm_mask_add_epi8(xmm0, mask4, xmm0, xmm_bias); xmm1 = _mm_mask_add_epi8(xmm1, mask5, xmm1, xmm_bias); @@ -948,7 +948,7 @@ inline void f32_f4_quantize_4x16(const float* srcptr, int8_t* dstptr, int ld_src _mm_mask_storeu_epi8(dstptr + 3 * ld_dst, ls_mask, xmm3); } -template +template inline void f32_f4_quantize_1x16(const float* srcptr, int8_t* dstptr, int ld_src, int ld_dst, const int8_t* broadcast_f4_v, float* scales, __mmask16 ls_mask) { __m512 zmm0{}, zmm1, zmm_scale{}; @@ -959,25 +959,25 @@ inline void f32_f4_quantize_1x16(const float* srcptr, int8_t* dstptr, int ld_src __mmask16 mask0, mask1; zmm0 = _mm512_mask_loadu_ps(zmm0, ls_mask, srcptr); zmm0 = _mm512_mul_ps(zmm0, zmm_scale); - if constexpr (F4_T == JBLAS_DTYPE::F4_NF4) { + if constexpr (F4_T == BTLA_DTYPE::F4_NF4) { auto zp = _mm512_set1_ps(0.8480964004993439f); zmm0 = _mm512_add_ps(zmm0, zp); } else { mask1 = _mm512_cmp_ps_mask(zmm0, zmm_v0, 1); zmm0 = _mm512_abs_ps(zmm0); } - constexpr int loop_num = F4_T == JBLAS_DTYPE::F4_NF4 ? 16 : 8; + constexpr int loop_num = F4_T == BTLA_DTYPE::F4_NF4 ? 16 : 8; for (int i = 0; i < loop_num; i++) { __m512 sub_v; - if constexpr (F4_T == JBLAS_DTYPE::F4_NF4) sub_v = _mm512_set1_ps(F4_NF4_quant_sub_helper[i]); - if constexpr (F4_T == JBLAS_DTYPE::F4_BNB) sub_v = _mm512_set1_ps(F4_BNB_quant_sub_helper[i]); - if constexpr (F4_T == JBLAS_DTYPE::F4_E2M1) sub_v = _mm512_set1_ps(F4_E2M1_quant_sub_helper[i]); + if constexpr (F4_T == BTLA_DTYPE::F4_NF4) sub_v = _mm512_set1_ps(F4_NF4_quant_sub_helper[i]); + if constexpr (F4_T == BTLA_DTYPE::F4_BNB) sub_v = _mm512_set1_ps(F4_BNB_quant_sub_helper[i]); + if constexpr (F4_T == BTLA_DTYPE::F4_E2M1) sub_v = _mm512_set1_ps(F4_E2M1_quant_sub_helper[i]); zmm1 = _mm512_sub_ps(zmm0, sub_v); mask0 = _mm512_cmp_ps_mask(zmm1, zmm_v0, 2); xmm0 = _mm_mask_blend_epi8(mask0, xmm0, _mm_loadu_si128(reinterpret_cast(broadcast_f4_v + i * 16))); zmm0 = _mm512_mask_add_ps(zmm0, mask0, zmm0, avoid_double_cmp); } - if constexpr (F4_T != JBLAS_DTYPE::F4_NF4) { + if constexpr (F4_T != BTLA_DTYPE::F4_NF4) { auto xmm_bias = _mm_set1_epi8(0x08); xmm0 = _mm_mask_add_epi8(xmm0, mask1, xmm0, xmm_bias); } @@ -997,17 +997,16 @@ constexpr auto broadcast_F4_NF4_quantv = broadcast_N_2_Nx16<16>(F4_NF4_simd_quan constexpr auto broadcast_F4_BNB_quantv = broadcast_N_2_Nx16<8>(F4_BNB_simd_quant_v); constexpr auto broadcast_F4_E2M1_quantv = broadcast_N_2_Nx16<8>(F4_E2M1_simd_quant_v); -template -inline JBLAS_CODE quantize_f32_f4_rowblock(const float* srcptr, int8_t* dstptr, int row, int col, int ld_src, - int ld_dst, float* scales, int8_t* zero_points, int blocksize) { +template +inline BTLA_CODE quantize_f32_f4_rowblock(const float* srcptr, int8_t* dstptr, int row, int col, int ld_src, int ld_dst, + float* scales, int8_t* zero_points, int blocksize) { // assert(col % 16 == 0); auto align_row = row / blocksize * blocksize; auto align_blk = blocksize / 4 * 4; int8_t* broadcast_f4_quantv; - if constexpr (F4_T == JBLAS_DTYPE::F4_NF4) broadcast_f4_quantv = const_cast(broadcast_F4_NF4_quantv.data()); - if constexpr (F4_T == JBLAS_DTYPE::F4_BNB) broadcast_f4_quantv = const_cast(broadcast_F4_BNB_quantv.data()); - if constexpr (F4_T == JBLAS_DTYPE::F4_E2M1) - broadcast_f4_quantv = const_cast(broadcast_F4_E2M1_quantv.data()); + if constexpr (F4_T == BTLA_DTYPE::F4_NF4) broadcast_f4_quantv = const_cast(broadcast_F4_NF4_quantv.data()); + if constexpr (F4_T == BTLA_DTYPE::F4_BNB) broadcast_f4_quantv = const_cast(broadcast_F4_BNB_quantv.data()); + if constexpr (F4_T == BTLA_DTYPE::F4_E2M1) broadcast_f4_quantv = const_cast(broadcast_F4_E2M1_quantv.data()); int i = 0; int align_col = col / 16 * 16; @@ -1045,13 +1044,13 @@ inline JBLAS_CODE quantize_f32_f4_rowblock(const float* srcptr, int8_t* dstptr, for (; i < align_col; i += 16) process_row_blk(i, 16); if (i < col) process_row_blk(i, col - i); - return JblasSuccess; + return BTLA_CODE::Success; } template -static inline JBLAS_CODE quantize_fp_u8_colblock(int row, int col, const SRC_T* srcptr, int ld_src, uint8_t* dstptr, - int ld_dst, float* scales, int ld_scale, uint8_t* zps, int blocksize, - float* blkreduce) { +static inline BTLA_CODE quantize_fp_u8_colblock(int row, int col, const SRC_T* srcptr, int ld_src, uint8_t* dstptr, + int ld_dst, float* scales, int ld_scale, uint8_t* zps, int blocksize, + float* blkreduce) { int constexpr VLen = 16; auto vff = _mm512_set1_epi32(255); auto v0 = _mm512_set1_epi32(0); @@ -1153,13 +1152,12 @@ static inline JBLAS_CODE quantize_fp_u8_colblock(int row, int col, const SRC_T* } } } - return JblasSuccess; + return BTLA_CODE::Success; } template -static inline JBLAS_CODE quantize_fp_s8_colblock(int row, int col, const SRC_T* srcptr, int ld_src, int8_t* dstptr, - int ld_dst, float* scales, int ld_scale, int blocksize, - float* reduce) { +static inline BTLA_CODE quantize_fp_s8_colblock(int row, int col, const SRC_T* srcptr, int ld_src, int8_t* dstptr, + int ld_dst, float* scales, int ld_scale, int blocksize, float* reduce) { int constexpr VLen = 16; auto vpos = _mm512_set1_epi32(127); auto vneg = _mm512_set1_epi32(-128); @@ -1238,12 +1236,12 @@ static inline JBLAS_CODE quantize_fp_s8_colblock(int row, int col, const SRC_T* if (reduce) reduce[j / blocksize + i * ld_scale] = sum * scale; } } - return JblasSuccess; + return BTLA_CODE::Success; } -static inline JBLAS_CODE alphabeta_f32_f32(const float alpha, const float* srcptr, const int srcstep, const float beta, - const float* src1ptr, const int src1step, float* dstptr, const int dststep, - const int M, const int N) { +static inline BTLA_CODE alphabeta_f32_f32(const float alpha, const float* srcptr, const int srcstep, const float beta, + const float* src1ptr, const int src1step, float* dstptr, const int dststep, + const int M, const int N) { int constexpr Vlen = 16; auto vN = utils::padto_le(N, Vlen); auto valpha = _mm512_set1_ps(alpha); @@ -1273,12 +1271,12 @@ static inline JBLAS_CODE alphabeta_f32_f32(const float alpha, const float* srcpt } } } - return JblasSuccess; + return BTLA_CODE::Success; } -template -inline JBLAS_CODE decompress_kblock_s4_s8fp(utils::int4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, - int ld_dst, int8_t* tmp, size_t tmpsize) { +template +inline BTLA_CODE decompress_kblock_s4_s8fp(utils::int4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, + int ld_dst, int8_t* tmp, size_t tmpsize) { uint32_t mask = 0xf0f0f0f0; auto zmm_mask = _mm512_set1_epi32(*reinterpret_cast(&mask)); if (col == ld_src) { @@ -1307,16 +1305,16 @@ inline JBLAS_CODE decompress_kblock_s4_s8fp(utils::int4x2* srcptr, _DST_T* dstpt } for (; i < elesize; i += 2) { auto tmp = srcptr[i / 2]; - dstptr[i + 0] = static_cast<_DST_T>(static_cast(jblas::kernel::ref::get_s8(tmp.x))); - dstptr[i + 1] = static_cast<_DST_T>(static_cast(jblas::kernel::ref::get_s8(tmp.y))); + dstptr[i + 0] = static_cast<_DST_T>(static_cast(kernel::ref::get_s8(tmp.x))); + dstptr[i + 1] = static_cast<_DST_T>(static_cast(kernel::ref::get_s8(tmp.y))); } - return JblasSuccess; + return BTLA_CODE::Success; } - return JblasNotSupport; + return BTLA_CODE::NotSupport; } template -inline JBLAS_CODE decompress_kblock_s8_s8fp(int8_t* srcptr, DST_T* dstptr, int row, int col, int ld_src, int ld_dst) { +inline BTLA_CODE decompress_kblock_s8_s8fp(int8_t* srcptr, DST_T* dstptr, int row, int col, int ld_src, int ld_dst) { if (col == ld_src) { size_t elesize = (size_t)row * col; size_t ele64 = utils::padto_le(elesize, 64); @@ -1332,14 +1330,14 @@ inline JBLAS_CODE decompress_kblock_s8_s8fp(int8_t* srcptr, DST_T* dstptr, int r auto tmp = srcptr[i]; dstptr[i] = static_cast(static_cast(tmp)); } - return JblasSuccess; + return BTLA_CODE::Success; } - return JblasNotSupport; + return BTLA_CODE::NotSupport; } template -static inline JBLAS_CODE accum_alphaN_f32_f32(const SCA_T* alpha, const float* srcptr, const int srcstep, float* dstptr, - const int dststep, const int M, const int N) { +static inline BTLA_CODE accum_alphaN_f32_f32(const SCA_T* alpha, const float* srcptr, const int srcstep, float* dstptr, + const int dststep, const int M, const int N) { int constexpr Vlen = 16; auto vN = utils::padto_le(N, Vlen); int j = 0; @@ -1371,11 +1369,11 @@ static inline JBLAS_CODE accum_alphaN_f32_f32(const SCA_T* alpha, const float* s } } } - return JblasSuccess; + return BTLA_CODE::Success; } -static inline JBLAS_CODE accum_f32_f32(const float* srcptr, const int srcstep, float* dstptr, const int dststep, - const int M, const int N) { +static inline BTLA_CODE accum_f32_f32(const float* srcptr, const int srcstep, float* dstptr, const int dststep, + const int M, const int N) { int constexpr Vlen = 16; auto vN = utils::padto_le(N, Vlen); int j = 0; @@ -1392,7 +1390,7 @@ static inline JBLAS_CODE accum_f32_f32(const float* srcptr, const int srcstep, f dstptr[i * dststep + j] += srcptr[i * srcstep + j]; } } - return JblasSuccess; + return BTLA_CODE::Success; } static inline void vec_quanout_s32_u32_v16(const int32_t* srcptr, __m512& vfactor, __m512i& vzp, __m512i& vzeros, @@ -1407,9 +1405,9 @@ static inline void vec_quanout_s32_u32_v16(const int32_t* srcptr, __m512& vfacto _mm_storeu_si128(reinterpret_cast<__m128i*>(dstptr), vdstb); } -static inline JBLAS_CODE quanout_s32_u32(const float alpha, const int32_t* srcptr, const int srcstep, uint8_t* dstptr, - const int dststep, const int M, const int N, float scaleSrc, float scaleDst, - int zpDst) { +static inline BTLA_CODE quanout_s32_u32(const float alpha, const int32_t* srcptr, const int srcstep, uint8_t* dstptr, + const int dststep, const int M, const int N, float scaleSrc, float scaleDst, + int zpDst) { float factor = alpha * scaleSrc / scaleDst; auto vfactor = _mm512_set1_ps(factor); auto vzp = _mm512_set1_epi32(zpDst); @@ -1444,12 +1442,12 @@ static inline JBLAS_CODE quanout_s32_u32(const float alpha, const int32_t* srcpt dstptr[i * dststep + j] = utils::cast(fsrc + static_cast(zpDst)); } } - return JblasSuccess; + return BTLA_CODE::Success; } -static inline JBLAS_CODE accumulate_dequantize_s32_f32(const int32_t* srcptr, float* dstptr, float alpha, float beta, - int row, int col, int ld_src, int ld_dst, float* ascales, - int ldas, float* wscales) { +static inline BTLA_CODE accumulate_dequantize_s32_f32(const int32_t* srcptr, float* dstptr, float alpha, float beta, + int row, int col, int ld_src, int ld_dst, float* ascales, + int ldas, float* wscales) { auto vbeta = _mm512_set1_ps(beta); int col16 = utils::padto_le(col, 16); for (int irow = 0; irow < row; irow++) { @@ -1471,13 +1469,13 @@ static inline JBLAS_CODE accumulate_dequantize_s32_f32(const int32_t* srcptr, fl scale * wscales[icol] * srcptr[irow * ld_src + icol] + beta * dstptr[irow * ld_dst + icol]; } } - return JblasSuccess; + return BTLA_CODE::Success; } template -static inline JBLAS_CODE dequant_s32_fp32(const int32_t* srcptr, const int srcstep, float* dstptr, const int dststep, - const int row, const int col, const float* scaleA, const int ldsa, - const SCAB_T* scaleB) { +static inline BTLA_CODE dequant_s32_fp32(const int32_t* srcptr, const int srcstep, float* dstptr, const int dststep, + const int row, const int col, const float* scaleA, const int ldsa, + const SCAB_T* scaleB) { int col16 = utils::padto_le(col, 16); int col64 = utils::padto_le(col, 64); for (int irow = 0; irow < row; irow++) { @@ -1520,10 +1518,10 @@ static inline JBLAS_CODE dequant_s32_fp32(const int32_t* srcptr, const int srcst dstptr[irow * dststep + icol] = scale * scaleB[icol] * srcptr[irow * srcstep + icol]; } } - return JblasSuccess; + return BTLA_CODE::Success; } -static inline JBLAS_CODE broadcast_u8(int num, const uint8_t& srcval, uint8_t* dstptr) { +static inline BTLA_CODE broadcast_u8(int num, const uint8_t& srcval, uint8_t* dstptr) { int i = 0; int constexpr VN = 64 / sizeof(srcval); int numv = utils::padto_le(num, VN); @@ -1540,11 +1538,11 @@ static inline JBLAS_CODE broadcast_u8(int num, const uint8_t& srcval, uint8_t* d for (; i < num; i++) { dstptr[i] = srcval; } - return JblasSuccess; + return BTLA_CODE::Success; } -static inline JBLAS_CODE remove_act_zeropoint_bias(float* accptr, int ldacc, int row, int col, uint8_t* zps, - float* scales, int lds, const float* reduce) { +static inline BTLA_CODE remove_act_zeropoint_bias(float* accptr, int ldacc, int row, int col, uint8_t* zps, + float* scales, int lds, const float* reduce) { int constexpr VLen = 16; auto col16 = utils::padto_le(col, VLen); for (int i = 0; i < row; i++) { @@ -1563,11 +1561,11 @@ static inline JBLAS_CODE remove_act_zeropoint_bias(float* accptr, int ldacc, int } } } - return JblasSuccess; + return BTLA_CODE::Success; } -static inline JBLAS_CODE remove_wei_zeropoint_bias(float* accptr, int ldacc, int row, int col, int8_t* zps, - float* scales, int lds, const float* reduce) { +static inline BTLA_CODE remove_wei_zeropoint_bias(float* accptr, int ldacc, int row, int col, int8_t* zps, + float* scales, int lds, const float* reduce) { int constexpr VLen = 16; auto col16 = utils::padto_le(col, VLen); for (int i = 0; i < row; i++) { @@ -1587,12 +1585,12 @@ static inline JBLAS_CODE remove_wei_zeropoint_bias(float* accptr, int ldacc, int } } } - return JblasSuccess; + return BTLA_CODE::Success; } -static inline JBLAS_CODE remove_zeropoint_bias(float* accptr, int ldacc, int row, int col, uint8_t* zpa, int8_t* zpb, - float* scalea, float* scaleb, int lds, int k, const float* reducea, - const float* reduceb) { +static inline BTLA_CODE remove_zeropoint_bias(float* accptr, int ldacc, int row, int col, uint8_t* zpa, int8_t* zpb, + float* scalea, float* scaleb, int lds, int k, const float* reducea, + const float* reduceb) { int constexpr VLen = 16; auto col16 = utils::padto_le(col, VLen); auto vk = _mm512_set1_ps(static_cast(k)); @@ -1622,11 +1620,11 @@ static inline JBLAS_CODE remove_zeropoint_bias(float* accptr, int ldacc, int row } } } - return JblasSuccess; + return BTLA_CODE::Success; } -static inline JBLAS_CODE fp32_cvt_bf16_2D_write_back(const void* raw_srcptr, void* raw_dstptr, int row, int col, - int srcstride, int dststride, bool zeropadding) { +static inline BTLA_CODE fp32_cvt_bf16_2D_write_back(const void* raw_srcptr, void* raw_dstptr, int row, int col, + int srcstride, int dststride, bool zeropadding) { auto srcptr = reinterpret_cast(raw_srcptr); auto dstptr = reinterpret_cast(raw_dstptr); constexpr int simd_proc_elt = 16; @@ -1646,8 +1644,7 @@ static inline JBLAS_CODE fp32_cvt_bf16_2D_write_back(const void* raw_srcptr, voi round_bias = _mm512_add_epi32(round_bias, bf16_add_helper); auto round_fp32_v = _mm512_add_epi32(round_bias, _mm512_loadu_si512(src + sizeof(float) * simd_proc_elt * j)); auto pack_bf16_value = _mm512_cvtepi32_epi16(_mm512_srli_epi32(round_fp32_v, 16)); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst + (j * simd_proc_elt) * sizeof(jblas::utils::bf16)), - pack_bf16_value); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst + (j * simd_proc_elt) * sizeof(utils::bf16)), pack_bf16_value); } if (col_tail > 0) { auto round_bias = _mm512_maskz_loadu_epi32(tail_mask, src + sizeof(float) * simd_proc_elt * j); @@ -1656,19 +1653,19 @@ static inline JBLAS_CODE fp32_cvt_bf16_2D_write_back(const void* raw_srcptr, voi auto round_fp32_v = _mm512_add_epi32(round_bias, _mm512_maskz_loadu_epi32(tail_mask, src + sizeof(float) * simd_proc_elt * j)); auto pack_bf16_tail = _mm512_cvtepi32_epi16(_mm512_srli_epi32(round_fp32_v, 16)); - _mm256_mask_storeu_epi16(reinterpret_cast<__m256i*>(dst + (j * simd_proc_elt) * sizeof(jblas::utils::bf16)), - tail_mask, pack_bf16_tail); + _mm256_mask_storeu_epi16(reinterpret_cast<__m256i*>(dst + (j * simd_proc_elt) * sizeof(utils::bf16)), tail_mask, + pack_bf16_tail); } if (zeropadding && npadding) { std::memset(dst + col * sizeof(utils::bf16), 0, npadding); } } - return JblasSuccess; + return BTLA_CODE::Success; } template -static inline JBLAS_CODE col_block_reduce_sum(const SRC_T* srcptr, int ldsrc, int row, int col, int blocksize, - float* reduce, int ldr) { +static inline BTLA_CODE col_block_reduce_sum(const SRC_T* srcptr, int ldsrc, int row, int col, int blocksize, + float* reduce, int ldr) { int constexpr VLen = 16; auto vblock2_ = utils::padto_le(blocksize, VLen * 2); auto vblock_ = utils::padto_le(blocksize, VLen); @@ -1700,11 +1697,11 @@ static inline JBLAS_CODE col_block_reduce_sum(const SRC_T* srcptr, int ldsrc, in reduce[i * ldr + j / blocksize] = tmp; } } - return JblasSuccess; + return BTLA_CODE::Success; } -static inline JBLAS_CODE fp32_cvt_fp16_2D_write_back(const float* src_ptr, utils::fp16* dst_ptr, int row, int col, - int src_step, int dst_step, bool zeropadding) { +static inline BTLA_CODE fp32_cvt_fp16_2D_write_back(const float* src_ptr, utils::fp16* dst_ptr, int row, int col, + int src_step, int dst_step, bool zeropadding) { #if CompileFP16() const int npadding = (dst_step - col) * sizeof(utils::fp16); constexpr int simd_proc_elt = 16; @@ -1724,14 +1721,14 @@ static inline JBLAS_CODE fp32_cvt_fp16_2D_write_back(const float* src_ptr, utils } if (zeropadding && npadding) std::memset(dst + col, 0, npadding); } - return JblasSuccess; + return BTLA_CODE::Success; #else - return JblasNotSupport; + return BTLA_CODE::NotSupport; #endif } -static inline JBLAS_CODE fp16_cvt_fp32_2D_write_back(const utils::fp16* src_ptr, float* dst_ptr, int row, int col, - int src_step, int dst_step, bool zeropadding) { +static inline BTLA_CODE fp16_cvt_fp32_2D_write_back(const utils::fp16* src_ptr, float* dst_ptr, int row, int col, + int src_step, int dst_step, bool zeropadding) { #if CompileFP16() const int npadding = (dst_step - col) * sizeof(float); constexpr int simd_proc_elt = 16; @@ -1751,14 +1748,14 @@ static inline JBLAS_CODE fp16_cvt_fp32_2D_write_back(const utils::fp16* src_ptr, } if (zeropadding && npadding) std::memset(dst + col, 0, npadding); } - return JblasSuccess; + return BTLA_CODE::Success; #else - return JblasNotSupport; + return BTLA_CODE::NotSupport; #endif } -static inline JBLAS_CODE bf16_cvt_fp32_2D_write_back(const utils::bf16* src_ptr, float* dst_ptr, int row, int col, - int src_step, int dst_step, bool zeropadding) { +static inline BTLA_CODE bf16_cvt_fp32_2D_write_back(const utils::bf16* src_ptr, float* dst_ptr, int row, int col, + int src_step, int dst_step, bool zeropadding) { const int npadding = (dst_step - col) * sizeof(float); constexpr int simd_proc_elt = 16; auto col_body = col / simd_proc_elt * simd_proc_elt; @@ -1780,7 +1777,7 @@ static inline JBLAS_CODE bf16_cvt_fp32_2D_write_back(const utils::bf16* src_ptr, _mm512_cvtepu16_epi32(_mm256_castps_si256(_mm256_loadu_ps(reinterpret_cast(src + j)))), 2))); if (zeropadding && npadding) std::memset(dst + col, 0, npadding); } - return JblasSuccess; + return BTLA_CODE::Success; } #ifdef __GNUC__ @@ -1946,9 +1943,9 @@ static constexpr decltype(load_maskz_fp16_bf16_tr_x16_dword<1>)* load_maskz_fp16 template struct padding_interleave_cvt { padding_interleave_cvt() = delete; - static JBLAS_CODE forward(const T_SRC* src, T_DST* dst, int NTile, int row, int col, int row_pad, int col_pad, - int src_step, int dst_step) { - return JblasNotSupport; + static BTLA_CODE forward(const T_SRC* src, T_DST* dst, int NTile, int row, int col, int row_pad, int col_pad, + int src_step, int dst_step) { + return BTLA_CODE::NotSupport; } }; #if CompileBF16() && CompileFP16() @@ -1958,8 +1955,8 @@ struct padding_interleave_cvt { padding_interleave_cvt() = delete; // M x N ===> N/NTile x M/RowPack x NTile x RowPack (leading dim stride = NTile * dststride) - static JBLAS_CODE forward(const utils::fp16* src, utils::bf16* dst, int NTile, int row, int col, int row_pad, - int col_pad, int src_step, int dst_step) { + static BTLA_CODE forward(const utils::fp16* src, utils::bf16* dst, int NTile, int row, int col, int row_pad, + int col_pad, int src_step, int dst_step) { int i = 0; for (; i < row / RowPack * RowPack; i += RowPack) { int j = 0; @@ -2034,7 +2031,7 @@ struct padding_interleave_cvt { memset(dst + i * NTile + j * dst_step, 0, sizeof(utils::bf16) * NTile * RowPack); } } - return JblasSuccess; + return BTLA_CODE::Success; } }; #endif @@ -2042,9 +2039,9 @@ struct padding_interleave_cvt { template struct padding_trans_interleave_cvt { padding_trans_interleave_cvt() = delete; - static JBLAS_CODE forward(const T_SRC* src, T_DST* dst, int MTile, int row, int col, int row_pad, int col_pad, - int src_step, int dst_step) { - return JblasNotSupport; + static BTLA_CODE forward(const T_SRC* src, T_DST* dst, int MTile, int row, int col, int row_pad, int col_pad, + int src_step, int dst_step) { + return BTLA_CODE::NotSupport; } }; #if CompileBF16() && CompileFP16() @@ -2053,8 +2050,8 @@ struct padding_trans_interleave_cvt { static constexpr int ColPack = 2; padding_trans_interleave_cvt() = delete; - static JBLAS_CODE forward(const utils::fp16* src, utils::bf16* dst, int MTile, int row, int col, int row_pad, - int col_pad, int src_step, int dst_step) { + static BTLA_CODE forward(const utils::fp16* src, utils::bf16* dst, int MTile, int row, int col, int row_pad, + int col_pad, int src_step, int dst_step) { assert(row_pad % 16 == 0 && col_pad % 32 == 0); int i = 0; for (; i < row / MTile * MTile; i += MTile) { @@ -2147,7 +2144,7 @@ struct padding_trans_interleave_cvt { memset(dst + i * dst_step + j * MTile, 0, 2 * sizeof(utils::bf16) * MTile); } } - return JblasSuccess; + return BTLA_CODE::Success; } }; #endif @@ -2159,4 +2156,4 @@ struct padding_trans_interleave_cvt { #endif } // namespace avx512f } // namespace kernel -} // namespace jblas +} // namespace bestla diff --git a/bestla/jblas/kernel_jit.h b/bestla/bestla/kernel_jit.h similarity index 90% rename from bestla/jblas/kernel_jit.h rename to bestla/bestla/kernel_jit.h index 4a711736e..718c47e25 100644 --- a/bestla/jblas/kernel_jit.h +++ b/bestla/bestla/kernel_jit.h @@ -20,20 +20,19 @@ #include #include -#include "jblas/jit_blas.h" -#include "jblas/jit_blas_device.h" -#include "jblas/xbyak/xbyak.h" -#include "jit_base.h" -#include "jit_blas_utils.h" +#include "bestla.h" +#include "bestla_device.h" +#include "bestla_utils.h" +#include "bestla_jit.h" #include "kernel_jit_injector.h" -namespace jblas { +namespace bestla { namespace kernel { namespace jit { class DequanS8FP { public: - class MicroKernelAVX512F : protected jblas::xbyak::JitAvx512f { + class MicroKernelAVX512F : protected xbyak::JitAvx512f { public: struct params { void *srcptr, *dstptr; @@ -47,15 +46,15 @@ class DequanS8FP { static int constexpr RegScale = 0; static int constexpr RegZP = 4; static int constexpr RegTmp = RegScale + 8; - MicroKernelAVX512F(JBLAS_DTYPE dst_dt, bool is_sym_, int pack_row) { - assert(dst_dt == JBLAS_DTYPE::F32 || dst_dt == JBLAS_DTYPE::BF16); + MicroKernelAVX512F(BTLA_DTYPE dst_dt, bool is_sym_, int pack_row) { + assert(dst_dt == BTLA_DTYPE::F32 || dst_dt == BTLA_DTYPE::BF16); is_sym = is_sym_; generate(dst_dt, pack_row); this->ready(); mKernel = this->getCode(); } - void generate(JBLAS_DTYPE dst_dt, int pack_row) { + void generate(BTLA_DTYPE dst_dt, int pack_row) { assert(pack_row == 1 || pack_row == 2 || pack_row == 4); int scale_step = 64 / pack_row; Xbyak::Label data_label; @@ -101,11 +100,11 @@ class DequanS8FP { } auto get_dst_step = [&] { - if (dst_dt == JBLAS_DTYPE::BF16) return 2; + if (dst_dt == BTLA_DTYPE::BF16) return 2; return 4; // f32 case. }; - auto generateNTile = [&](int N, JBLAS_DTYPE dst_dt, int scale_step, std::string row_label) { + auto generateNTile = [&](int N, BTLA_DTYPE dst_dt, int scale_step, std::string row_label) { if (pack_row == 2) { vmovups(Xbyak::Zmm(RegTmp), ptr[rip + data_label + 8]); } else if (pack_row == 4) { @@ -131,10 +130,10 @@ class DequanS8FP { } vcvtdq2ps(Xbyak::Zmm(RegTmp), Xbyak::Zmm(RegTmp)); vmulps(Xbyak::Zmm(RegTmp), Xbyak::Zmm(RegScale + i)); - if (dst_dt == JBLAS_DTYPE::F32) { + if (dst_dt == BTLA_DTYPE::F32) { vmovups(ptr[reg_tmp1 + i * 64], Xbyak::Zmm(RegTmp)); } - if (dst_dt == JBLAS_DTYPE::BF16) { + if (dst_dt == BTLA_DTYPE::BF16) { Xbyak::Ymm ymm_v = Xbyak::Ymm(RegTmp); Xbyak::Zmm zmm_v = Xbyak::Zmm(RegTmp); if (device::CpuDevice::getInstance()->AVX512_BF16()) { @@ -230,8 +229,8 @@ class DequanS8FP { template static void forward_avx512f(int8_t* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, float* scales, int8_t* zero_points) { - static MicroKernelAVX512F mAVX512FSym(utils::jblas_dtype<_DST_T>, true, PACK_ROW); - static MicroKernelAVX512F mAVX512FASym(utils::jblas_dtype<_DST_T>, false, PACK_ROW); + static MicroKernelAVX512F mAVX512FSym(utils::bestla_dtype<_DST_T>, true, PACK_ROW); + static MicroKernelAVX512F mAVX512FASym(utils::bestla_dtype<_DST_T>, false, PACK_ROW); auto param = MicroKernelAVX512F::params{srcptr, dstptr, row, @@ -251,8 +250,8 @@ class DequanS8FP { class DequanKBlockS8Fp { public: template - static inline JBLAS_CODE forward_avx512f(int8_t* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, - _ST* scales, int8_t* zero_points, int k_offset, int kblock, int NPad) { + static inline BTLA_CODE forward_avx512f(int8_t* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, + _ST* scales, int8_t* zero_points, int k_offset, int kblock, int NPad) { int row0 = kblock - k_offset % kblock; row0 = row0 == kblock ? 0 : row0; row0 = row0 > row ? row : row0; @@ -279,7 +278,7 @@ class DequanKBlockS8Fp { if (row2 > 0) { DequanS8FP::forward_avx512f<_PACK_ROW>(srcptr, dstptr, row2, col, ld_src, ld_dst, sptr, zptr); } - return JblasSuccess; + return BTLA_CODE::Success; } }; @@ -292,25 +291,24 @@ struct DataConvertConfig { FP32_TO_F16, }; - DataConvertConfig(JBLAS_DTYPE src_t, JBLAS_DTYPE dst_t, - std::vector injectors) { + DataConvertConfig(BTLA_DTYPE src_t, BTLA_DTYPE dst_t, std::vector injectors) { input_dt = src_t; output_dt = dst_t; if (injectors.size() != 0) { - assert(src_t == JBLAS_DTYPE::F32 || src_t == JBLAS_DTYPE::BF16 || src_t == JBLAS_DTYPE::F16); - if (src_t == JBLAS_DTYPE::BF16) before_postop = DataConvertConfig::cvt_direct::BF16_TO_FP32; - if (src_t == JBLAS_DTYPE::F16) before_postop = DataConvertConfig::cvt_direct::F16_TO_FP32; + assert(src_t == BTLA_DTYPE::F32 || src_t == BTLA_DTYPE::BF16 || src_t == BTLA_DTYPE::F16); + if (src_t == BTLA_DTYPE::BF16) before_postop = DataConvertConfig::cvt_direct::BF16_TO_FP32; + if (src_t == BTLA_DTYPE::F16) before_postop = DataConvertConfig::cvt_direct::F16_TO_FP32; } // once contain postop, data_type before store will be fp32. - if (injectors.size() != 0 || src_t == JBLAS_DTYPE::F32) { - if (dst_t == JBLAS_DTYPE::BF16) before_store = DataConvertConfig::cvt_direct::FP32_TO_BF16; - if (dst_t == JBLAS_DTYPE::F16) { + if (injectors.size() != 0 || src_t == BTLA_DTYPE::F32) { + if (dst_t == BTLA_DTYPE::BF16) before_store = DataConvertConfig::cvt_direct::FP32_TO_BF16; + if (dst_t == BTLA_DTYPE::F16) { if (!device::CpuDevice::getInstance()->AVX512_FP16()) assert(0); before_store = DataConvertConfig::cvt_direct::FP32_TO_F16; } - } else if (src_t == JBLAS_DTYPE::BF16 && dst_t == JBLAS_DTYPE::F32) { + } else if (src_t == BTLA_DTYPE::BF16 && dst_t == BTLA_DTYPE::F32) { before_store = DataConvertConfig::cvt_direct::BF16_TO_FP32; - } else if (src_t == JBLAS_DTYPE::F16 && dst_t == JBLAS_DTYPE::F32) { + } else if (src_t == BTLA_DTYPE::F16 && dst_t == BTLA_DTYPE::F32) { assert(device::CpuDevice::getInstance()->AVX512_FP16()); before_store = DataConvertConfig::cvt_direct::F16_TO_FP32; } @@ -328,7 +326,7 @@ struct DataConvertConfig { cvt_direct before_postop = cvt_direct::NO_CVT; cvt_direct before_store = cvt_direct::NO_CVT; - JBLAS_DTYPE input_dt, output_dt; + BTLA_DTYPE input_dt, output_dt; }; template @@ -339,7 +337,7 @@ struct MemcpyStoreParam { Xbyak::Opmask store_mask = Xbyak::util::k1; }; -class JitMemcpy2DAvx2 : protected jblas::xbyak::JitAvx2 { +class JitMemcpy2DAvx2 : protected xbyak::JitAvx2 { public: struct params { void *srcptr, *dstptr, *elt_const_v; @@ -349,29 +347,29 @@ class JitMemcpy2DAvx2 : protected jblas::xbyak::JitAvx2 { public: static int constexpr VBytes = 32; - JitMemcpy2DAvx2(int unroll_row, JBLAS_DTYPE src_t, JBLAS_DTYPE dst_t, + JitMemcpy2DAvx2(int unroll_row, BTLA_DTYPE src_t, BTLA_DTYPE dst_t, std::vector injectors = {}) { DataConvertConfig dt_cvt_cfg(src_t, dst_t, injectors); generate(unroll_row, injectors, dt_cvt_cfg); } template - static JBLAS_CODE forward(const _SRC_T* srcptr, _DST_T* dstptr, int row, int col, int srcstep, int dststep, - void* elt_const_v = nullptr) { - static JitMemcpy2DAvx2 instance_withops(1, utils::jblas_dtype<_SRC_T>, utils::jblas_dtype<_DST_T>); + static BTLA_CODE forward(const _SRC_T* srcptr, _DST_T* dstptr, int row, int col, int srcstep, int dststep, + void* elt_const_v = nullptr) { + static JitMemcpy2DAvx2 instance_withops(1, utils::bestla_dtype<_SRC_T>, utils::bestla_dtype<_DST_T>); for (int i = 0; i < row; i++) { auto param = params{reinterpret_cast(const_cast<_SRC_T*>(srcptr)) + i * srcstep * sizeof(_SRC_T), reinterpret_cast(dstptr) + i * dststep * sizeof(_DST_T), elt_const_v, static_cast(col * sizeof(_SRC_T))}; instance_withops.mKernel(¶m); } - return JblasSuccess; + return BTLA_CODE::Success; } - template - static JBLAS_CODE forward1(const _SRC_T* srcptr, _DST_T* dstptr, int row, int col, int srcstep, int dststep, - void* elt_const_v = nullptr) { - static JitMemcpy2DAvx2 instance_withops(1, utils::jblas_dtype<_SRC_T>, utils::jblas_dtype<_DST_T>, + template + static BTLA_CODE forward1(const _SRC_T* srcptr, _DST_T* dstptr, int row, int col, int srcstep, int dststep, + void* elt_const_v = nullptr) { + static JitMemcpy2DAvx2 instance_withops(1, utils::bestla_dtype<_SRC_T>, utils::bestla_dtype<_DST_T>, {kernel::jit_injector::eltwise_injector(Op)}); for (int i = 0; i < row; i++) { auto param = params{reinterpret_cast(const_cast<_SRC_T*>(srcptr)) + i * srcstep * sizeof(_SRC_T), @@ -379,7 +377,7 @@ class JitMemcpy2DAvx2 : protected jblas::xbyak::JitAvx2 { static_cast(col * sizeof(_SRC_T))}; instance_withops.mKernel(¶m); } - return JblasSuccess; + return BTLA_CODE::Success; } protected: @@ -428,16 +426,16 @@ class JitMemcpy2DAvx2 : protected jblas::xbyak::JitAvx2 { auto unpack_ymm_16bit_withfunc = [&](MemcpyStoreParam p, std::function)> func, - JBLAS_DTYPE BIT16_DT) { + BTLA_DTYPE BIT16_DT) { vmovups(ymm_tmps[0], p.vmm_v); Xbyak::Ymm ymm_v = Xbyak::Ymm(p.vmm_v.getIdx()); - if (BIT16_DT == JBLAS_DTYPE::BF16) { + if (BIT16_DT == BTLA_DTYPE::BF16) { vpmovzxwd(p.vmm_v, ymm_v); vpslld(p.vmm_v, p.vmm_v, 16); } func(p); vextractf128(Xbyak::Xmm(ymm_tmps[0].getIdx()), ymm_tmps[0], 1); - if (BIT16_DT == JBLAS_DTYPE::BF16) { + if (BIT16_DT == BTLA_DTYPE::BF16) { vpmovzxwd(ymm_tmps[0], Xbyak::Ymm(ymm_tmps[0].getIdx())); vpslld(ymm_tmps[0], ymm_tmps[0], 16); } @@ -451,7 +449,7 @@ class JitMemcpy2DAvx2 : protected jblas::xbyak::JitAvx2 { if (dt_cvt_cfg.before_store == DataConvertConfig::cvt_direct::NO_CVT) { store_ymm_v(p); } else if (dt_cvt_cfg.before_store == DataConvertConfig::cvt_direct::BF16_TO_FP32) { - unpack_ymm_16bit_withfunc(p, store_ymm_v, JBLAS_DTYPE::BF16); + unpack_ymm_16bit_withfunc(p, store_ymm_v, BTLA_DTYPE::BF16); } else if (dt_cvt_cfg.before_store == DataConvertConfig::cvt_direct::FP32_TO_BF16) { Xbyak::Xmm xmm_v = Xbyak::Xmm(p.vmm_v.getIdx()); Xbyak::Xmm xmm_tmp = Xbyak::Xmm(ymm_tmps[1].getIdx()); @@ -479,7 +477,7 @@ class JitMemcpy2DAvx2 : protected jblas::xbyak::JitAvx2 { if (dt_cvt_cfg.before_postop == DataConvertConfig::cvt_direct::NO_CVT) { apply_postop_and_store({ymm_v, store_addr}); } else if (dt_cvt_cfg.before_postop == DataConvertConfig::cvt_direct::BF16_TO_FP32) { - unpack_ymm_16bit_withfunc({ymm_v, store_addr}, apply_postop_and_store, JBLAS_DTYPE::BF16); + unpack_ymm_16bit_withfunc({ymm_v, store_addr}, apply_postop_and_store, BTLA_DTYPE::BF16); } else { assert(0); } @@ -549,7 +547,7 @@ class JitMemcpy2DAvx2 : protected jblas::xbyak::JitAvx2 { std::set used_ymm_idx; }; -class JitMemcpy2DAvx512f : protected jblas::xbyak::JitAvx512f { +class JitMemcpy2DAvx512f : protected xbyak::JitAvx512f { public: struct params { void *srcptr, *dstptr, *elt_const_v; @@ -559,16 +557,16 @@ class JitMemcpy2DAvx512f : protected jblas::xbyak::JitAvx512f { public: static int constexpr VBytes = 64; - JitMemcpy2DAvx512f(int unroll_row, JBLAS_DTYPE src_t, JBLAS_DTYPE dst_t, + JitMemcpy2DAvx512f(int unroll_row, BTLA_DTYPE src_t, BTLA_DTYPE dst_t, std::vector injectors = {}) { DataConvertConfig dt_cvt_cfg(src_t, dst_t, injectors); generate(unroll_row, injectors, dt_cvt_cfg); } template - static JBLAS_CODE forward(const _SRC_T* srcptr, _DST_T* dstptr, int row, int col, int srcstep, int dststep, - void* elt_const_v = nullptr) { - static JitMemcpy2DAvx512f instance_withops(1, utils::jblas_dtype<_SRC_T>, utils::jblas_dtype<_DST_T>); + static BTLA_CODE forward(const _SRC_T* srcptr, _DST_T* dstptr, int row, int col, int srcstep, int dststep, + void* elt_const_v = nullptr) { + static JitMemcpy2DAvx512f instance_withops(1, utils::bestla_dtype<_SRC_T>, utils::bestla_dtype<_DST_T>); for (int i = 0; i < row; i++) { auto param = params{reinterpret_cast(const_cast<_SRC_T*>(srcptr)) + i * srcstep * sizeof(_SRC_T), @@ -576,13 +574,13 @@ class JitMemcpy2DAvx512f : protected jblas::xbyak::JitAvx512f { static_cast(col * sizeof(_SRC_T))}; instance_withops.mKernel(¶m); } - return JblasSuccess; + return BTLA_CODE::Success; } - template - static JBLAS_CODE forward1(const _SRC_T* srcptr, _DST_T* dstptr, int row, int col, int srcstep, int dststep, - void* elt_const_v = nullptr) { - static JitMemcpy2DAvx512f instance_withops(1, utils::jblas_dtype<_SRC_T>, utils::jblas_dtype<_DST_T>, + template + static BTLA_CODE forward1(const _SRC_T* srcptr, _DST_T* dstptr, int row, int col, int srcstep, int dststep, + void* elt_const_v = nullptr) { + static JitMemcpy2DAvx512f instance_withops(1, utils::bestla_dtype<_SRC_T>, utils::bestla_dtype<_DST_T>, {kernel::jit_injector::eltwise_injector(Op)}); for (int i = 0; i < row; i++) { auto param = params{reinterpret_cast(const_cast<_SRC_T*>(srcptr)) + i * srcstep * sizeof(_SRC_T), @@ -590,7 +588,7 @@ class JitMemcpy2DAvx512f : protected jblas::xbyak::JitAvx512f { static_cast(col * sizeof(_SRC_T))}; instance_withops.mKernel(¶m); } - return JblasSuccess; + return BTLA_CODE::Success; } protected: @@ -634,22 +632,22 @@ class JitMemcpy2DAvx512f : protected jblas::xbyak::JitAvx512f { auto unpack_zmm_16bit_withfunc = [&](MemcpyStoreParam p, std::function)> func, - JBLAS_DTYPE BIT16_DT) { + BTLA_DTYPE BIT16_DT) { vmovups(zmm_tmps[0], p.vmm_v); Xbyak::Ymm ymm_v = Xbyak::Ymm(p.vmm_v.getIdx()); - if (BIT16_DT == JBLAS_DTYPE::BF16) { + if (BIT16_DT == BTLA_DTYPE::BF16) { vpmovzxwd(p.vmm_v, ymm_v); vpslld(p.vmm_v, p.vmm_v, 16); } - if (BIT16_DT == JBLAS_DTYPE::F16) vcvtph2psx(p.vmm_v, ymm_v); + if (BIT16_DT == BTLA_DTYPE::F16) vcvtph2psx(p.vmm_v, ymm_v); p.store_mask = k3; func(p); vextractf32x8(Xbyak::Ymm(zmm_tmps[0].getIdx()), zmm_tmps[0], 1); - if (BIT16_DT == JBLAS_DTYPE::BF16) { + if (BIT16_DT == BTLA_DTYPE::BF16) { vpmovzxwd(zmm_tmps[0], Xbyak::Ymm(zmm_tmps[0].getIdx())); vpslld(zmm_tmps[0], zmm_tmps[0], 16); } - if (BIT16_DT == JBLAS_DTYPE::F16) vcvtph2psx(zmm_tmps[0], Xbyak::Ymm(zmm_tmps[0].getIdx())); + if (BIT16_DT == BTLA_DTYPE::F16) vcvtph2psx(zmm_tmps[0], Xbyak::Ymm(zmm_tmps[0].getIdx())); p.vmm_v = zmm_tmps[0]; p.store_addr = p.store_addr + VBytes; p.store_mask = k4; @@ -661,9 +659,9 @@ class JitMemcpy2DAvx512f : protected jblas::xbyak::JitAvx512f { if (dt_cvt_cfg.before_store == DataConvertConfig::cvt_direct::NO_CVT) { store_zmm_v(p); } else if (dt_cvt_cfg.before_store == DataConvertConfig::cvt_direct::BF16_TO_FP32) { - unpack_zmm_16bit_withfunc(p, store_zmm_v, JBLAS_DTYPE::BF16); + unpack_zmm_16bit_withfunc(p, store_zmm_v, BTLA_DTYPE::BF16); } else if (dt_cvt_cfg.before_store == DataConvertConfig::cvt_direct::F16_TO_FP32) { - unpack_zmm_16bit_withfunc(p, store_zmm_v, JBLAS_DTYPE::F16); + unpack_zmm_16bit_withfunc(p, store_zmm_v, BTLA_DTYPE::F16); } else if (dt_cvt_cfg.before_store == DataConvertConfig::cvt_direct::FP32_TO_BF16) { Xbyak::Ymm ymm_v = Xbyak::Ymm(p.vmm_v.getIdx()); if (device::CpuDevice::getInstance()->AVX512_BF16()) { @@ -705,9 +703,9 @@ class JitMemcpy2DAvx512f : protected jblas::xbyak::JitAvx512f { if (dt_cvt_cfg.before_postop == DataConvertConfig::cvt_direct::NO_CVT) { apply_postop_and_store({zmm_v, store_addr, tail}); } else if (dt_cvt_cfg.before_postop == DataConvertConfig::cvt_direct::BF16_TO_FP32) { - unpack_zmm_16bit_withfunc({zmm_v, store_addr, tail}, apply_postop_and_store, JBLAS_DTYPE::BF16); + unpack_zmm_16bit_withfunc({zmm_v, store_addr, tail}, apply_postop_and_store, BTLA_DTYPE::BF16); } else if (dt_cvt_cfg.before_postop == DataConvertConfig::cvt_direct::F16_TO_FP32) { - unpack_zmm_16bit_withfunc({zmm_v, store_addr, tail}, apply_postop_and_store, JBLAS_DTYPE::F16); + unpack_zmm_16bit_withfunc({zmm_v, store_addr, tail}, apply_postop_and_store, BTLA_DTYPE::F16); } else { assert(0); } @@ -738,7 +736,7 @@ class JitMemcpy2DAvx512f : protected jblas::xbyak::JitAvx512f { int vbytes = VBytes; // consider a case that input==bf16 but apply postop, betore sotre will be fp32_to_bf16 but need to normal gen // mask. - if (dt_cvt_cfg.input_dt == JBLAS_DTYPE::F32) { + if (dt_cvt_cfg.input_dt == BTLA_DTYPE::F32) { shr(reg_iter, 1); shr(reg_size, 1); vbytes /= 2; @@ -807,7 +805,7 @@ static inline Xbyak::Zmm unpack_4bit_2regs(Xbyak::CodeGenerator* jit, Xbyak::Ymm return dst; } -class DecompressS4S8_AVX512F : protected jblas::xbyak::JitAvx512f { +class DecompressS4S8_AVX512F : protected xbyak::JitAvx512f { public: struct params { void *srcptr, *dstptr; @@ -879,24 +877,24 @@ class DecompressS4S8_AVX512F : protected jblas::xbyak::JitAvx512f { mKernel = this->getCode(); } - static JBLAS_CODE forward(void* srcptr, void* dstptr, size_t size) { + static BTLA_CODE forward(void* srcptr, void* dstptr, size_t size) { static DecompressS4S8_AVX512F instance; auto param = params{srcptr, dstptr, size}; instance.mKernel(¶m); - return JblasSuccess; + return BTLA_CODE::Success; } private: func_t mKernel = nullptr; }; -static inline JBLAS_CODE decompress_s4_s8(utils::int4x2* srcptr, int8_t* dstptr, int row, int col, int ld_src, - int ld_dst) { +static inline BTLA_CODE decompress_s4_s8(utils::int4x2* srcptr, int8_t* dstptr, int row, int col, int ld_src, + int ld_dst) { if (col != ld_src) { // memory is not continuous - return JblasNotSupport; + return BTLA_CODE::NotSupport; } DecompressS4S8_AVX512F::forward(srcptr, dstptr, (size_t)row * col); - return JblasSuccess; + return BTLA_CODE::Success; } // src: row x col => dst: ⌈col/n_tile⌉ x ⌈row/row_pack⌉ x n_tile x row_pack (zeor-padded) @@ -920,11 +918,11 @@ class PaddingInterleaveCvt : protected xbyak::JitAvx512f { 12, 28, 13, 29, 14, 30, 15, 31, // }; - PaddingInterleaveCvt(int n_tile, JBLAS_DTYPE dst_t) : PaddingInterleaveCvt(n_tile, dst_t, dst_t) {} - PaddingInterleaveCvt(int n_tile, JBLAS_DTYPE dst_t, JBLAS_DTYPE src_t, int row_pack = 0) : xbyak::JitAvx512f() { + PaddingInterleaveCvt(int n_tile, BTLA_DTYPE dst_t) : PaddingInterleaveCvt(n_tile, dst_t, dst_t) {} + PaddingInterleaveCvt(int n_tile, BTLA_DTYPE dst_t, BTLA_DTYPE src_t, int row_pack = 0) : xbyak::JitAvx512f() { inLocalLabel(); // use local label for multiple instance - const auto src_bytes = static_cast(utils::jblas_dtype_size(src_t)); - const auto dst_bytes = static_cast(utils::jblas_dtype_size(dst_t)); + const auto src_bytes = static_cast(utils::bestla_dtype_size(src_t)); + const auto dst_bytes = static_cast(utils::bestla_dtype_size(dst_t)); if (row_pack == 0) row_pack = 4 / dst_bytes; // default value const auto ne_zmm = 64 / std::max(src_bytes, dst_bytes); const auto src_bytes_vmm = ne_zmm * src_bytes; @@ -1005,7 +1003,7 @@ class PaddingInterleaveCvt : protected xbyak::JitAvx512f { vmovdqu32(reg_srcs_ii | mask_rd | T_z, ptr[reg_tmp1 + ii * reg_srcstride + jj * src_bytes]); } } - if (src_t == JBLAS_DTYPE::F32 && dst_t == JBLAS_DTYPE::BF16) { + if (src_t == BTLA_DTYPE::F32 && dst_t == BTLA_DTYPE::BF16) { vcvtne2ps2bf16(reg_tmps[0], reg_srcs[1], reg_srcs[0]); vpermt2w(reg_tmps[0], vreg_idx0, reg_tmps[0]); vmovups(ptr[reg_tmp + jj * row_pack * dst_bytes], reg_tmps[0]); @@ -1049,7 +1047,7 @@ class PaddingInterleaveCvt : protected xbyak::JitAvx512f { } else { assert(false); } - if (src_t == JBLAS_DTYPE::F32 && dst_t == JBLAS_DTYPE::BF16) { + if (src_t == BTLA_DTYPE::F32 && dst_t == BTLA_DTYPE::BF16) { vcvtne2ps2bf16(reg_tmps[0], reg_srcs[1], reg_srcs[0]); vpermt2w(reg_tmps[0], vreg_idx0, reg_tmps[0]); vmovups(ptr[reg_tmp + jj * row_pack * dst_bytes], reg_tmps[0]); @@ -1078,7 +1076,7 @@ class PaddingInterleaveCvt : protected xbyak::JitAvx512f { const auto src_stride = static_cast(sizeof(T_SRC)) * src_step; const auto dst_stride = static_cast(sizeof(T_DST)) * dst_step; params param = {src, dst, row, col, src_stride, dst_stride}; - static const PaddingInterleaveCvt kern(NTile, utils::jblas_dtype, utils::jblas_dtype, RowPack); + static const PaddingInterleaveCvt kern(NTile, utils::bestla_dtype, utils::bestla_dtype, RowPack); kern(¶m); // extra row and col pad @@ -1123,11 +1121,11 @@ class PaddingTransInterleaveCvt : protected xbyak::JitAvx512f { const int trans_cell; // transpose matrices of size trans_cellxtrans_cell (in terms of #elements or #packs) private: - PaddingTransInterleaveCvt(int m_tile, JBLAS_DTYPE dst_t) : PaddingTransInterleaveCvt(m_tile, dst_t, dst_t) {} - PaddingTransInterleaveCvt(int m_tile, JBLAS_DTYPE dst_t, JBLAS_DTYPE src_t, int col_pack = 0) - : xbyak::JitAvx512f(), trans_cell(64 / col_pack / int(utils::jblas_dtype_size(dst_t))) { - const auto src_bytes = static_cast(utils::jblas_dtype_size(src_t)); - const auto dst_bytes = static_cast(utils::jblas_dtype_size(dst_t)); + PaddingTransInterleaveCvt(int m_tile, BTLA_DTYPE dst_t) : PaddingTransInterleaveCvt(m_tile, dst_t, dst_t) {} + PaddingTransInterleaveCvt(int m_tile, BTLA_DTYPE dst_t, BTLA_DTYPE src_t, int col_pack = 0) + : xbyak::JitAvx512f(), trans_cell(64 / col_pack / int(utils::bestla_dtype_size(dst_t))) { + const auto src_bytes = static_cast(utils::bestla_dtype_size(src_t)); + const auto dst_bytes = static_cast(utils::bestla_dtype_size(dst_t)); if (col_pack == 0) col_pack = 4 / dst_bytes; // default value // const auto src_bytes_vmm = ne_zmm * src_bytes; // const auto dst_bytes_vmm = ne_zmm * dst_bytes; @@ -1192,7 +1190,7 @@ class PaddingTransInterleaveCvt : protected xbyak::JitAvx512f { L(".colloop"); generate_Nbitsmask(mask_rd, reg_itercol, ptr[reg_colsize], reg_tmp2, reg_tmp3, 64 / dst_bytes); - if (src_t == JBLAS_DTYPE::F32 && dst_t == JBLAS_DTYPE::BF16) { + if (src_t == BTLA_DTYPE::F32 && dst_t == BTLA_DTYPE::BF16) { kshiftrq(mask_rd2, mask_rd, 16); assert(trans_cell == 16); for (int ii = 0; ii < trans_cell; ++ii) { @@ -1235,7 +1233,7 @@ class PaddingTransInterleaveCvt : protected xbyak::JitAvx512f { auto& tailcolloop = l_tail_case[m_tail]; L(tailcolloop); generate_Nbitsmask(mask_rd, reg_itercol, ptr[reg_colsize], reg_tmp2, reg_tmp3, 64 / dst_bytes); - if (src_t == JBLAS_DTYPE::F32 && dst_t == JBLAS_DTYPE::BF16) { + if (src_t == BTLA_DTYPE::F32 && dst_t == BTLA_DTYPE::BF16) { kshiftrq(mask_rd2, mask_rd, 16); assert(trans_cell == 16); for (int ii = 0; ii < trans_cell; ++ii) { @@ -1277,7 +1275,7 @@ class PaddingTransInterleaveCvt : protected xbyak::JitAvx512f { int dst_step) { assert(utils::padto(row, MTile) <= row_pad && row_pad % MTile == 0); assert(utils::padto(col, ColPack) <= col_pad && col_pad % ColPack == 0); - static const PaddingTransInterleaveCvt kern(MTile, utils::jblas_dtype, utils::jblas_dtype, ColPack); + static const PaddingTransInterleaveCvt kern(MTile, utils::bestla_dtype, utils::bestla_dtype, ColPack); // 0-padded guarantee by jit kern const auto kern_row_pad = utils::padto(row, kern.trans_cell), kern_col_pad = utils::padto(col, kern.trans_cell * ColPack); @@ -1482,4 +1480,4 @@ class CScaleInterleavedBF16FP16 : protected xbyak::JitAvx512_fp16 { } // namespace jit } // namespace kernel -} // namespace jblas +} // namespace bestla diff --git a/bestla/jblas/kernel_jit_injector.h b/bestla/bestla/kernel_jit_injector.h similarity index 96% rename from bestla/jblas/kernel_jit_injector.h rename to bestla/bestla/kernel_jit_injector.h index d3e49eecd..3db66fa4f 100644 --- a/bestla/jblas/kernel_jit_injector.h +++ b/bestla/bestla/kernel_jit_injector.h @@ -22,11 +22,11 @@ #include #include -#include "jit_blas.h" -#include "jit_blas_utils.h" +#include "bestla.h" +#include "bestla_utils.h" #include "xbyak/xbyak.h" -namespace jblas { +namespace bestla { namespace kernel { namespace jit_injector { using Zmm = Xbyak::Zmm; @@ -34,7 +34,7 @@ using Ymm = Xbyak::Ymm; using Xmm = Xbyak::Xmm; class eltwise_injector { public: - eltwise_injector(JBLAS_ELTWISEOP eltwiseop) : elt_op(eltwiseop) { reigster_table_entries(); } + eltwise_injector(BTLA_ELTWISEOP eltwiseop) : elt_op(eltwiseop) { reigster_table_entries(); } virtual ~eltwise_injector() {} void assign_resources(Xbyak::CodeGenerator* ptr, const std::set& used_zmm_idx, const Xbyak::Reg64& table_reg, @@ -65,25 +65,25 @@ class eltwise_injector { void vector_compute(const Xbyak::Zmm& zmm_src, int const_p_offset = 0) { load_table_addr(); switch (elt_op) { - case EXP: + case BTLA_ELTWISEOP::EXP: exp_compute_vector_fwd(zmm_src); break; - case TANH: + case BTLA_ELTWISEOP::TANH: tanh_compute_vector_fwd(zmm_src); break; - case GELU: + case BTLA_ELTWISEOP::GELU: gelu_compute_vector_fwd(zmm_src); break; - case RELU: + case BTLA_ELTWISEOP::RELU: relu_compute_vector_fwd(zmm_src, const_p_offset); break; - case LINEAR: + case BTLA_ELTWISEOP::LINEAR: linear_compute_vector_fwd(zmm_src, const_p_offset); break; - case LOW_PRECISION_EXP: + case BTLA_ELTWISEOP::LOW_PRECISION_EXP: low_precision_exp_compute_vector_fwd(zmm_src); break; - case SWISH: + case BTLA_ELTWISEOP::SWISH: swish_compute_vector_fwd(zmm_src, const_p_offset); break; default: @@ -94,19 +94,19 @@ class eltwise_injector { void vector_compute(const Xbyak::Ymm& ymm_src, int const_p_offset = 0) { load_table_addr(); switch (elt_op) { - case EXP: + case BTLA_ELTWISEOP::EXP: exp_compute_vector_fwd(ymm_src); break; - case TANH: + case BTLA_ELTWISEOP::TANH: tanh_compute_vector_fwd(ymm_src); break; - case GELU: + case BTLA_ELTWISEOP::GELU: gelu_compute_vector_fwd(ymm_src); break; - case LOW_PRECISION_EXP: + case BTLA_ELTWISEOP::LOW_PRECISION_EXP: low_precision_exp_compute_vector_fwd(ymm_src); break; - case SWISH: + case BTLA_ELTWISEOP::SWISH: swish_compute_vector_fwd(ymm_src, const_p_offset); break; default: @@ -136,9 +136,9 @@ class eltwise_injector { static constexpr std::array exp_approx_f32_coeff{0.35815147f, 0.96963238f, 1.f}; static const table_t low_precision_exp_consts{ - {low_precision_exp_const_v0, {jblas::utils::bit_cast(exp_approx_f32_coeff[0]), true}}, - {low_precision_exp_const_v1, {jblas::utils::bit_cast(exp_approx_f32_coeff[1]), true}}, - {low_precision_exp_const_v2, {jblas::utils::bit_cast(exp_approx_f32_coeff[2]), true}}, + {low_precision_exp_const_v0, {utils::bit_cast(exp_approx_f32_coeff[0]), true}}, + {low_precision_exp_const_v1, {utils::bit_cast(exp_approx_f32_coeff[1]), true}}, + {low_precision_exp_const_v2, {utils::bit_cast(exp_approx_f32_coeff[2]), true}}, }; static const table_t exp_consts{{exp_log2ef, {0x3fb8aa3b, true}}, @@ -425,12 +425,12 @@ class eltwise_injector { }; struct need_t { - explicit need_t(JBLAS_ELTWISEOP& op) { - if (op == EXP) exp_ = true; - if (op == TANH) tanh_ = true; - if (op == GELU) gelu_ = true; - if (op == SWISH) swish_ = true; - if (op == LOW_PRECISION_EXP) low_precision_exp_ = true; + explicit need_t(BTLA_ELTWISEOP& op) { + if (op == BTLA_ELTWISEOP::EXP) exp_ = true; + if (op == BTLA_ELTWISEOP::TANH) tanh_ = true; + if (op == BTLA_ELTWISEOP::GELU) gelu_ = true; + if (op == BTLA_ELTWISEOP::SWISH) swish_ = true; + if (op == BTLA_ELTWISEOP::LOW_PRECISION_EXP) low_precision_exp_ = true; } bool bf16_ = false; bool exp_ = false; @@ -835,7 +835,7 @@ class eltwise_injector { } private: - JBLAS_ELTWISEOP elt_op; + BTLA_ELTWISEOP elt_op; Xbyak::CodeGenerator* h = nullptr; /*labels*/ @@ -927,4 +927,4 @@ class eltwise_injector { }; } // namespace jit_injector } // namespace kernel -} // namespace jblas +} // namespace bestla diff --git a/bestla/jblas/kernel_ref.h b/bestla/bestla/kernel_ref.h similarity index 74% rename from bestla/jblas/kernel_ref.h rename to bestla/bestla/kernel_ref.h index 1e0ddccda..f55f4fa45 100644 --- a/bestla/jblas/kernel_ref.h +++ b/bestla/bestla/kernel_ref.h @@ -17,28 +17,28 @@ #include #include #include -#include "jblas/jit_blas.h" -#include "jit_blas_utils.h" +#include "bestla.h" +#include "bestla_utils.h" -namespace jblas { +namespace bestla { namespace kernel { namespace ref { template -static inline JBLAS_CODE shuffle_activation(T* src, T* dst, int shuffle_m, int shuffle_k, int m_offset, int k_offset, - int* indices, int src_stride, int dst_stride) { +static inline BTLA_CODE shuffle_activation(T* src, T* dst, int shuffle_m, int shuffle_k, int m_offset, int k_offset, + int* indices, int src_stride, int dst_stride) { T* cur_src = src + m_offset * src_stride; for (int i = 0; i < shuffle_m; i++) { for (int j = 0; j < shuffle_k; j++) { dst[i * dst_stride + j] = cur_src[i * src_stride + indices[k_offset + j]]; } } - return JblasSuccess; + return BTLA_CODE::Success; } template -static inline JBLAS_CODE padding_interleave(const T_SRC* src_ptr, T_DST* dst_ptr, int row, int col, int rowpad, - int colpad, int src_step, int dst_step, int NTile, int RowPack) { +static inline BTLA_CODE padding_interleave(const T_SRC* src_ptr, T_DST* dst_ptr, int row, int col, int rowpad, + int colpad, int src_step, int dst_step, int NTile, int RowPack) { const T_DST dst_0(0); static_assert(sizeof(T_SRC) == sizeof(T_DST), "SRC & DST size should be the same"); for (int i = 0; i < rowpad; i += RowPack) { @@ -53,14 +53,14 @@ static inline JBLAS_CODE padding_interleave(const T_SRC* src_ptr, T_DST* dst_ptr } } } - return JblasSuccess; + return BTLA_CODE::Success; } // revert padding and interleave // row*col <= colpad/NTile*rowpad*NTile template -static inline JBLAS_CODE revert_padding_interleave(const T_SRC* src_ptr, T_DST* dst_ptr, int row, int col, int rowpad, - int colpad, int src_step, int dst_step, int NTile, int RowPack) { +static inline BTLA_CODE revert_padding_interleave(const T_SRC* src_ptr, T_DST* dst_ptr, int row, int col, int rowpad, + int colpad, int src_step, int dst_step, int NTile, int RowPack) { static_assert(sizeof(T_SRC) == sizeof(T_DST), "SRC & DST size should be the same"); for (int i = 0; i < rowpad; i += RowPack) { for (int j = 0; j < colpad; j += NTile) { @@ -76,13 +76,13 @@ static inline JBLAS_CODE revert_padding_interleave(const T_SRC* src_ptr, T_DST* } } } - return JblasSuccess; + return BTLA_CODE::Success; } // M x N ===> M/MTile x N/colPack x MTile x colPack (leading dim stride = MTile * dst_stride) template -static inline JBLAS_CODE padding_trans_interleave(const T_SRC* src, T_DST* dst, int row, int col, int rowpad, - int colpad, int src_step, int dst_step, int MTile, int ColPack) { +static inline BTLA_CODE padding_trans_interleave(const T_SRC* src, T_DST* dst, int row, int col, int rowpad, int colpad, + int src_step, int dst_step, int MTile, int ColPack) { // Note: rows/cols and i/j are in terms of src static_assert(sizeof(T_SRC) == sizeof(T_DST), "SRC & DST size should be the same"); const T_DST dst_0(0); @@ -98,12 +98,12 @@ static inline JBLAS_CODE padding_trans_interleave(const T_SRC* src, T_DST* dst, } } } - return JblasSuccess; + return BTLA_CODE::Success; } template -static inline JBLAS_CODE dt_cvt_2D_write_back(const void* raw_srcptr, void* raw_dstptr, int row, int col, int srcstride, - int dststride, bool zeropadding) { +static inline BTLA_CODE dt_cvt_2D_write_back(const void* raw_srcptr, void* raw_dstptr, int row, int col, int srcstride, + int dststride, bool zeropadding) { for (int i = 0; i < row; i++) { int j = 0; for (; j < col; j++) { @@ -117,72 +117,72 @@ static inline JBLAS_CODE dt_cvt_2D_write_back(const void* raw_srcptr, void* raw_ } } } - return JblasSuccess; + return BTLA_CODE::Success; } template -static inline JBLAS_CODE dequan_s8_fp(int8_t* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, - float* scales) { +static inline BTLA_CODE dequan_s8_fp(int8_t* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, + float* scales) { for (int i = 0; i < row; i++) { for (int j = 0; j < col; j++) { dstptr[i * ld_dst + j] = static_cast(srcptr[i * ld_src + j]) * scales[j]; } } - return JblasSuccess; + return BTLA_CODE::Success; } -static inline JBLAS_CODE dequan_s8_bf16(int8_t* srcptr, uint16_t* dstptr, int row, int col, int ld_src, int ld_dst, - float* scales) { +static inline BTLA_CODE dequan_s8_bf16(int8_t* srcptr, uint16_t* dstptr, int row, int col, int ld_src, int ld_dst, + float* scales) { for (int i = 0; i < row; i++) { for (int j = 0; j < col; j++) { dstptr[i * ld_dst + j] = - jblas::utils::cast(static_cast(srcptr[i * ld_src + j]) * scales[j]).x; + utils::cast(static_cast(srcptr[i * ld_src + j]) * scales[j]).x; } } - return JblasSuccess; + return BTLA_CODE::Success; } template -static inline JBLAS_CODE transpose2d(const _T* srcptr, _T* dstptr, int row, int col, int ld_src, int ld_dst) { +static inline BTLA_CODE transpose2d(const _T* srcptr, _T* dstptr, int row, int col, int ld_src, int ld_dst) { for (int i = 0; i < col; i++) { for (size_t j = 0; j < row; j++) { dstptr[j + i * ld_dst] = srcptr[j * ld_src + i]; } } - return JblasSuccess; + return BTLA_CODE::Success; } template -static inline JBLAS_CODE compress_s8_s4(const int8_t* srcptr, jblas::utils::int4x2* dstptr, int row, int col, - int ld_src, int ld_dst) { +static inline BTLA_CODE compress_s8_s4(const int8_t* srcptr, utils::int4x2* dstptr, int row, int col, int ld_src, + int ld_dst) { for (int j = 0; j < row; j++) { for (int ii = 0; ii < col; ii += 2) { - jblas::utils::int4x2 tmp; - tmp.x = jblas::utils::int4x2::convert(srcptr[j * ld_src + ii + 0]); - tmp.y = jblas::utils::int4x2::convert(srcptr[j * ld_src + ii + 1]); + utils::int4x2 tmp; + tmp.x = utils::int4x2::convert(srcptr[j * ld_src + ii + 0]); + tmp.y = utils::int4x2::convert(srcptr[j * ld_src + ii + 1]); dstptr[j * ld_dst / 2 + ii / 2] = tmp; } } - return JblasSuccess; + return BTLA_CODE::Success; } template -static inline JBLAS_CODE compress_f4(const int8_t* srcptr, jblas::utils::f4x2* dstptr, int row, int col, int ld_src, - int ld_dst) { +static inline BTLA_CODE compress_f4(const int8_t* srcptr, utils::f4x2* dstptr, int row, int col, int ld_src, + int ld_dst) { for (int j = 0; j < row; j++) { for (int ii = 0; ii < col; ii += 2) { - jblas::utils::f4x2 tmp; + utils::f4x2 tmp; tmp.x = srcptr[j * ld_src + ii + 0]; tmp.y = srcptr[j * ld_src + ii + 1]; dstptr[j * ld_dst / 2 + ii / 2] = tmp; } } - return JblasSuccess; + return BTLA_CODE::Success; } template -static inline JBLAS_CODE decompress_s4_f32(jblas::utils::int4x2* srcptr, float* dstptr, int row, int col, int ld_src, - int ld_dst, float* scales) { +static inline BTLA_CODE decompress_s4_f32(utils::int4x2* srcptr, float* dstptr, int row, int col, int ld_src, + int ld_dst, float* scales) { for (int i = 0; i < row; i++) { for (int j = 0; j < col; j += 2) { auto tmp = srcptr[i * ld_src / 2 + j / 2]; @@ -191,15 +191,15 @@ static inline JBLAS_CODE decompress_s4_f32(jblas::utils::int4x2* srcptr, float* dstptr[i * ld_dst + j + 1] = static_cast(static_cast(tmp.y) << 4) * scales[noffset + 1]; } } - return JblasSuccess; + return BTLA_CODE::Success; } -template +template inline int8_t get_s8(int8_t v) { switch (S4_T) { - case JBLAS_DTYPE::S4_CLIP: + case BTLA_DTYPE::S4_CLIP: return v << 4; - case JBLAS_DTYPE::S4_FULLRANGE: + case BTLA_DTYPE::S4_FULLRANGE: v &= 0x0f; return v - 8; default: @@ -209,7 +209,7 @@ inline int8_t get_s8(int8_t v) { return static_cast(0); } -template +template inline void convert_s4_s8_8(int8_t* dstptr, int8_t* srcptr) { auto src32 = *reinterpret_cast(srcptr); auto tmp = static_cast(src32 & 0xf) << 4; @@ -251,7 +251,7 @@ inline void convert_s4_s8_8_lowbits(int8_t* dstptr, int8_t* srcptr) { } template <> -inline void convert_s4_s8_8(int8_t* dstptr, int8_t* srcptr) { +inline void convert_s4_s8_8(int8_t* dstptr, int8_t* srcptr) { convert_s4_s8_8_lowbits(dstptr, srcptr); for (size_t i = 0; i < 8; i++) { dstptr[i] -= 8; @@ -259,22 +259,22 @@ inline void convert_s4_s8_8(int8_t* dstptr, int8_t* s } template <> -inline void convert_s4_s8_8(int8_t* dstptr, int8_t* srcptr) { +inline void convert_s4_s8_8(int8_t* dstptr, int8_t* srcptr) { convert_s4_s8_8_lowbits(dstptr, srcptr); } template <> -inline void convert_s4_s8_8(int8_t* dstptr, int8_t* srcptr) { +inline void convert_s4_s8_8(int8_t* dstptr, int8_t* srcptr) { convert_s4_s8_8_lowbits(dstptr, srcptr); } template <> -inline void convert_s4_s8_8(int8_t* dstptr, int8_t* srcptr) { +inline void convert_s4_s8_8(int8_t* dstptr, int8_t* srcptr) { convert_s4_s8_8_lowbits(dstptr, srcptr); } -template -inline JBLAS_CODE decompress_s4_s8(utils::int4x2* srcptr, int8_t* dstptr, int row, int col, int ld_src, int ld_dst) { +template +inline BTLA_CODE decompress_s4_s8(utils::int4x2* srcptr, int8_t* dstptr, int row, int col, int ld_src, int ld_dst) { for (int i = 0; i < row; i++) { for (int j = 0; j < col; j += 2) { auto tmp = srcptr[i * ld_src / 2 + j / 2]; @@ -282,16 +282,16 @@ inline JBLAS_CODE decompress_s4_s8(utils::int4x2* srcptr, int8_t* dstptr, int ro dstptr[i * ld_dst + j + 1] = get_s8(tmp.y); } } - return JblasSuccess; + return BTLA_CODE::Success; } -inline float f8_to_fp32(utils::f8 v, JBLAS_DTYPE f8_t) { +inline float f8_to_fp32(utils::f8 v, BTLA_DTYPE f8_t) { uint32_t sign_revert = v.x; uint32_t e_revert = v.x; uint32_t mantissa_revert = v.x; sign_revert <<= 24; sign_revert &= 0x80000000; - auto ebits = utils::jblas_dtype_get_f8_ebits(f8_t); + auto ebits = utils::bestla_dtype_get_f8_ebits(f8_t); auto mantissabit = 7 - ebits; e_revert &= 0x7f; e_revert >>= mantissabit; @@ -305,8 +305,8 @@ inline float f8_to_fp32(utils::f8 v, JBLAS_DTYPE f8_t) { } template -inline JBLAS_CODE decompress_kblock_f8_fp(utils::f8* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, - _S_T* scales, int k_offset, int kblock, int NPad, JBLAS_DTYPE src_f8_type) { +inline BTLA_CODE decompress_kblock_f8_fp(utils::f8* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, + _S_T* scales, int k_offset, int kblock, int NPad, BTLA_DTYPE src_f8_type) { for (int i = 0; i < row; i++) { int kpos = (k_offset + i) / kblock; auto sptr = scales + kpos * NPad; @@ -324,12 +324,12 @@ inline JBLAS_CODE decompress_kblock_f8_fp(utils::f8* srcptr, _DST_T* dstptr, int dstptr[i * ld_dst + j] = fp_v * scale; } } - return JblasSuccess; + return BTLA_CODE::Success; } template -inline JBLAS_CODE decompress_kblock_s8_fp(int8_t* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, - _S_T* scales, int8_t* zero_points, int k_offset, int kblock, int NPad) { +inline BTLA_CODE decompress_kblock_s8_fp(int8_t* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, + _S_T* scales, int8_t* zero_points, int k_offset, int kblock, int NPad) { for (int i = 0; i < row; i++) { int kpos = (k_offset + i) / kblock; auto sptr = scales + kpos * NPad; @@ -339,13 +339,13 @@ inline JBLAS_CODE decompress_kblock_s8_fp(int8_t* srcptr, _DST_T* dstptr, int ro dstptr[i * ld_dst + j] = static_cast<_DST_T>(tmp * sptr[j / _PACK_ROW]); } } - return JblasSuccess; + return BTLA_CODE::Success; } -template -inline JBLAS_CODE decompress_kblock_s4_fp(utils::int4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, - int ld_dst, _S_T* scales, int8_t* zero_points, int k_offset, int kblock, - int NPad, int8_t* tmp, size_t tmpsize) { +template +inline BTLA_CODE decompress_kblock_s4_fp(utils::int4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, + int ld_dst, _S_T* scales, int8_t* zero_points, int k_offset, int kblock, + int NPad, int8_t* tmp, size_t tmpsize) { for (int i = 0; i < row; i++) { int kpos = (k_offset + i) / kblock; auto sptr = scales + kpos * NPad; @@ -370,12 +370,12 @@ inline JBLAS_CODE decompress_kblock_s4_fp(utils::int4x2* srcptr, _DST_T* dstptr, dstptr[i * ld_dst + j + 1] = static_cast<_DST_T>(dst1); } } - return JblasSuccess; + return BTLA_CODE::Success; } -template -inline JBLAS_CODE decompress_kblock_s4_s8fp(utils::int4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, - int ld_dst, int8_t* tmp, size_t tmpsize) { +template +inline BTLA_CODE decompress_kblock_s4_s8fp(utils::int4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, + int ld_dst, int8_t* tmp, size_t tmpsize) { for (int i = 0; i < row; i++) { for (int j = 0; j < col; j += 2) { auto tmp = srcptr[i * ld_src / 2 + j / 2]; @@ -383,18 +383,18 @@ inline JBLAS_CODE decompress_kblock_s4_s8fp(utils::int4x2* srcptr, _DST_T* dstpt dstptr[i * ld_dst + j + 1] = static_cast<_DST_T>(static_cast(get_s8(tmp.y))); } } - return JblasSuccess; + return BTLA_CODE::Success; } template -inline JBLAS_CODE decompress_kblock_s8_s8fp(int8_t* srcptr, DST_T* dstptr, int row, int col, int ld_src, int ld_dst) { +inline BTLA_CODE decompress_kblock_s8_s8fp(int8_t* srcptr, DST_T* dstptr, int row, int col, int ld_src, int ld_dst) { for (int i = 0; i < row; i++) { for (int j = 0; j < col; j += 1) { auto tmp = srcptr[i * ld_src + j]; dstptr[i * ld_dst + j] = static_cast(static_cast(tmp)); } } - return JblasSuccess; + return BTLA_CODE::Success; } inline float fp4_bnb_unpack(uint8_t val) { @@ -558,7 +558,7 @@ inline float nf4_unpack(int8_t val) { inline float nf4_dequantize(int8_t val, float absmax) { return nf4_unpack(val) * absmax; } -// Note: In the BNB Nf4 definition, 0 has a non-zero value after dequantization, but Jblas uses 0 for padding, which +// Note: In the BNB Nf4 definition, 0 has a non-zero value after dequantization, but BTLA uses 0 for padding, which // leads to calculation errors. We ultimately choose to swap the binary bits of -1 and 0 in Nf4 to avoid this // conflict. inline int8_t nf4_quantize(float x) { @@ -603,16 +603,16 @@ inline int8_t nf4_quantize(float x) { return 0b0111; } -template +template inline float f4_unpack(int8_t v) { - static_assert(F4_T == JBLAS_DTYPE::F4_BNB || F4_T == JBLAS_DTYPE::F4_NF4 || F4_T == JBLAS_DTYPE::F4_E2M1, + static_assert(F4_T == BTLA_DTYPE::F4_BNB || F4_T == BTLA_DTYPE::F4_NF4 || F4_T == BTLA_DTYPE::F4_E2M1, "Unsupported F4 type"); switch (F4_T) { - case JBLAS_DTYPE::F4_BNB: + case BTLA_DTYPE::F4_BNB: return fp4_bnb_unpack(v); - case JBLAS_DTYPE::F4_NF4: + case BTLA_DTYPE::F4_NF4: return nf4_unpack(v); - case JBLAS_DTYPE::F4_E2M1: + case BTLA_DTYPE::F4_E2M1: return fp4_e2m1_unpack(v); default: break; @@ -620,23 +620,23 @@ inline float f4_unpack(int8_t v) { return std::numeric_limits::quiet_NaN(); } -template +template inline float f4_dequantize(int8_t v, float scale) { - static_assert(F4_T == JBLAS_DTYPE::F4_BNB || F4_T == JBLAS_DTYPE::F4_NF4 || F4_T == JBLAS_DTYPE::F4_E2M1, + static_assert(F4_T == BTLA_DTYPE::F4_BNB || F4_T == BTLA_DTYPE::F4_NF4 || F4_T == BTLA_DTYPE::F4_E2M1, "Unsupported F4 type"); return f4_unpack(v) * scale; } -template +template inline int8_t f4_quantize(float x) { - static_assert(F4_T == JBLAS_DTYPE::F4_BNB || F4_T == JBLAS_DTYPE::F4_NF4 || F4_T == JBLAS_DTYPE::F4_E2M1, + static_assert(F4_T == BTLA_DTYPE::F4_BNB || F4_T == BTLA_DTYPE::F4_NF4 || F4_T == BTLA_DTYPE::F4_E2M1, "Unsupported F4 type"); switch (F4_T) { - case JBLAS_DTYPE::F4_BNB: + case BTLA_DTYPE::F4_BNB: return fp4_bnb_quantize(x); - case JBLAS_DTYPE::F4_NF4: + case BTLA_DTYPE::F4_NF4: return nf4_quantize(x); - case JBLAS_DTYPE::F4_E2M1: + case BTLA_DTYPE::F4_E2M1: return fp4_e2m1_quantize(x); default: break; @@ -644,10 +644,10 @@ inline int8_t f4_quantize(float x) { return static_cast(0); } -template -inline JBLAS_CODE decompress_kblock_f4_fp(utils::f4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, - _S_T* scales, int k_offset, int kblock, int NPad, int8_t* tmp, - size_t tmpsize) { +template +inline BTLA_CODE decompress_kblock_f4_fp(utils::f4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, + _S_T* scales, int k_offset, int kblock, int NPad, int8_t* tmp, + size_t tmpsize) { for (int i = 0; i < row; i++) { int kpos = (k_offset + i) / kblock; auto sptr = scales + kpos * NPad; @@ -665,12 +665,12 @@ inline JBLAS_CODE decompress_kblock_f4_fp(utils::f4x2* srcptr, _DST_T* dstptr, i dstptr[i * ld_dst + j + 1] = static_cast<_DST_T>(dst1); } } - return JblasSuccess; + return BTLA_CODE::Success; } -template -inline JBLAS_CODE decompress_kblock_f4_fp_noscale(utils::f4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, - int ld_dst, int8_t* tmp, size_t tmpsize) { +template +inline BTLA_CODE decompress_kblock_f4_fp_noscale(utils::f4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, + int ld_dst, int8_t* tmp, size_t tmpsize) { for (int i = 0; i < row; i++) { for (int j = 0; j < col; j += 2) { auto tmp = srcptr[i * ld_src / 2 + j / 2]; @@ -678,36 +678,35 @@ inline JBLAS_CODE decompress_kblock_f4_fp_noscale(utils::f4x2* srcptr, _DST_T* d dstptr[i * ld_dst + j + 1] = static_cast<_DST_T>(f4_unpack(tmp.y)); } } - return JblasSuccess; + return BTLA_CODE::Success; } template -inline JBLAS_CODE decompress_kblock_f8_fp_noscale(utils::f8* srcptr, _DST_T* dstptr, int row, int col, int ld_src, - int ld_dst, JBLAS_DTYPE src_f8_t) { +inline BTLA_CODE decompress_kblock_f8_fp_noscale(utils::f8* srcptr, _DST_T* dstptr, int row, int col, int ld_src, + int ld_dst, BTLA_DTYPE src_f8_t) { for (int i = 0; i < row; i++) { for (int j = 0; j < col; j++) { dstptr[i * ld_dst + j] = f8_to_fp32(srcptr[i * ld_src + j], src_f8_t); } } - return JblasSuccess; + return BTLA_CODE::Success; } -static inline JBLAS_CODE memcpy2d_dw2highw(const void* srcptr, void* dstptr, int row, int col, int srcstride, - int dststride) { +static inline BTLA_CODE memcpy2d_dw2highw(const void* srcptr, void* dstptr, int row, int col, int srcstride, + int dststride) { auto bsrcptr = (char*)srcptr; auto bdstptr = (char*)dstptr; for (int i = 0; i < row; i++) { for (int j = 0; j < col; j++) { - std::memcpy(bdstptr + i * dststride + j * sizeof(jblas::utils::bf16), - bsrcptr + i * srcstride + j * sizeof(float) + 2, sizeof(jblas::utils::bf16)); + std::memcpy(bdstptr + i * dststride + j * sizeof(utils::bf16), bsrcptr + i * srcstride + j * sizeof(float) + 2, + sizeof(utils::bf16)); } } - return JblasSuccess; + return BTLA_CODE::Success; } template -static inline JBLAS_CODE memcpy2d(const _SRC_T* srcptr, _DST_T* dstptr, int row, int col, int srcstride, - int dststride) { +static inline BTLA_CODE memcpy2d(const _SRC_T* srcptr, _DST_T* dstptr, int row, int col, int srcstride, int dststride) { auto bsrcptr = (const char*)srcptr; auto bdstptr = (char*)dstptr; for (int i = 0; i < row; i++) { @@ -727,23 +726,23 @@ static inline JBLAS_CODE memcpy2d(const _SRC_T* srcptr, _DST_T* dstptr, int row, assert(0); } } - return JblasSuccess; + return BTLA_CODE::Success; } -static float postop(float x, JBLAS_ELTWISEOP op, void* const_elt_v) { - if (op == GELU) { +static float postop(float x, BTLA_ELTWISEOP op, void* const_elt_v) { + if (op == BTLA_ELTWISEOP::GELU) { return 0.5f * x * (1.f + tanhf(0.7978845834732056f * (x + 0.044714998453855515f * x * x * x))); } - if (op == SWISH) { + if (op == BTLA_ELTWISEOP::SWISH) { return x / (1 + exp(-x)); } assert(0); return std::numeric_limits::infinity(); } -template -static inline JBLAS_CODE memcpy2d_withop(const _SRC_T* srcptr, _DST_T* dstptr, int row, int col, int srcstride, - int dststride, void* const_elt_v) { +template +static inline BTLA_CODE memcpy2d_withop(const _SRC_T* srcptr, _DST_T* dstptr, int row, int col, int srcstride, + int dststride, void* const_elt_v) { for (int i = 0; i < row; i++) { for (int j = 0; j < col; j += sizeof(_SRC_T)) { float v = srcptr[(i * srcstride + j) / sizeof(_SRC_T)]; @@ -751,11 +750,11 @@ static inline JBLAS_CODE memcpy2d_withop(const _SRC_T* srcptr, _DST_T* dstptr, i dstptr[(i * srcstride + j) / sizeof(_DST_T)] = v; } } - return JblasSuccess; + return BTLA_CODE::Success; } -static inline JBLAS_CODE get2d_e8m0_scale(const void* srcptr, void* dstptr, int row, int col, int srcstride, - int dststride) { +static inline BTLA_CODE get2d_e8m0_scale(const void* srcptr, void* dstptr, int row, int col, int srcstride, + int dststride) { auto f8_v = (const utils::f8*)srcptr; auto f32_v = (float*)dstptr; auto f8_stride = srcstride / sizeof(utils::f8); @@ -766,12 +765,12 @@ static inline JBLAS_CODE get2d_e8m0_scale(const void* srcptr, void* dstptr, int f32_v[i * f32_stride + j] = std::pow(2, f8_v[i * f8_stride + j].x); } } - return JblasSuccess; + return BTLA_CODE::Success; } -template -inline JBLAS_CODE quantize_f32_sign_int_rowblock(const float* srcptr, int8_t* dstptr, int row, int col, int ld_src, - int ld_dst, float* scales, int8_t* zero_points, int blocksize) { +template +inline BTLA_CODE quantize_f32_sign_int_rowblock(const float* srcptr, int8_t* dstptr, int row, int col, int ld_src, + int ld_dst, float* scales, int8_t* zero_points, int blocksize) { int raw_blocksize = blocksize; for (int i = 0; i < col; i++) { int align_row_loop = row / blocksize * blocksize; @@ -848,15 +847,15 @@ inline JBLAS_CODE quantize_f32_sign_int_rowblock(const float* srcptr, int8_t* ds auto dispatch_calc = [&](int blocksize) { switch (S4_T) { - case JBLAS_DTYPE::S8: - case JBLAS_DTYPE::S4_CLIP: + case BTLA_DTYPE::S8: + case BTLA_DTYPE::S4_CLIP: if (zero_points == nullptr) { s8_calc_store_scale_and_quantv_sym(blocksize); } else { s8_calc_store_scale_and_quantv_asym(blocksize); } break; - case JBLAS_DTYPE::S4_FULLRANGE: + case BTLA_DTYPE::S4_FULLRANGE: if (zero_points == nullptr) { s4_fullrange_calc_store_scale_and_quantv_sym(blocksize); } else { @@ -872,18 +871,18 @@ inline JBLAS_CODE quantize_f32_sign_int_rowblock(const float* srcptr, int8_t* ds for (; j < align_row_loop; j += blocksize) dispatch_calc(blocksize); if (j < row) dispatch_calc(row - align_row_loop); } - return JblasSuccess; + return BTLA_CODE::Success; } -template -int8_t f8_mx_quantize(float v, float scale, JBLAS_DTYPE scale_dtype) { - if (scale_dtype == JBLAS_DTYPE::F8_E8M0) { +template +int8_t f8_mx_quantize(float v, float scale, BTLA_DTYPE scale_dtype) { + if (scale_dtype == BTLA_DTYPE::F8_E8M0) { v /= std::pow(2, scale); } else { v /= scale; } - auto ebits = utils::jblas_dtype_get_f8_ebits(F8_T); - auto quant_mantissa = utils::jblas_dtype_get_f8_quant_mbits(F8_T); + auto ebits = utils::bestla_dtype_get_f8_ebits(F8_T); + auto quant_mantissa = utils::bestla_dtype_get_f8_quant_mbits(F8_T); auto store_mantissa = 7 - ebits; auto private_exp = std::floor(std::log2(std::abs(v == 0 ? v + 1 : v))); auto min_exp = -1 * (std::pow(2, ebits - 1)) + 2; @@ -906,8 +905,8 @@ int8_t f8_mx_quantize(float v, float scale, JBLAS_DTYPE scale_dtype) { *shift_v <<= 1; uint8_t store_ebit = (*(p + 3) & 0xFF); store_ebit = store_ebit - 127 + std::pow(2, ebits - 1) - 1; - if (store_ebit > 15 && F8_T == JBLAS_DTYPE::F8_E4M3) store_ebit = 0; - if (store_ebit > 31 && F8_T == JBLAS_DTYPE::F8_E5M2) store_ebit = 0; + if (store_ebit > 15 && F8_T == BTLA_DTYPE::F8_E4M3) store_ebit = 0; + if (store_ebit > 31 && F8_T == BTLA_DTYPE::F8_E5M2) store_ebit = 0; store_ebit <<= store_mantissa; *shift_v <<= 8; int8_t ox80_shift = -128 >> (store_mantissa - 1); @@ -917,9 +916,9 @@ int8_t f8_mx_quantize(float v, float scale, JBLAS_DTYPE scale_dtype) { return ret; } -template -inline JBLAS_CODE quantize_f32_f8_rowblock_mxscale(const float* srcptr, int8_t* dstptr, int row, int col, int ld_src, - int ld_dst, float* scales, int blocksize, JBLAS_DTYPE scale_dtype) { +template +inline BTLA_CODE quantize_f32_f8_rowblock_mxscale(const float* srcptr, int8_t* dstptr, int row, int col, int ld_src, + int ld_dst, float* scales, int blocksize, BTLA_DTYPE scale_dtype) { for (int i = 0; i < col; i++) { int align_row_loop = row / blocksize * blocksize; int j = 0; @@ -928,18 +927,18 @@ inline JBLAS_CODE quantize_f32_f8_rowblock_mxscale(const float* srcptr, int8_t* for (size_t ij = 0; ij < blksize; ij++) { scale = std::max(scale, std::abs(srcptr[(j + ij) * ld_src + i])); } - if (scale_dtype == JBLAS_DTYPE::F8_E8M0) { + if (scale_dtype == BTLA_DTYPE::F8_E8M0) { if (scale == 0) scale += std::abs(std::numeric_limits::min()); scale = std::floor(std::log2(scale)); - auto ebits = utils::jblas_dtype_get_f8_ebits(F8_T); + auto ebits = utils::bestla_dtype_get_f8_ebits(F8_T); auto emax = std::pow(2, ebits - 1); - if (F8_T == JBLAS_DTYPE::F8_E5M2) emax -= 1; + if (F8_T == BTLA_DTYPE::F8_E5M2) emax -= 1; scale -= emax; auto scale_max = std::pow(2, 7) - 1; // e8m0 scale type. scale = scale < (-1 * scale_max) ? (-1 * scale_max) : scale; - } else if (scale_dtype == JBLAS_DTYPE::F32) { - scale /= utils::get_mxfp_maxnorm(F8_T, utils::jblas_dtype_get_f8_ebits(F8_T), - utils::jblas_dtype_get_f8_quant_mbits(F8_T)); + } else if (scale_dtype == BTLA_DTYPE::F32) { + scale /= utils::get_mxfp_maxnorm(F8_T, utils::bestla_dtype_get_f8_ebits(F8_T), + utils::bestla_dtype_get_f8_quant_mbits(F8_T)); } else { assert(0); } @@ -951,12 +950,12 @@ inline JBLAS_CODE quantize_f32_f8_rowblock_mxscale(const float* srcptr, int8_t* for (; j < align_row_loop; j += blocksize) f8_blk_quant(blocksize); if (j < row) f8_blk_quant(row - align_row_loop); } - return JblasSuccess; + return BTLA_CODE::Success; } -template -inline JBLAS_CODE quantize_f32_f4_rowblock(const float* srcptr, int8_t* dstptr, int row, int col, int ld_src, - int ld_dst, float* scales, int8_t* zero_points, int blocksize) { +template +inline BTLA_CODE quantize_f32_f4_rowblock(const float* srcptr, int8_t* dstptr, int row, int col, int ld_src, int ld_dst, + float* scales, int8_t* zero_points, int blocksize) { int raw_blocksize = blocksize; for (int i = 0; i < col; i++) { int align_row_loop = row / blocksize * blocksize; @@ -996,13 +995,12 @@ inline JBLAS_CODE quantize_f32_f4_rowblock(const float* srcptr, int8_t* dstptr, for (; j < align_row_loop; j += blocksize) dispatch_calc(blocksize); if (j < row) dispatch_calc(row - align_row_loop); } - return JblasSuccess; + return BTLA_CODE::Success; } template -inline JBLAS_CODE quantize_fp_u8_colblock(int row, int col, const SRC_T* srcptr, int ld_src, uint8_t* dstptr, - int ld_dst, float* scales, int ld_scale, uint8_t* zps, int blocksize, - float* blkreduce) { +inline BTLA_CODE quantize_fp_u8_colblock(int row, int col, const SRC_T* srcptr, int ld_src, uint8_t* dstptr, int ld_dst, + float* scales, int ld_scale, uint8_t* zps, int blocksize, float* blkreduce) { int colblk = utils::padto_le(col, blocksize); for (int i = 0; i < row; i++) { size_t j = 0; @@ -1057,12 +1055,12 @@ inline JBLAS_CODE quantize_fp_u8_colblock(int row, int col, const SRC_T* srcptr, } } } - return JblasSuccess; + return BTLA_CODE::Success; } template -inline JBLAS_CODE quantize_fp_s8_colblock(int row, int col, const SRC_T* srcptr, int ld_src, int8_t* dstptr, int ld_dst, - float* scales, int ld_scale, int blocksize, float* reduce) { +inline BTLA_CODE quantize_fp_s8_colblock(int row, int col, const SRC_T* srcptr, int ld_src, int8_t* dstptr, int ld_dst, + float* scales, int ld_scale, int blocksize, float* reduce) { int colblk = utils::padto_le(col, blocksize); for (int i = 0; i < row; i++) { size_t j = 0; @@ -1102,30 +1100,30 @@ inline JBLAS_CODE quantize_fp_s8_colblock(int row, int col, const SRC_T* srcptr, if (reduce) reduce[j / blocksize + i * ld_scale] = sum * scale; } } - return JblasSuccess; + return BTLA_CODE::Success; } -static inline JBLAS_CODE alphabeta_f32_f32(const float alpha, const float* srcptr, const int srcstep, const float beta, - const float* src1ptr, const int src1step, float* dstptr, const int dststep, - const int M, const int N) { +static inline BTLA_CODE alphabeta_f32_f32(const float alpha, const float* srcptr, const int srcstep, const float beta, + const float* src1ptr, const int src1step, float* dstptr, const int dststep, + const int M, const int N) { if (beta != 0.f) { for (int i = 0; i < M; i++) { for (int j = 0; j < N; j++) { dstptr[i * dststep + j] = alpha * srcptr[i * srcstep + j] + beta * src1ptr[i * src1step + j]; } } - return JblasSuccess; + return BTLA_CODE::Success; } for (int i = 0; i < M; i++) { for (int j = 0; j < N; j++) { dstptr[i * dststep + j] = alpha * srcptr[i * srcstep + j]; } } - return JblasSuccess; + return BTLA_CODE::Success; } template -static inline JBLAS_CODE accum_alphaN_f32_f32(const SCA_T* alpha, const float* srcptr, const int srcstep, float* dstptr, - const int dststep, const int M, const int N) { +static inline BTLA_CODE accum_alphaN_f32_f32(const SCA_T* alpha, const float* srcptr, const int srcstep, float* dstptr, + const int dststep, const int M, const int N) { for (size_t i = 0; i < M; i++) { for (size_t j = 0; j < N; j++) { if constexpr (!std::is_same_v) { @@ -1136,22 +1134,22 @@ static inline JBLAS_CODE accum_alphaN_f32_f32(const SCA_T* alpha, const float* s } } } - return JblasSuccess; + return BTLA_CODE::Success; } -static inline JBLAS_CODE accum_f32_f32(const float* srcptr, const int srcstep, float* dstptr, const int dststep, - const int M, const int N) { +static inline BTLA_CODE accum_f32_f32(const float* srcptr, const int srcstep, float* dstptr, const int dststep, + const int M, const int N) { for (size_t i = 0; i < M; i++) { for (size_t j = 0; j < N; j++) { dstptr[i * dststep + j] = srcptr[i * srcstep + j] + dstptr[i * dststep + j]; } } - return JblasSuccess; + return BTLA_CODE::Success; } -static inline JBLAS_CODE quanout_s32_u32(const float alpha, const int32_t* srcptr, const int srcstep, uint8_t* dstptr, - const int dststep, const int M, const int N, float scaleSrc, float scaleDst, - int zpDst) { +static inline BTLA_CODE quanout_s32_u32(const float alpha, const int32_t* srcptr, const int srcstep, uint8_t* dstptr, + const int dststep, const int M, const int N, float scaleSrc, float scaleDst, + int zpDst) { float factor = alpha * scaleSrc / scaleDst; for (int i = 0; i < M; i++) { for (int j = 0; j < N; j++) { @@ -1159,13 +1157,13 @@ static inline JBLAS_CODE quanout_s32_u32(const float alpha, const int32_t* srcpt dstptr[i * dststep + j] = utils::cast(fsrc + static_cast(zpDst)); } } - return JblasSuccess; + return BTLA_CODE::Success; } template -static inline JBLAS_CODE dequant_s32_fp32(const int32_t* srcptr, const int srcstep, float* dstptr, const int dststep, - const int M, const int N, const float* scaleA, const int ldsa, - const SCAB_T* scaleB) { +static inline BTLA_CODE dequant_s32_fp32(const int32_t* srcptr, const int srcstep, float* dstptr, const int dststep, + const int M, const int N, const float* scaleA, const int ldsa, + const SCAB_T* scaleB) { for (int i = 0; i < M; i++) { float scale = scaleA[i * ldsa]; for (int j = 0; j < N; j++) { @@ -1173,11 +1171,11 @@ static inline JBLAS_CODE dequant_s32_fp32(const int32_t* srcptr, const int srcst dstptr[i * dststep + j] = fsrc; } } - return JblasSuccess; + return BTLA_CODE::Success; } -inline JBLAS_CODE minmax_f32_kblock(const float* srcptr, int row, int col, int ld_src, float* minmaxptr, int ld_minmax, - int fsize_minmax, int blocksize) { +inline BTLA_CODE minmax_f32_kblock(const float* srcptr, int row, int col, int ld_src, float* minmaxptr, int ld_minmax, + int fsize_minmax, int blocksize) { for (int i = 0; i < row; i++) { if (col >= blocksize) { for (int icol = 0; icol < col; icol += blocksize) { @@ -1202,32 +1200,32 @@ inline JBLAS_CODE minmax_f32_kblock(const float* srcptr, int row, int col, int l minmaxptr[i * ld_minmax + 1] = maxval; } } - return JblasSuccess; + return BTLA_CODE::Success; } -static inline JBLAS_CODE accumulate_dequantize_s32_f32(const int32_t* srcptr, float* dstptr, float alpha, float beta, - int row, int col, int ld_src, int ld_dst, float* ascales, - int ldas, float* wscales) { +static inline BTLA_CODE accumulate_dequantize_s32_f32(const int32_t* srcptr, float* dstptr, float alpha, float beta, + int row, int col, int ld_src, int ld_dst, float* ascales, + int ldas, float* wscales) { for (int irow = 0; irow < row; irow++) { for (int icol = 0; icol < col; icol++) { float scale = ascales[irow * ldas] * wscales[icol] * alpha; dstptr[irow * ld_dst + icol] = scale * srcptr[irow * ld_src + icol] + beta * dstptr[irow * ld_dst + icol]; } } - return JblasSuccess; + return BTLA_CODE::Success; } -static inline JBLAS_CODE broadcast_u8(int num, const uint8_t& srcval, uint8_t* dstptr) { +static inline BTLA_CODE broadcast_u8(int num, const uint8_t& srcval, uint8_t* dstptr) { int i = 0; for (; i < num; i++) { dstptr[i] = srcval; } - return JblasSuccess; + return BTLA_CODE::Success; } template -static inline JBLAS_CODE quant_s8_row_reduce_sum(const int8_t* srcptr, int ldsrc, const float* scales, - const int8_t* zero_points, int row, int col, _RT* reduce) { +static inline BTLA_CODE quant_s8_row_reduce_sum(const int8_t* srcptr, int ldsrc, const float* scales, + const int8_t* zero_points, int row, int col, _RT* reduce) { std::memset(reduce, 0, sizeof(reduce[0]) * col); for (int i = 0; i < row; i++) { for (int j = 0; j < col; j++) { @@ -1239,11 +1237,11 @@ static inline JBLAS_CODE quant_s8_row_reduce_sum(const int8_t* srcptr, int ldsrc } } } - return JblasSuccess; + return BTLA_CODE::Success; } template -static inline JBLAS_CODE row_reduce_sum(const float* srcptr, int ldsrc, int row, int col, _RT* reduce) { +static inline BTLA_CODE row_reduce_sum(const float* srcptr, int ldsrc, int row, int col, _RT* reduce) { for (int j = 0; j < col; j++) { float tmp = 0.f; for (int i = 0; i < row; i++) { @@ -1251,12 +1249,12 @@ static inline JBLAS_CODE row_reduce_sum(const float* srcptr, int ldsrc, int row, } reduce[j] = static_cast<_RT>(tmp); } - return JblasSuccess; + return BTLA_CODE::Success; } template -static inline JBLAS_CODE col_block_reduce_sum(const SRC_T* srcptr, int ldsrc, int row, int col, int blocksize, - float* reduce, int ldr) { +static inline BTLA_CODE col_block_reduce_sum(const SRC_T* srcptr, int ldsrc, int row, int col, int blocksize, + float* reduce, int ldr) { for (int i = 0; i < row; i++) { for (int j = 0; j < col; j += blocksize) { auto tmp = 0.f; @@ -1268,34 +1266,34 @@ static inline JBLAS_CODE col_block_reduce_sum(const SRC_T* srcptr, int ldsrc, in reduce[i * ldr + j / blocksize] = tmp; } } - return JblasSuccess; + return BTLA_CODE::Success; } -static inline JBLAS_CODE remove_act_zeropoint_bias(float* accptr, int ldacc, int row, int col, uint8_t* zps, - float* scales, int lds, const float* reduce) { +static inline BTLA_CODE remove_act_zeropoint_bias(float* accptr, int ldacc, int row, int col, uint8_t* zps, + float* scales, int lds, const float* reduce) { for (int i = 0; i < row; i++) { auto zpf = static_cast(zps[i * lds]) * scales[i * lds]; for (int j = 0; j < col; j++) { accptr[i * ldacc + j] -= zpf * reduce[j]; } } - return JblasSuccess; + return BTLA_CODE::Success; } -static inline JBLAS_CODE remove_wei_zeropoint_bias(float* accptr, int ldacc, int row, int col, int8_t* zps, - float* scales, int lds, const float* reduce) { +static inline BTLA_CODE remove_wei_zeropoint_bias(float* accptr, int ldacc, int row, int col, int8_t* zps, + float* scales, int lds, const float* reduce) { for (int i = 0; i < row; i++) { auto reducef = reduce[i * lds]; for (int j = 0; j < col; j++) { accptr[i * ldacc + j] -= static_cast(zps[j]) * scales[j] * reducef; } } - return JblasSuccess; + return BTLA_CODE::Success; } -static inline JBLAS_CODE remove_zeropoint_bias(float* accptr, int ldacc, int row, int col, uint8_t* zpa, int8_t* zpb, - float* scalea, float* scaleb, int lds, int k, const float* reducea, - const float* reduceb) { +static inline BTLA_CODE remove_zeropoint_bias(float* accptr, int ldacc, int row, int col, uint8_t* zpa, int8_t* zpb, + float* scalea, float* scaleb, int lds, int k, const float* reducea, + const float* reduceb) { for (int i = 0; i < row; i++) { auto reduceaf = reducea[i * lds]; auto zpaf = static_cast(zpa[i * lds]) * scalea[i * lds]; @@ -1306,8 +1304,8 @@ static inline JBLAS_CODE remove_zeropoint_bias(float* accptr, int ldacc, int row accptr[i * ldacc + j] -= zpaf * zpbf * k; } } - return JblasSuccess; + return BTLA_CODE::Success; } } // namespace ref } // namespace kernel -} // namespace jblas +} // namespace bestla diff --git a/bestla/jblas/kernel_wrapper.h b/bestla/bestla/kernel_wrapper.h similarity index 71% rename from bestla/jblas/kernel_wrapper.h rename to bestla/bestla/kernel_wrapper.h index 742748e89..d0832bd82 100644 --- a/bestla/jblas/kernel_wrapper.h +++ b/bestla/bestla/kernel_wrapper.h @@ -16,28 +16,28 @@ #include #include -#include "jblas/jit_blas.h" -#include "jit_blas_utils.h" +#include "bestla.h" +#include "bestla_utils.h" #include "kernel_avx2.h" #include "kernel_avx512f.h" #include "kernel_avx512_bf16.h" #include "kernel_jit.h" #include "kernel_ref.h" -namespace jblas { +namespace bestla { namespace kernel { namespace wrapper { template class PaddingInterleaveMN { // M x N ===> N/NTile x M/RowPack x NTile x RowPack (leading dim stride = NTile * dststride) public: - template - static JBLAS_CODE forward(const T_SRC* src, T_DST* dst, int row, int col, int row_pad, int col_pad, int src_step, - int dst_step) { + template + static BTLA_CODE forward(const T_SRC* src, T_DST* dst, int row, int col, int row_pad, int col_pad, int src_step, + int dst_step) { if constexpr (utils::isa_base::avx512f) { const auto kern_ret = kernel::avx512f::padding_interleave_cvt::forward( src, dst, NTile, row, col, row_pad, col_pad, src_step, dst_step); - if (kern_ret != JblasNotSupport) return kern_ret; + if (kern_ret != BTLA_CODE::NotSupport) return kern_ret; } return ref::padding_interleave(src, dst, row, col, row_pad, col_pad, src_step, dst_step, NTile, RowPack); } @@ -47,9 +47,9 @@ template class RevertPaddingInterleaveMN { // M x N ===> N/NTile x M/RowPack x NTile x RowPack (leading dim stride = NTile * dststride) public: - template - static JBLAS_CODE forward(const T_SRC* src, T_DST* dst, int row, int col, int row_pad, int col_pad, int src_step, - int dst_step) { + template + static BTLA_CODE forward(const T_SRC* src, T_DST* dst, int row, int col, int row_pad, int col_pad, int src_step, + int dst_step) { return ref::revert_padding_interleave(src, dst, row, col, row_pad, col_pad, src_step, dst_step, NTile, RowPack); } }; @@ -59,14 +59,14 @@ class PaddingTransInterleaveMN { // row and cols are in terms of src // M x N ===> M/MTile x N/ColPack x MTile x ColPack (leading dim stride = MTile * dststride) public: - template - static JBLAS_CODE forward(const T_SRC* src, T_DST* dst, int row, int col, int row_pad, int col_pad, int src_step, - int dst_step) { + template + static BTLA_CODE forward(const T_SRC* src, T_DST* dst, int row, int col, int row_pad, int col_pad, int src_step, + int dst_step) { // Note: rows/cols and i/j are in terms of src if constexpr (utils::isa_base::avx512f) { const auto kern_ret = kernel::avx512f::padding_trans_interleave_cvt::forward( src, dst, MTile, row, col, row_pad, col_pad, src_step, dst_step); - if (kern_ret != JblasNotSupport) return kern_ret; + if (kern_ret != BTLA_CODE::NotSupport) return kern_ret; } return ref::padding_trans_interleave(src, dst, row, col, row_pad, col_pad, src_step, dst_step, MTile, ColPack); } @@ -74,14 +74,14 @@ class PaddingTransInterleaveMN { class Memcpy2D { public: - template - static JBLAS_CODE forward(const _SRC_T* srcptr, _DST_T* dstptr, int row, int col, int srcstep, int dststep, - void* const_elt_v = nullptr) { - auto ret = JblasNotSupport; + template + static BTLA_CODE forward(const _SRC_T* srcptr, _DST_T* dstptr, int row, int col, int srcstep, int dststep, + void* const_elt_v = nullptr) { + auto ret = BTLA_CODE::NotSupport; if constexpr (utils::isa_base::avx512f) { ret = kernel::jit::JitMemcpy2DAvx512f::forward<_SRC_T, _DST_T>(srcptr, dstptr, row, col, srcstep, dststep, const_elt_v); - if (ret == JblasSuccess) { + if (ret == BTLA_CODE::Success) { return ret; } } @@ -93,7 +93,7 @@ class Memcpy2D { if (col - align_col > 0) ret = kernel::ref::memcpy2d(srcptr + align_col, dstptr + align_col, row, (col - align_col) * sizeof(_SRC_T), srcstep * sizeof(_SRC_T), dststep * sizeof(_DST_T)); - if (ret == JblasSuccess) { + if (ret == BTLA_CODE::Success) { return ret; } } @@ -102,15 +102,15 @@ class Memcpy2D { dststep * sizeof(_DST_T)); } - template - static JBLAS_CODE forward1(const _SRC_T* srcptr, _DST_T* dstptr, int row, int col, int srcstep, int dststep, - void* const_elt_v = nullptr) { - auto ret = JblasNotSupport; + template + static BTLA_CODE forward1(const _SRC_T* srcptr, _DST_T* dstptr, int row, int col, int srcstep, int dststep, + void* const_elt_v = nullptr) { + auto ret = BTLA_CODE::NotSupport; #if CompileAVX512F() if constexpr (utils::isa_base::avx512f) { ret = kernel::jit::JitMemcpy2DAvx512f::forward1<_SRC_T, _DST_T, OP_T>(srcptr, dstptr, row, col, srcstep, dststep, const_elt_v); - if (ret == JblasSuccess) { + if (ret == BTLA_CODE::Success) { return ret; } } @@ -124,7 +124,7 @@ class Memcpy2D { ret = kernel::ref::memcpy2d_withop<_SRC_T, _DST_T, OP_T>( srcptr + align_col, dstptr + align_col, row, (col - align_col) * sizeof(_SRC_T), srcstep * sizeof(_SRC_T), dststep * sizeof(_DST_T), const_elt_v); - if (ret == JblasSuccess) { + if (ret == BTLA_CODE::Success) { return ret; } } @@ -135,9 +135,9 @@ class Memcpy2D { class Memcpy2DFp32CvtBf16 { public: - template - static JBLAS_CODE forward(const void* srcptr, void* dstptr, int row, int col, int srcstride, int dststride, - bool zeropadding) { + template + static BTLA_CODE forward(const void* srcptr, void* dstptr, int row, int col, int srcstride, int dststride, + bool zeropadding) { #if CompileBF16() if constexpr (utils::isa_base::amx_bf16) { return kernel::avx512_bf16::fp32_cvt_bf16_2D_write_back(srcptr, dstptr, row, col, srcstride, dststride, @@ -161,9 +161,9 @@ class Memcpy2DFp32CvtBf16 { class Memcpy2DFp32CvtFp16 { public: - template - static JBLAS_CODE forward(void* srcptr, void* dstptr, int row, int col, int srcstride, int dststride, - bool zeropadding) { + template + static BTLA_CODE forward(void* srcptr, void* dstptr, int row, int col, int srcstride, int dststride, + bool zeropadding) { #if CompileFP16() if constexpr (utils::isa_base::avx512_fp16) { return kernel::avx512f::fp32_cvt_fp16_2D_write_back( @@ -171,15 +171,15 @@ class Memcpy2DFp32CvtFp16 { srcstride / sizeof(float), dststride / sizeof(utils::fp16), zeropadding); } #endif - return JblasNotSupport; + return BTLA_CODE::NotSupport; } }; class Memcpy2DFp16CvtFp32 { public: - template - static JBLAS_CODE forward(void* srcptr, void* dstptr, int row, int col, int srcstride, int dststride, - bool zeropadding) { + template + static BTLA_CODE forward(void* srcptr, void* dstptr, int row, int col, int srcstride, int dststride, + bool zeropadding) { #if CompileFP16() if constexpr (utils::isa_base::avx512_fp16) { return kernel::avx512f::fp16_cvt_fp32_2D_write_back( // @@ -187,31 +187,31 @@ class Memcpy2DFp16CvtFp32 { srcstride / sizeof(utils::fp16), dststride / sizeof(float), zeropadding); } #endif - return JblasNotSupport; + return BTLA_CODE::NotSupport; } }; class Memcpy2DBf16CvtFp32 { public: - template - static JBLAS_CODE forward(void* srcptr, void* dstptr, int row, int col, int srcstride, int dststride, - bool zeropadding) { + template + static BTLA_CODE forward(void* srcptr, void* dstptr, int row, int col, int srcstride, int dststride, + bool zeropadding) { #if CompileBF16() - if constexpr (ISA_T >= JblasAMX_BF16) { + if constexpr (ISA_T >= BTLA_ISA::AMX_BF16) { return kernel::avx512_bf16::bf16_cvt_fp32_2D_write_back( // reinterpret_cast(srcptr), reinterpret_cast(dstptr), row, col, srcstride / sizeof(utils::bf16), dststride / sizeof(float), zeropadding); } #endif #if CompileAVX512F() - if constexpr (ISA_T >= JblasAVX512F) { + if constexpr (ISA_T >= BTLA_ISA::AVX512F) { return kernel::avx512f::bf16_cvt_fp32_2D_write_back( // reinterpret_cast(srcptr), reinterpret_cast(dstptr), row, col, srcstride / sizeof(utils::bf16), dststride / sizeof(float), zeropadding); } #endif #if CompileAVX2() - if constexpr (ISA_T >= JblasAVX2) { + if constexpr (ISA_T >= BTLA_ISA::AVX2) { return kernel::avx2::bf16_cvt_fp32_2D_write_back( reinterpret_cast(srcptr), reinterpret_cast(dstptr), row, col, srcstride / sizeof(utils::bf16), dststride / sizeof(float), zeropadding); @@ -225,9 +225,9 @@ class Memcpy2DBf16CvtFp32 { template class CompressS8S4 { public: - template - static inline JBLAS_CODE forward(const int8_t* srcptr, jblas::utils::int4x2* dstptr, int row, int col, int ld_src, - int ld_dst) { + template + static inline BTLA_CODE forward(const int8_t* srcptr, bestla::utils::int4x2* dstptr, int row, int col, int ld_src, + int ld_dst) { return ref::compress_s8_s4(srcptr, dstptr, row, col, ld_src, ld_dst); } }; @@ -235,9 +235,9 @@ class CompressS8S4 { template class CompressFp4 { public: - template - static inline JBLAS_CODE forward(const int8_t* srcptr, jblas::utils::f4x2* dstptr, int row, int col, int ld_src, - int ld_dst) { + template + static inline BTLA_CODE forward(const int8_t* srcptr, bestla::utils::f4x2* dstptr, int row, int col, int ld_src, + int ld_dst) { return ref::compress_f4(srcptr, dstptr, row, col, ld_src, ld_dst); } }; @@ -245,20 +245,20 @@ class CompressFp4 { template class Transpose2D { public: - template - static inline JBLAS_CODE forward(const _T* srcptr, _T* dstptr, int row, int col, int ld_src, int ld_dst) { + template + static inline BTLA_CODE forward(const _T* srcptr, _T* dstptr, int row, int col, int ld_src, int ld_dst) { return ref::transpose2d(srcptr, dstptr, row, col, ld_src, ld_dst); } }; class QuantizeSignIntRowBlock { public: - template - static inline JBLAS_CODE forward(const float* srcptr, int8_t* dstptr, int row, int col, int ld_src, int ld_dst, - float* scales, int8_t* zero_points, int blocksize) { + template + static inline BTLA_CODE forward(const float* srcptr, int8_t* dstptr, int row, int col, int ld_src, int ld_dst, + float* scales, int8_t* zero_points, int blocksize) { #if CompileAVX512F() if constexpr (utils::isa_base::avx512f && - S4_T != JBLAS_DTYPE::S4_FULLRANGE) { // TODO(zhe): support simd version s4_fullrange quantization. + S4_T != BTLA_DTYPE::S4_FULLRANGE) { // TODO(zhe): support simd version s4_fullrange quantization. return avx512f::quantize_f32_sign_int_rowblock(srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, blocksize); } @@ -270,9 +270,9 @@ class QuantizeSignIntRowBlock { class QuantizeF8RowBlock { public: - template - static inline JBLAS_CODE forward(const float* srcptr, int8_t* dstptr, int row, int col, int ld_src, int ld_dst, - float* scales, int blocksize, JBLAS_DTYPE scale_dtype) { + template + static inline BTLA_CODE forward(const float* srcptr, int8_t* dstptr, int row, int col, int ld_src, int ld_dst, + float* scales, int blocksize, BTLA_DTYPE scale_dtype) { return ref::quantize_f32_f8_rowblock_mxscale(srcptr, dstptr, row, col, ld_src, ld_dst, scales, blocksize, scale_dtype); } @@ -280,9 +280,9 @@ class QuantizeF8RowBlock { class QuantizeF4RowBlock { public: - template - static inline JBLAS_CODE forward(const float* srcptr, int8_t* dstptr, int row, int col, int ld_src, int ld_dst, - float* scales, int8_t* zero_points, int blocksize) { + template + static inline BTLA_CODE forward(const float* srcptr, int8_t* dstptr, int row, int col, int ld_src, int ld_dst, + float* scales, int8_t* zero_points, int blocksize) { #if CompileAVX512F() if constexpr (utils::isa_base::avx512f) { return avx512f::quantize_f32_f4_rowblock(srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, @@ -296,9 +296,9 @@ class QuantizeF4RowBlock { class QuantizeU8ColBlock { public: - template - static inline JBLAS_CODE forward(int row, int col, const SRC_T* srcptr, int ld_src, uint8_t* dstptr, int ld_dst, - float* scales, int ld_scale, uint8_t* zps, int blocksize, float* blkreduce) { + template + static inline BTLA_CODE forward(int row, int col, const SRC_T* srcptr, int ld_src, uint8_t* dstptr, int ld_dst, + float* scales, int ld_scale, uint8_t* zps, int blocksize, float* blkreduce) { #if CompileAVX512F() if constexpr (utils::isa_base::avx512f) { return avx512f::quantize_fp_u8_colblock(row, col, srcptr, ld_src, dstptr, ld_dst, scales, ld_scale, zps, @@ -318,9 +318,9 @@ class QuantizeU8ColBlock { class QuantizeS8ColBlock { public: - template - static inline JBLAS_CODE forward(int row, int col, const SRC_T* srcptr, int ld_src, int8_t* dstptr, int ld_dst, - float* scales, int ld_scale, int blocksize, float* reduce) { + template + static inline BTLA_CODE forward(int row, int col, const SRC_T* srcptr, int ld_src, int8_t* dstptr, int ld_dst, + float* scales, int ld_scale, int blocksize, float* reduce) { #if CompileAVX512F() if constexpr (utils::isa_base::avx512f) { return avx512f::quantize_fp_s8_colblock(row, col, srcptr, ld_src, dstptr, ld_dst, scales, ld_scale, @@ -333,8 +333,8 @@ class QuantizeS8ColBlock { class Broadcast { public: - template - static inline JBLAS_CODE forward(int num, const uint8_t& srcval, uint8_t* dstptr) { + template + static inline BTLA_CODE forward(int num, const uint8_t& srcval, uint8_t* dstptr) { #if CompileAVX512F() if constexpr (utils::isa_base::avx512f) { return avx512f::broadcast_u8(num, srcval, dstptr); @@ -346,9 +346,9 @@ class Broadcast { class AccumulateDequantizeS32F32 { public: - template - static inline JBLAS_CODE forward(const int32_t* srcptr, float* dstptr, float alpha, float beta, int row, int col, - int ld_src, int ld_dst, float* ascales, int ldas, float* wscales) { + template + static inline BTLA_CODE forward(const int32_t* srcptr, float* dstptr, float alpha, float beta, int row, int col, + int ld_src, int ld_dst, float* ascales, int ldas, float* wscales) { #if CompileAVX512F() if constexpr (utils::isa_base::avx512f) { return avx512f::accumulate_dequantize_s32_f32(srcptr, dstptr, alpha, beta, row, col, ld_src, ld_dst, ascales, @@ -363,17 +363,17 @@ class AccumulateDequantizeS32F32 { template // zero points always be int8_t, not compressed class DecompressKBlockS4Fp { public: - template - static inline JBLAS_CODE forward(utils::int4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, - _SCA_T* scales, int8_t* zero_points, int k_offset, int kblock, int NPad, void* tmp, - size_t tmpsize) { - JBLAS_CODE ret = JblasNotSupport; + template + static inline BTLA_CODE forward(utils::int4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, + _SCA_T* scales, int8_t* zero_points, int k_offset, int kblock, int NPad, void* tmp, + size_t tmpsize) { + BTLA_CODE ret = BTLA_CODE::NotSupport; #if CompileAVX512F() if constexpr (utils::isa_base::avx512f) { ret = avx512f::decompress_kblock_s4_fp( srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, k_offset, kblock, NPad, reinterpret_cast(tmp), tmpsize); - if (ret == JblasSuccess) return ret; + if (ret == BTLA_CODE::Success) return ret; } #endif #if CompileAVX2() @@ -407,7 +407,7 @@ class DecompressKBlockS4Fp { } } - if (ret == JblasSuccess) return ret; + if (ret == BTLA_CODE::Success) return ret; } #endif ret = ref::decompress_kblock_s4_fp(srcptr, dstptr, row, col, ld_src, ld_dst, @@ -420,10 +420,10 @@ class DecompressKBlockS4Fp { template // zero points always be int8_t, not compressed class DecompressKBlockS4S8Fp { public: - template - static inline JBLAS_CODE forward(utils::int4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, - void* tmp, size_t tmpsize) { - JBLAS_CODE ret = JblasNotSupport; + template + static inline BTLA_CODE forward(utils::int4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, + void* tmp, size_t tmpsize) { + BTLA_CODE ret = BTLA_CODE::NotSupport; #if CompileAVX512F() if constexpr (utils::isa_base::avx512f) { return avx512f::decompress_kblock_s4_s8fp(srcptr, dstptr, row, col, ld_src, ld_dst, @@ -442,16 +442,16 @@ class DecompressKBlockS4S8Fp { template class DecompressKBlockF4Fp { public: - template - static inline JBLAS_CODE forward(utils::f4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, - SCA_T* scales, int k_offset, int kblock, int NPad, void* tmp, size_t tmpsize) { - JBLAS_CODE ret = JblasNotSupport; + template + static inline BTLA_CODE forward(utils::f4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, + SCA_T* scales, int k_offset, int kblock, int NPad, void* tmp, size_t tmpsize) { + BTLA_CODE ret = BTLA_CODE::NotSupport; #if CompileAVX512F() if constexpr (utils::isa_base::avx512f) { ret = avx512f::decompress_kblock_f4_fp(srcptr, dstptr, row, col, ld_src, ld_dst, scales, k_offset, kblock, NPad, reinterpret_cast(tmp), tmpsize); - if (ret == JblasSuccess) return ret; + if (ret == BTLA_CODE::Success) return ret; } #endif #if CompileAVX2() @@ -459,7 +459,7 @@ class DecompressKBlockF4Fp { ret = avx2::decompress_kblock_f4_fp(srcptr, dstptr, row, col, ld_src, ld_dst, scales, k_offset, kblock, NPad, reinterpret_cast(tmp), tmpsize); - if (ret == JblasSuccess) return ret; + if (ret == BTLA_CODE::Success) return ret; } #endif return ref::decompress_kblock_f4_fp(srcptr, dstptr, row, col, ld_src, ld_dst, @@ -471,10 +471,10 @@ class DecompressKBlockF4Fp { template class DecompressKBlockF4FpNoscale { public: - template - static inline JBLAS_CODE forward(utils::f4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, - void* tmp, size_t tmpsize) { - JBLAS_CODE ret = JblasNotSupport; + template + static inline BTLA_CODE forward(utils::f4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, + void* tmp, size_t tmpsize) { + BTLA_CODE ret = BTLA_CODE::NotSupport; if constexpr (utils::isa_base::avx512f) { return avx512f::decompress_kblock_f4_fp_noscale(srcptr, dstptr, row, col, ld_src, ld_dst, reinterpret_cast(tmp), tmpsize); @@ -490,9 +490,9 @@ class DecompressKBlockF4FpNoscale { class DecompressKBlockS4S8 { public: - template - static inline JBLAS_CODE forward(utils::int4x2* srcptr, int8_t* dstptr, int row, int col, int ld_src, int ld_dst) { - if constexpr (utils::isa_base::avx512f && S4_T == JBLAS_DTYPE::S4_CLIP) { + template + static inline BTLA_CODE forward(utils::int4x2* srcptr, int8_t* dstptr, int row, int col, int ld_src, int ld_dst) { + if constexpr (utils::isa_base::avx512f && S4_T == BTLA_DTYPE::S4_CLIP) { return jit::decompress_s4_s8(srcptr, dstptr, row, col, ld_src, ld_dst); } #if CompileAVX512F() @@ -512,9 +512,9 @@ class DecompressKBlockS4S8 { template class DecompressKBlockF8FP { public: - template - static inline JBLAS_CODE forward(utils::f8* srcptr, DST_T* dstptr, int row, int col, int ld_src, int ld_dst, - SCA_T* scales, int k_offset, int kblock, int NPad, JBLAS_DTYPE src_f8_type) { + template + static inline BTLA_CODE forward(utils::f8* srcptr, DST_T* dstptr, int row, int col, int ld_src, int ld_dst, + SCA_T* scales, int k_offset, int kblock, int NPad, BTLA_DTYPE src_f8_type) { #if CompileAVX512F() if constexpr (utils::isa_base::avx512f) { return avx512f::decompress_kblock_f8_fp( @@ -535,10 +535,10 @@ class DecompressKBlockF8FP { template class DecompressKBlockS8Fp { public: - template - static inline JBLAS_CODE forward(int8_t* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, - SCA_T* scales, int8_t* zero_points, int k_offset, int kblock, int NPad, void* tmp, - size_t tmpsize) { + template + static inline BTLA_CODE forward(int8_t* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, + SCA_T* scales, int8_t* zero_points, int k_offset, int kblock, int NPad, void* tmp, + size_t tmpsize) { #if CompileAVX512F() if constexpr (utils::isa_base::avx512f && std::is_same_v) { // TODO Scale type support return jit::DequanKBlockS8Fp::forward_avx512f(srcptr, dstptr, row, col, ld_src, ld_dst, scales, @@ -561,9 +561,9 @@ class DecompressKBlockS8Fp { template class DecompressKBlockS8S8Fp { public: - template - static inline JBLAS_CODE forward(int8_t* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, void* tmp, - size_t tmpsize) { + template + static inline BTLA_CODE forward(int8_t* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, void* tmp, + size_t tmpsize) { if constexpr (utils::isa_base::avx512f) { // TODO Scale type support return avx512f::decompress_kblock_s8_s8fp<_DST_T>(srcptr, dstptr, row, col, ld_src, ld_dst); } @@ -577,9 +577,9 @@ class DecompressKBlockS8S8Fp { template class DecompressKBlockF8FpNoScale { public: - template - static inline JBLAS_CODE forward(utils::f8* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, - void* tmp, size_t tmpsize, JBLAS_DTYPE src_f8_t) { + template + static inline BTLA_CODE forward(utils::f8* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, + void* tmp, size_t tmpsize, BTLA_DTYPE src_f8_t) { #if CompileAVX512F() if constexpr (utils::isa_base::avx512f) { return avx512f::decompress_kblock_f8_fp( @@ -598,10 +598,10 @@ class DecompressKBlockF8FpNoScale { class AlphaBetaF32F32 { public: - template - static JBLAS_CODE forward(const float alpha, const float* srcptr, const int srcstep, const float beta, - const float* src1ptr, const int src1step, float* dstptr, const int dststep, const int M, - const int N) { + template + static BTLA_CODE forward(const float alpha, const float* srcptr, const int srcstep, const float beta, + const float* src1ptr, const int src1step, float* dstptr, const int dststep, const int M, + const int N) { #if CompileAVX512F() if constexpr (utils::isa_base::avx512f) { return avx512f::alphabeta_f32_f32(alpha, srcptr, srcstep, beta, src1ptr, src1step, dstptr, dststep, M, N); @@ -618,9 +618,9 @@ class AlphaBetaF32F32 { class CompFp32BlockScale { public: - template - static JBLAS_CODE forward(const SCA_T* alpha, const float* srcptr, const int srcstep, float* dstptr, - const int dststep, const int M, const int N) { + template + static BTLA_CODE forward(const SCA_T* alpha, const float* srcptr, const int srcstep, float* dstptr, const int dststep, + const int M, const int N) { #if CompileAVX512F() if constexpr (utils::isa_base::avx512f) { return avx512f::accum_alphaN_f32_f32(alpha, srcptr, srcstep, dstptr, dststep, M, N); @@ -635,9 +635,9 @@ class CompFp32BlockScale { class AccumulateFp32 { public: - template - static JBLAS_CODE forward(const float* srcptr, const int srcstep, float* dstptr, const int dststep, const int M, - const int N) { + template + static BTLA_CODE forward(const float* srcptr, const int srcstep, float* dstptr, const int dststep, const int M, + const int N) { #if CompileAVX512F() if constexpr (utils::isa_base::avx512f) { return avx512f::accum_f32_f32(srcptr, srcstep, dstptr, dststep, M, N); @@ -649,9 +649,9 @@ class AccumulateFp32 { class QuanOutS32U32 { public: - template - static JBLAS_CODE forward(const float alpha, const int32_t* srcptr, const int srcstep, uint8_t* dstptr, - const int dststep, const int M, const int N, float scaleSrc, float scaleDst, int zpDst) { + template + static BTLA_CODE forward(const float alpha, const int32_t* srcptr, const int srcstep, uint8_t* dstptr, + const int dststep, const int M, const int N, float scaleSrc, float scaleDst, int zpDst) { #if CompileAVX512F() if constexpr (utils::isa_base::avx512f) { return avx512f::quanout_s32_u32(alpha, srcptr, srcstep, dstptr, dststep, M, N, scaleSrc, scaleDst, zpDst); @@ -665,9 +665,9 @@ class QuanOutS32U32 { // scaleB per channel(N) class DequanS32Fp32 { public: - template - static JBLAS_CODE forward(const int32_t* srcptr, const int srcstep, float* dstptr, const int dststep, const int M, - const int N, const float* scaleA, const int ldsa, const SCAB_T* scaleB) { + template + static BTLA_CODE forward(const int32_t* srcptr, const int srcstep, float* dstptr, const int dststep, const int M, + const int N, const float* scaleA, const int ldsa, const SCAB_T* scaleB) { #if CompileAVX512F() if constexpr (utils::isa_base::avx512f) { return avx512f::dequant_s32_fp32(srcptr, srcstep, dstptr, dststep, M, N, scaleA, ldsa, scaleB); @@ -684,9 +684,9 @@ class DequanS32Fp32 { class MinMaxKBlock { public: - template - static inline JBLAS_CODE forward(const float* srcptr, int row, int col, int ld_src, float* minmaxptr, int ld_minmax, - int fsize_minmax, int blocksize) { + template + static inline BTLA_CODE forward(const float* srcptr, int row, int col, int ld_src, float* minmaxptr, int ld_minmax, + int fsize_minmax, int blocksize) { return ref::minmax_f32_kblock(srcptr, row, col, ld_src, minmaxptr, ld_minmax, fsize_minmax, blocksize); } }; @@ -694,9 +694,9 @@ class MinMaxKBlock { template class QuantS8RowReduceSum { public: - template - static inline JBLAS_CODE forward(const int8_t* srcptr, int ldsrc, const float* scales, const int8_t* zero_points, - int row, int col, _RT* reduce) { + template + static inline BTLA_CODE forward(const int8_t* srcptr, int ldsrc, const float* scales, const int8_t* zero_points, + int row, int col, _RT* reduce) { return ref::quant_s8_row_reduce_sum(srcptr, ldsrc, scales, zero_points, row, col, reduce); } }; @@ -704,17 +704,17 @@ class QuantS8RowReduceSum { template class RowReduceSum { public: - template - static inline JBLAS_CODE forward(const float* srcptr, int ldsrc, int row, int col, _RT* reduce) { + template + static inline BTLA_CODE forward(const float* srcptr, int ldsrc, int row, int col, _RT* reduce) { return ref::row_reduce_sum<_RT>(srcptr, ldsrc, row, col, reduce); } }; class ColBlockReduceSum { public: - template - static inline JBLAS_CODE forward(const SRC_T* srcptr, int ldsrc, int row, int col, int blocksize, float* reduce, - int ldr) { + template + static inline BTLA_CODE forward(const SRC_T* srcptr, int ldsrc, int row, int col, int blocksize, float* reduce, + int ldr) { if constexpr (utils::isa_base::avx512f) { return avx512f::col_block_reduce_sum(srcptr, ldsrc, row, col, blocksize, reduce, ldr); } @@ -727,9 +727,9 @@ class ColBlockReduceSum { class RemoveZeroPointBias { public: - template - static inline JBLAS_CODE forward_wei(float* accptr, int ldacc, int row, int col, int8_t* zps, float* scales, int lds, - const float* reduce) { + template + static inline BTLA_CODE forward_wei(float* accptr, int ldacc, int row, int col, int8_t* zps, float* scales, int lds, + const float* reduce) { #if CompileAVX512F() if constexpr (utils::isa_base::avx512f) { return avx512f::remove_wei_zeropoint_bias(accptr, ldacc, row, col, zps, scales, lds, reduce); @@ -742,9 +742,9 @@ class RemoveZeroPointBias { #endif return ref::remove_wei_zeropoint_bias(accptr, ldacc, row, col, zps, scales, lds, reduce); } - template - static inline JBLAS_CODE forward_act(float* accptr, int ldacc, int row, int col, uint8_t* zps, float* scales, int lds, - const float* reduce) { + template + static inline BTLA_CODE forward_act(float* accptr, int ldacc, int row, int col, uint8_t* zps, float* scales, int lds, + const float* reduce) { #if CompileAVX512F() if constexpr (utils::isa_base::avx512f) { return avx512f::remove_act_zeropoint_bias(accptr, ldacc, row, col, zps, scales, lds, reduce); @@ -757,10 +757,10 @@ class RemoveZeroPointBias { #endif return ref::remove_act_zeropoint_bias(accptr, ldacc, row, col, zps, scales, lds, reduce); } - template - static inline JBLAS_CODE forward_both(float* accptr, int ldacc, int row, int col, uint8_t* zpa, int8_t* zpb, - float* scalea, float* scaleb, int lds, int k, const float* reducea, - const float* reduceb) { + template + static inline BTLA_CODE forward_both(float* accptr, int ldacc, int row, int col, uint8_t* zpa, int8_t* zpb, + float* scalea, float* scaleb, int lds, int k, const float* reducea, + const float* reduceb) { #if CompileAVX512F() if constexpr (utils::isa_base::avx512f) { return avx512f::remove_zeropoint_bias(accptr, ldacc, row, col, zpa, zpb, scalea, scaleb, lds, k, reducea, @@ -778,4 +778,4 @@ class RemoveZeroPointBias { } // namespace wrapper } // namespace kernel -} // namespace jblas +} // namespace bestla diff --git a/bestla/bestla/ut/bestla.cpp b/bestla/bestla/ut/bestla.cpp new file mode 100644 index 000000000..dccf567e2 --- /dev/null +++ b/bestla/bestla/ut/bestla.cpp @@ -0,0 +1 @@ +#include "../bestla.h" diff --git a/bestla/jblas/ut/jit_blas_epilogue.cpp b/bestla/bestla/ut/bestla_epilogue.cpp similarity index 68% rename from bestla/jblas/ut/jit_blas_epilogue.cpp rename to bestla/bestla/ut/bestla_epilogue.cpp index b86213e42..3d83035c3 100644 --- a/bestla/jblas/ut/jit_blas_epilogue.cpp +++ b/bestla/bestla/ut/bestla_epilogue.cpp @@ -1,7 +1,7 @@ -#include "../jit_blas_epilogue.h" -#include "jit_blas_ut.h" +#include "bestla_epilogue.h" +#include "bestla_ut.h" -namespace jblas { +namespace bestla { using namespace utils; namespace ut { class UT_AccumulatorWriteBack { @@ -9,67 +9,67 @@ class UT_AccumulatorWriteBack { UT_AccumulatorWriteBack() { UT_START(); CheckISA(AVX2); - fp32ut(127, 255, 0, 0, 127, 255); - fp32ut(101, 237, 10, 63, 30, 33); - fp32ut_with_custom_gelu(15, 15, 0, 0, 15, 15); - fp32ut_with_custom_swish(15, 15, 0, 0, 15, 15); - bf16ut(127, 255, 0, 0, 127, 255); - bf16ut(101, 237, 10, 63, 30, 33); - bf16fp32ut(101, 237, 10, 63, 30, 33); - bf16fp32ut(127, 255, 0, 0, 127, 255); + fp32ut(127, 255, 0, 0, 127, 255); + fp32ut(101, 237, 10, 63, 30, 33); + fp32ut_with_custom_gelu(15, 15, 0, 0, 15, 15); + fp32ut_with_custom_swish(15, 15, 0, 0, 15, 15); + bf16ut(127, 255, 0, 0, 127, 255); + bf16ut(101, 237, 10, 63, 30, 33); + bf16fp32ut(101, 237, 10, 63, 30, 33); + bf16fp32ut(127, 255, 0, 0, 127, 255); CheckISA(AVX512F); - fp32ut(127, 255, 0, 0, 127, 255); - fp32ut(101, 237, 10, 63, 30, 33); - bf16ut(127, 255, 0, 0, 127, 255); - bf16ut(101, 237, 10, 63, 30, 33); - fp32ut_with_custom_gelu(15, 15, 0, 0, 15, 15); - fp32ut_with_custom_swish(15, 15, 0, 0, 15, 15); + fp32ut(127, 255, 0, 0, 127, 255); + fp32ut(101, 237, 10, 63, 30, 33); + bf16ut(127, 255, 0, 0, 127, 255); + bf16ut(101, 237, 10, 63, 30, 33); + fp32ut_with_custom_gelu(15, 15, 0, 0, 15, 15); + fp32ut_with_custom_swish(15, 15, 0, 0, 15, 15); - bf16fp32ut(101, 237, 10, 63, 30, 33); - bf16fp32ut(127, 255, 0, 0, 127, 255); + bf16fp32ut(101, 237, 10, 63, 30, 33); + bf16fp32ut(127, 255, 0, 0, 127, 255); } - template + template void bf16fp32ut(int _M, int _N, int _M_offset, int _N_offset, int _cpy_M, int _cpy_N) { printf("Test Case %s %d %d %d %d %d %d\n", __FUNCTION__, _M, _N, _M_offset, _N_offset, _cpy_M, _cpy_N); std::vector src(_M * _N); for (int i = 0; i < _M * _N; i++) src[i].fromfloat(i); std::vector dstref(_M * _N, 0), dstker(_M * _N, 0); epilogue::gemm::AccumulatorWriteBackBf16Fp32<_RT_ISA_T> ker; - epilogue::gemm::AccumulatorWriteBackBf16Fp32 kerref; + epilogue::gemm::AccumulatorWriteBackBf16Fp32 kerref; kerref.forward(src.data(), _N, _M_offset, _N_offset, _cpy_M, _cpy_N, {dstref.data(), _N}, cache, CacheSize); ker.forward(src.data(), _N, _M_offset, _N_offset, _cpy_M, _cpy_N, {dstker.data(), _N}, cache, CacheSize); - jblas::ut::buffer_error(dstref.data(), dstker.data(), dstref.size()); + ut::buffer_error(dstref.data(), dstker.data(), dstref.size()); } - template + template void bf16ut(int _M, int _N, int _M_offset, int _N_offset, int _cpy_M, int _cpy_N) { printf("Test Case %s %d %d %d %d %d %d\n", __FUNCTION__, _M, _N, _M_offset, _N_offset, _cpy_M, _cpy_N); std::vector src(_M * _N); for (int i = 0; i < _M * _N; i++) src[i] = float(i); std::vector dstref(_M * _N, 0), dstker(_M * _N, 0); epilogue::gemm::AccumulatorWriteBackFp32Bf16<_RT_ISA_T> ker; - epilogue::gemm::AccumulatorWriteBackFp32Bf16 kerref; + epilogue::gemm::AccumulatorWriteBackFp32Bf16 kerref; - kerref.forward(src.data(), _N, _M_offset, _N_offset, _cpy_M, _cpy_N, - {reinterpret_cast(dstref.data()), _N}, cache, CacheSize); - ker.forward(src.data(), _N, _M_offset, _N_offset, _cpy_M, _cpy_N, - {reinterpret_cast(dstker.data()), _N}, cache, CacheSize); - jblas::ut::buffer_error(dstref.data(), dstker.data(), dstref.size()); + kerref.forward(src.data(), _N, _M_offset, _N_offset, _cpy_M, _cpy_N, {reinterpret_cast(dstref.data()), _N}, + cache, CacheSize); + ker.forward(src.data(), _N, _M_offset, _N_offset, _cpy_M, _cpy_N, {reinterpret_cast(dstker.data()), _N}, + cache, CacheSize); + ut::buffer_error(dstref.data(), dstker.data(), dstref.size()); } - template + template void fp32ut(int _M, int _N, int _M_offset, int _N_offset, int _cpy_M, int _cpy_N) { printf("Test Case %s %d %d %d %d %d %d\n", __FUNCTION__, _M, _N, _M_offset, _N_offset, _cpy_M, _cpy_N); std::vector src(_M * _N); for (int i = 0; i < _M * _N; i++) src[i] = float(i); std::vector dstref(_M * _N, 0), dstker(_M * _N, 0); epilogue::gemm::AccumulatorWriteBackFp32<_RT_ISA_T> ker; - epilogue::gemm::AccumulatorWriteBackFp32 kerref; + epilogue::gemm::AccumulatorWriteBackFp32 kerref; kerref.forward(src.data(), _N, _M_offset, _N_offset, _cpy_M, _cpy_N, {dstref.data(), _N}, cache, CacheSize); ker.forward(src.data(), _N, _M_offset, _N_offset, _cpy_M, _cpy_N, {dstker.data(), _N}, cache, CacheSize); - jblas::ut::buffer_error(dstref.data(), dstker.data(), dstref.size()); + ut::buffer_error(dstref.data(), dstker.data(), dstref.size()); } - template + template void fp32ut_with_custom_gelu(int _M, int _N, int _M_offset, int _N_offset, int _cpy_M, int _cpy_N) { printf("Test Case %s %d %d %d %d %d %d\n", __FUNCTION__, _N, _M, _M_offset, _N_offset, _cpy_M, _cpy_N); std::vector src(_M * _N); @@ -81,9 +81,10 @@ class UT_AccumulatorWriteBack { return 0.5f * x * (1.f + tanhf(0.7978845834732056f * (x + 0.044714998453855515f * x * x * x))); }; for (int i = 0; i < _M * _N; i++) src[i] = gelu(src[i]); - jblas::ut::buffer_error(src.data(), dstker.data(), dstker.size(), 0.000001f); + ut::buffer_error(src.data(), dstker.data(), dstker.size(), 0.000001f); } - template + + template void fp32ut_with_custom_swish(int _M, int _N, int _M_offset, int _N_offset, int _cpy_M, int _cpy_N) { printf("Test Case %s %d %d %d %d %d %d\n", __FUNCTION__, _N, _M, _M_offset, _N_offset, _cpy_M, _cpy_N); std::vector src(_M * _N); @@ -95,7 +96,7 @@ class UT_AccumulatorWriteBack { CacheSize); auto swish = [&](float x) { return x / (1 + exp(-x)); }; for (int i = 0; i < _M * _N; i++) src[i] = swish(src[i]); - jblas::ut::buffer_error(src.data(), dstker.data(), dstker.size(), 0.2f); // swish use low lprecision exp + ut::buffer_error(src.data(), dstker.data(), dstker.size(), 0.2f); // swish use low lprecision exp } }; @@ -123,18 +124,18 @@ class UT_AlphaBetaProcessFp32 { for (int i = 0; i < src.size(); i++) { src[i] = float(i); } - epilogue::gemm::AlphaBetaProcessFp32 kernref; - epilogue::gemm::AlphaBetaProcessFp32 kern0; + epilogue::gemm::AlphaBetaProcessFp32 kernref; + epilogue::gemm::AlphaBetaProcessFp32 kern0; kernref.forward(src.data(), _srcstep, 0, 0, _M, _N, {dstref.data(), src1.data(), _dststep, _src1step, alpha, beta}, cache, CacheSize); kern0.forward(src.data(), _srcstep, 0, 0, _M, _N, {dst.data(), src1.data(), _dststep, _src1step, alpha, beta}, cache, CacheSize); - jblas::ut::buffer_error(dstref.data(), dst.data(), dstref.size()); + ut::buffer_error(dstref.data(), dst.data(), dstref.size()); } }; -#ifdef JBLAS_UT_EPILOGUE +#ifdef BTLA_UT_EPILOGUE static UT_AccumulatorWriteBack sUT_AccumulatorWriteBack; static UT_AlphaBetaProcessFp32 sUT_AlphaBetaProcessFp32; #endif } // namespace ut -} // namespace jblas +} // namespace bestla diff --git a/bestla/jblas/ut/jit_blas_gemm.cpp b/bestla/bestla/ut/bestla_gemm.cpp similarity index 99% rename from bestla/jblas/ut/jit_blas_gemm.cpp rename to bestla/bestla/ut/bestla_gemm.cpp index 4b9559f5a..43d86eab0 100644 --- a/bestla/jblas/ut/jit_blas_gemm.cpp +++ b/bestla/bestla/ut/bestla_gemm.cpp @@ -1,9 +1,8 @@ -#include "../jit_blas_gemm.h" +#include "bestla_gemm.h" +#include "bestla_utils.h" +#include "bestla_ut.h" -#include "../jit_blas_utils.h" -#include "jit_blas_ut.h" - -namespace jblas { +namespace bestla { using namespace utils; template @@ -1115,4 +1114,4 @@ class UT_GEMM_AMXINT8 { static UT_GEMM_AMXINT8 sUT_GEMM_AMXINT8; #endif } // namespace ut -} // namespace jblas +} // namespace bestla diff --git a/bestla/jblas/ut/jit_blas_parallel.cpp b/bestla/bestla/ut/bestla_parallel.cpp similarity index 50% rename from bestla/jblas/ut/jit_blas_parallel.cpp rename to bestla/bestla/ut/bestla_parallel.cpp index b730da9fd..81e4eb899 100644 --- a/bestla/jblas/ut/jit_blas_parallel.cpp +++ b/bestla/bestla/ut/bestla_parallel.cpp @@ -1,10 +1,10 @@ -#include "../jit_blas_utils.h" -#include "../jit_blas_parallel.h" -#include "../jit_blas_device.h" -#include "../jit_blas_gemm.h" -#include "jit_blas_ut.h" -#include "../jit_blas_prologue_a.h" -namespace jblas { +#include "bestla_utils.h" +#include "bestla_parallel.h" +#include "bestla_device.h" +#include "bestla_gemm.h" +#include "bestla_ut.h" +#include "bestla_prologue_a.h" +namespace bestla { using namespace utils; namespace ut { #ifdef _OPENMP @@ -22,13 +22,14 @@ class UT_OMPThreading { avector src(row * col), dst(row * col), ref(row * col); fill_buffer_randn(src.data(), src.size(), -0.5f, 0.5f); int ld_src = col, ld_dst = row; - kernel::wrapper::Transpose2D::template forward(src.data(), ref.data(), row, col, col, row); - jblas::parallel::Scheduler2D _para({threads, row, col, 1, 1}); + kernel::wrapper::Transpose2D::template forward(src.data(), ref.data(), row, col, col, + row); + parallel::Scheduler2D _para({threads, row, col, 1, 1}); DefaultThreading.parallel_for([&](int tidx) { - jblas::parallel::ThreadProblem2D thdp{tidx}; + parallel::ThreadProblem2D thdp{tidx}; _para.getIndex(thdp); if (thdp.valid) { - kernel::wrapper::Transpose2D::template forward( + kernel::wrapper::Transpose2D::template forward( src.data() + thdp.loc[0] * ld_src + thdp.loc[1], dst.data() + thdp.loc[0] + thdp.loc[1] * ld_dst, thdp.size[0], thdp.size[1], ld_src, ld_dst); } @@ -55,13 +56,14 @@ class UT_StdThreading { avector src(row * col), dst(row * col), ref(row * col); fill_buffer_randn(src.data(), src.size(), -0.5f, 0.5f); int ld_src = col, ld_dst = row; - kernel::wrapper::Transpose2D::template forward(src.data(), ref.data(), row, col, col, row); - jblas::parallel::Scheduler2D _para({threads, row, col, 1, 1}); + kernel::wrapper::Transpose2D::template forward(src.data(), ref.data(), row, col, col, + row); + parallel::Scheduler2D _para({threads, row, col, 1, 1}); DefaultThreading.parallel_for([&](int tidx) { - jblas::parallel::ThreadProblem2D thdp{tidx}; + parallel::ThreadProblem2D thdp{tidx}; _para.getIndex(thdp); if (thdp.valid) { - kernel::wrapper::Transpose2D::template forward( + kernel::wrapper::Transpose2D::template forward( src.data() + thdp.loc[0] * ld_src + thdp.loc[1], dst.data() + thdp.loc[0] + thdp.loc[1] * ld_dst, thdp.size[0], thdp.size[1], ld_src, ld_dst); } @@ -85,10 +87,10 @@ class UT_Scheduler2D { void ut(int row, int col, int threads) { printf("%s %d %d %d\n", __FUNCTION__, row, col, threads); - jblas::parallel::Scheduler2D sch; + parallel::Scheduler2D sch; sch.update({threads, row, col, 1, 1}); sch.print(); - jblas::parallel::ThreadProblem2D prb{threads - 1}; + parallel::ThreadProblem2D prb{threads - 1}; sch.getIndex(prb); prb.print(); } @@ -101,22 +103,22 @@ class UT_SchedulerGemmBase { public: UT_SchedulerGemmBase() { UT_START(); - ut>(2048, 4096, 4096, 48, 2048 * 1024, 32 * 1024); - ut>(1, 4096, 4096, 24); - ut>(2048, 4096, 4096, 24); - ut>(4, 4096, 4096, 48); + ut>(2048, 4096, 4096, 48, 2048 * 1024, 32 * 1024); + ut>(1, 4096, 4096, 24); + ut>(2048, 4096, 4096, 24); + ut>(4, 4096, 4096, 48); } template void ut(int m, int n, int k, int threads, size_t l2cache = 0, size_t l1cache = 0) { - printf("%s %d %d %d %d core:%s\n", __FUNCTION__, m, n, k, threads, jblas::gemm::CoreAttr::to_str(GemmCore_T::ID)); - jblas::parallel::gemm::SchedulerBase sch; + printf("%s %d %d %d %d core:%s\n", __FUNCTION__, m, n, k, threads, gemm::CoreAttr::to_str(GemmCore_T::ID)); + parallel::gemm::SchedulerBase sch; GetCPUDevice(); utils::GemmProblem gp(1, m, n, k); sch.update( {threads, gp, l2cache == 0 ? _cd->getL2CacheSize() : l2cache, l1cache == 0 ? _cd->getL1CacheSize() : l1cache}); sch.print(); - jblas::parallel::gemm::ThreadProblemBase prb{sch.valid_theads() - 1}; + parallel::gemm::ThreadProblemBase prb{sch.valid_theads() - 1}; sch.getIndex(prb); prb.print(); } @@ -129,32 +131,32 @@ class UT_SchedulerGemmKBlock { public: UT_SchedulerGemmKBlock() { UT_START(); - ut>(1, 4096, 4096, 32, 24); - ut>(1, 4096, 4096, 64, 22, 32 * 1024); - ut>(1, 4096, 4096, 128, 24); - ut>(1, 4096, 4096, 1024, 24); - ut>(1, 4096, 4096, 64, 24, 32 * 1024); - ut>(2048, 4096, 4096, 64, 24); - ut>(2048, 4096, 4096, 4096, 24); - ut>(2048, 4096, 4096, 4096, 48); - ut>(2048, 4096, 4096, 4096, 56); - ut>(2048, 4096, 4096, 32, 56); - ut>(4, 4096, 4096, 128, 48); - ut>(4, 4096, 3072, 32, 48); - ut>(2048, 4096, 3072, 3072, 48); - ut>(2048, 4096, 3072, 32, 56); + ut>(1, 4096, 4096, 32, 24); + ut>(1, 4096, 4096, 64, 22, 32 * 1024); + ut>(1, 4096, 4096, 128, 24); + ut>(1, 4096, 4096, 1024, 24); + ut>(1, 4096, 4096, 64, 24, 32 * 1024); + ut>(2048, 4096, 4096, 64, 24); + ut>(2048, 4096, 4096, 4096, 24); + ut>(2048, 4096, 4096, 4096, 48); + ut>(2048, 4096, 4096, 4096, 56); + ut>(2048, 4096, 4096, 32, 56); + ut>(4, 4096, 4096, 128, 48); + ut>(4, 4096, 3072, 32, 48); + ut>(2048, 4096, 3072, 3072, 48); + ut>(2048, 4096, 3072, 32, 56); } template void ut(int m, int n, int k, int kblock, int threads, size_t l1cache = 0) { printf("%s %d %d %d %d %d core:%s\n", __FUNCTION__, m, n, k, kblock, threads, - jblas::gemm::CoreAttr::to_str(GemmCore_T::ID)); - jblas::parallel::gemm::SchedulerKBlock sch; + gemm::CoreAttr::to_str(GemmCore_T::ID)); + parallel::gemm::SchedulerKBlock sch; GetCPUDevice(); utils::GemmProblem gp(1, m, n, k, kblock); sch.update({threads, gp, _cd->getL2CacheSize(), l1cache == 0 ? _cd->getL1CacheSize() : l1cache}); sch.print(); - jblas::parallel::gemm::ThreadProblemBase prb{sch.valid_theads() - 1}; + parallel::gemm::ThreadProblemBase prb{sch.valid_theads() - 1}; sch.getIndex(prb); prb.print(); } @@ -167,34 +169,34 @@ class UT_SchedulerGemmKBlockNew { public: UT_SchedulerGemmKBlockNew() { UT_START(); - ut>(2011, 32000, 4096, 128, 32); - ut>(2048, 4096, 4096, 4096, 48); - ut>(2048, 4096, 4096, 128, 24); - ut>(2048, 4096, 4096, 4096, 24); - ut>(1, 4096, 4096, 32, 24); - ut>(1, 4096, 4096, 64, 22, 32 * 1024); - ut>(1, 4096, 4096, 128, 24); - ut>(1, 4096, 4096, 1024, 24); - ut>(1, 4096, 4096, 64, 24, 32 * 1024); - ut>(2048, 4096, 4096, 64, 24); - ut>(2048, 4096, 4096, 4096, 56); - ut>(2048, 4096, 4096, 32, 56); - ut>(4, 4096, 4096, 128, 48); - ut>(4, 4096, 3072, 32, 48); - ut>(2048, 4096, 3072, 3072, 48); - ut>(2048, 4096, 3072, 32, 56); + ut>(2011, 32000, 4096, 128, 32); + ut>(2048, 4096, 4096, 4096, 48); + ut>(2048, 4096, 4096, 128, 24); + ut>(2048, 4096, 4096, 4096, 24); + ut>(1, 4096, 4096, 32, 24); + ut>(1, 4096, 4096, 64, 22, 32 * 1024); + ut>(1, 4096, 4096, 128, 24); + ut>(1, 4096, 4096, 1024, 24); + ut>(1, 4096, 4096, 64, 24, 32 * 1024); + ut>(2048, 4096, 4096, 64, 24); + ut>(2048, 4096, 4096, 4096, 56); + ut>(2048, 4096, 4096, 32, 56); + ut>(4, 4096, 4096, 128, 48); + ut>(4, 4096, 3072, 32, 48); + ut>(2048, 4096, 3072, 3072, 48); + ut>(2048, 4096, 3072, 32, 56); } template void ut(int m, int n, int k, int kblock, int threads, size_t l1cache = 0) { printf("%s %d %d %d %d %d core:%s\n", __FUNCTION__, m, n, k, kblock, threads, - jblas::gemm::CoreAttr::to_str(GemmCore_T::ID)); - jblas::parallel::gemm::SchedulerKBlockS sch; + gemm::CoreAttr::to_str(GemmCore_T::ID)); + parallel::gemm::SchedulerKBlockS sch; GetCPUDevice(); utils::GemmProblem gp(1, m, n, k, kblock); sch.update({threads, gp, _cd->getL2CacheSize(), l1cache == 0 ? _cd->getL1CacheSize() : l1cache}); sch.print(); - jblas::parallel::gemm::ThreadProblemBase prb{sch.valid_theads() - 1}; + parallel::gemm::ThreadProblemBase prb{sch.valid_theads() - 1}; sch.getIndex(prb); prb.print(); } @@ -203,4 +205,4 @@ class UT_SchedulerGemmKBlockNew { static UT_SchedulerGemmKBlockNew sUT_SchedulerGemmKBlockNew; #endif } // namespace ut -} // namespace jblas +} // namespace bestla diff --git a/bestla/jblas/ut/jit_blas_prologue_a.cpp b/bestla/bestla/ut/bestla_prologue_a.cpp similarity index 90% rename from bestla/jblas/ut/jit_blas_prologue_a.cpp rename to bestla/bestla/ut/bestla_prologue_a.cpp index 4af84f0ee..7cb4b6379 100644 --- a/bestla/jblas/ut/jit_blas_prologue_a.cpp +++ b/bestla/bestla/ut/bestla_prologue_a.cpp @@ -1,8 +1,8 @@ -#include "../jit_blas_prologue_a.h" -#include "jit_blas_ut.h" -#include "../kernel_avx512f.h" +#include "bestla_prologue_a.h" +#include "bestla_ut.h" +#include "kernel_avx512f.h" -namespace jblas { +namespace bestla { using namespace utils; namespace ut { class UT_ActivationBase { @@ -24,8 +24,8 @@ class UT_ActivationBase { for (int i = 0; i < src.size(); i++) { src[i] = static_cast(i); } - jblas::prologue_a::gemm::ActivationBase<_T, JblasNoSIMD> reorderref; - jblas::prologue_a::gemm::ActivationBase<_T, JblasAVX512F> reorderavx512; + prologue_a::gemm::ActivationBase<_T, BTLA_ISA::NoSIMD> reorderref; + prologue_a::gemm::ActivationBase<_T, BTLA_ISA::AVX512F> reorderavx512; auto dstrefptr = dstref.data(); auto dstptr = dst.data(); int dststride = 0; @@ -37,7 +37,7 @@ class UT_ActivationBase { } } }; -#ifdef JBLAS_UT_PROLOGUE_A +#ifdef BTLA_UT_PROLOGUE_A static UT_ActivationBase sUT_ActivationBase; #endif @@ -63,15 +63,15 @@ class UT_ActivationConverter { for (int i = 0; i < src.size(); i++) { src[i] = static_cast(float(i)); } - jblas::prologue_a::gemm::ActivationConverter<_T, JblasNoSIMD, SRC_T> reorderref; - jblas::prologue_a::gemm::ActivationConverter<_T, JblasAVX512F, SRC_T> reorderavx512; + prologue_a::gemm::ActivationConverter<_T, BTLA_ISA::NoSIMD, SRC_T> reorderref; + prologue_a::gemm::ActivationConverter<_T, BTLA_ISA::AVX512F, SRC_T> reorderavx512; auto dstptr = dstref.data(); int dststride = 0; auto ret = reorderref.getActivation(&dstptr, &dststride, {src.data(), lda}, m, k, 0, 0, cache, CacheSize); - assert(ret == JblasSuccess); + assert(ret == BTLA_CODE::Success); dstptr = dst.data(); ret = reorderavx512.getActivation(&dstptr, &dststride, {src.data(), lda}, m, k, 0, 0, cache, CacheSize); - assert(ret == JblasSuccess); + assert(ret == BTLA_CODE::Success); ut::buffer_error(dst.data(), dstref.data(), dst.size(), AType{0}); aligned_vector revert(dst.size()); for (size_t i = 0; i < revert.size(); i++) { @@ -84,7 +84,7 @@ class UT_ActivationConverter { buffer_error_2d(src.data(), revert.data(), m, k, lda, kpad); } }; -#ifdef JBLAS_UT_PROLOGUE_A +#ifdef BTLA_UT_PROLOGUE_A static UT_ActivationConverter sUT_ActivationConverter; #endif @@ -114,8 +114,8 @@ class UT_ActivationU8KBlockQuantize { template void ut(int m, int k, int lda, int kblock, bool hasreduce = false) { int kpad = padto(k, _T::KTILE); - printf("Test Case core:%s: %d %d %d %d %d reduce:%d\n", jblas::gemm::CoreAttr::to_str(_T::ID), m, k, lda, kblock, - kpad, hasreduce); + printf("Test Case core:%s: %d %d %d %d %d reduce:%d\n", gemm::CoreAttr::to_str(_T::ID), m, k, lda, kblock, kpad, + hasreduce); int kcount = updiv(kpad, kblock); utils::aligned_vector raw(m * lda), scales(m * kcount); ut::fill_buffer_randn(raw.data(), raw.size(), -0.5f, 0.5f); @@ -126,7 +126,7 @@ class UT_ActivationU8KBlockQuantize { kernel::ref::quantize_fp_u8_colblock(m, k, raw.data(), lda, q.data(), lda, scales.data(), kcount, zp.data(), kblock, hasreduce ? reduce.data() : nullptr); - jblas::prologue_a::gemm::ActivationF32KBlockQuantize<_T, _T::ISA> actA; + prologue_a::gemm::ActivationF32KBlockQuantize<_T, _T::ISA> actA; auto quanAct = actA.createStorage(m, k, kblock, hasreduce); avector bufA(quanAct.mSize); quanAct.assign(bufA.data()); @@ -149,7 +149,7 @@ class UT_ActivationU8KBlockQuantize { } } }; -#ifdef JBLAS_UT_PROLOGUE_A +#ifdef BTLA_UT_PROLOGUE_A static UT_ActivationU8KBlockQuantize sUT_ActivationU8KBlockQuantize; #endif @@ -181,7 +181,7 @@ class UT_ActivationS8KBlockQuantize { q.resize(m * lda); kernel::ref::quantize_fp_s8_colblock(m, k, raw.data(), k, q.data(), lda, scales.data(), kcount, kblock, hasreduce ? reduce.data() : nullptr); - jblas::prologue_a::gemm::ActivationF32KBlockQuantize<_T, JblasAVX512F> actA; + prologue_a::gemm::ActivationF32KBlockQuantize<_T, BTLA_ISA::AVX512F> actA; auto quanAct = actA.createStorage(m, k, kblock, hasreduce); avector bufA(quanAct.mSize); quanAct.assign(bufA.data()); @@ -201,7 +201,7 @@ class UT_ActivationS8KBlockQuantize { } } }; -#ifdef JBLAS_UT_PROLOGUE_A +#ifdef BTLA_UT_PROLOGUE_A static UT_ActivationS8KBlockQuantize sUT_ActivationS8KBlockQuantize; #endif @@ -227,7 +227,7 @@ class UT_ShuffleActivationKblock { indices[i] = i % 2 == 0 ? (i + 1) == indices.size() ? i : i + 1 : i - 1; } for (int i = 0; i < src.size(); i++) src[i] = static_cast<_SRC_T>(i); - jblas::prologue_a::gemm::ShuffleActivationKBlockBase kernel; + prologue_a::gemm::ShuffleActivationKBlockBase kernel; auto dstrefptr = dstref.data(); auto dstptr = dst.data(); int dststride = 0; @@ -236,7 +236,7 @@ class UT_ShuffleActivationKblock { reordA.assign(bufA.data()); kernel.preprocess({src.data(), k, nullptr, indices.data(), &reordA}, m, k, 32, &DefaultThreading); - kernel.getActivation(&dstptr, &dststride, {src.data(), k, nullptr, indices.data(), &reordA}, m, kpad, 0, 0, cache, + kernel.getActivation(&dstptr, &dststride, {src.data(), k, nullptr, indices.data(), &reordA}, m, kpad, 0, 0, cache, CacheSize); for (int i = 0; i < m; i++) { int j = 0; @@ -265,7 +265,7 @@ class UT_ShuffleActivationKblock { kernel::ref::shuffle_activation(raw_cp.data(), raw.data(), m, k, 0, 0, indices.data(), k, k); kernel::ref::quantize_fp_s8_colblock(m, k, raw.data(), k, q.data(), lda, scales.data(), kcount, kblock, hasreduce ? reduce.data() : nullptr); - jblas::prologue_a::gemm::ShuffleActivationKBlockQuantize actA; + prologue_a::gemm::ShuffleActivationKBlockQuantize actA; auto quanAct = actA.createQuantStorage(m, k, kblock, hasreduce); auto reordAct = actA.createReorderStorage(m, k, kblock); avector bufA(quanAct.mSize + reordAct.mSize); @@ -287,8 +287,8 @@ class UT_ShuffleActivationKblock { } } }; -#ifdef JBLAS_UT_PROLOGUE_A +#ifdef BTLA_UT_PROLOGUE_A static UT_ShuffleActivationKblock sUT_ShuffleActivationKblock; #endif } // namespace ut -} // namespace jblas +} // namespace bestla diff --git a/bestla/jblas/ut/jit_blas_prologue_b.cpp b/bestla/bestla/ut/bestla_prologue_b.cpp similarity index 54% rename from bestla/jblas/ut/jit_blas_prologue_b.cpp rename to bestla/bestla/ut/bestla_prologue_b.cpp index 8d1863165..3280e4e4f 100644 --- a/bestla/jblas/ut/jit_blas_prologue_b.cpp +++ b/bestla/bestla/ut/bestla_prologue_b.cpp @@ -1,13 +1,12 @@ -#include "../jit_blas_gemm.h" -#include "../jit_blas_prologue_b.h" -#include "../jit_blas_parallel.h" -#include "../jit_blas_device.h" -#include "../jit_blas_wrapper.h" -#include "jit_blas_ut.h" - -namespace jblas { +#include "bestla_gemm.h" +#include "bestla_prologue_b.h" +#include "bestla_parallel.h" +#include "bestla_device.h" +#include "bestla_wrapper.h" +#include "bestla_ut.h" + +namespace bestla { using namespace utils; -using namespace parallel; namespace ut { class UT_BlockQunatize_INT8 { public: @@ -63,10 +62,10 @@ class UT_BlockQunatize_INT8 { } } - auto constexpr RuntimeISA = JblasAVX512F; - using PrologueB = jblas::prologue_b::gemm::WeightKBlockNInteger, RuntimeISA>; + auto constexpr RuntimeISA = BTLA_ISA::AVX512F; + using PrologueB = prologue_b::gemm::WeightKBlockNInteger, RuntimeISA>; PrologueB kernel; - auto ptr = kernel.createStorage(n, k, blocksize, JBLAS_DTYPE::S8, jblas_dtype, jblas_dtype, asym); + auto ptr = kernel.createStorage(n, k, blocksize, BTLA_DTYPE::S8, bestla_dtype, bestla_dtype, asym); avector buffer(ptr.mSize); ptr.assign(buffer.data()); kernel.packWeight(n, k, dequanRef.data(), ldb, &ptr, &DefaultThreading); @@ -113,10 +112,10 @@ class UT_BlockQunatize_INT8 { } } - auto constexpr RuntimeISA = JblasAVX512F; - using PrologueB = jblas::prologue_b::gemm::WeightKBlockNInteger, RuntimeISA>; + auto constexpr RuntimeISA = BTLA_ISA::AVX512F; + using PrologueB = prologue_b::gemm::WeightKBlockNInteger, RuntimeISA>; PrologueB kernel; - auto ptr = kernel.createStorage(n, k, blocksize, JBLAS_DTYPE::S8, jblas_dtype, jblas_dtype, asym); + auto ptr = kernel.createStorage(n, k, blocksize, BTLA_DTYPE::S8, bestla_dtype, bestla_dtype, asym); avector buffer(ptr.mSize); ptr.assign(buffer.data()); kernel.packTransposeWeight(n, k, dequanT.data(), k, &ptr, &DefaultThreading); @@ -130,7 +129,7 @@ class UT_BlockQunatize_INT8 { ut::buffer_error(dequanRef.data(), dequant.data(), dequanRef.size(), 0.01f); } }; -#ifdef JBLAS_UT_PROLOGUE_B +#ifdef BTLA_UT_PROLOGUE_B static UT_BlockQunatize_INT8 sUT_BlockQunatize_INT8; #endif @@ -139,23 +138,23 @@ class UT_BlockQunatize_F8 { UT_BlockQunatize_F8() { UT_START(); CheckISA(AVX512F); - ut(127, 1023, 32, JBLAS_DTYPE::F8_E4M3); - ut(127, 1023, 32, JBLAS_DTYPE::F8_E5M2); + ut(127, 1023, 32, BTLA_DTYPE::F8_E4M3); + ut(127, 1023, 32, BTLA_DTYPE::F8_E5M2); } - void ut(int n, int k, int blocksize, JBLAS_DTYPE QUANT_T) { + void ut(int n, int k, int blocksize, BTLA_DTYPE QUANT_T) { printf("%s: %d %d %d\n", __FUNCTION__, n, k, blocksize); int ldb = n; utils::aligned_vector raw(n * k); ut::fill_buffer_randn(raw.data(), raw.size(), -3.f, 3.f); - auto constexpr RuntimeISA = JblasAVX512F; - using PrologueB = jblas::prologue_b::gemm::WeightKBlockNFloat, RuntimeISA>; - using refPorB = jblas::prologue_b::gemm::WeightKBlockNFloat, JblasNoSIMD>; + auto constexpr RuntimeISA = BTLA_ISA::AVX512F; + using PrologueB = prologue_b::gemm::WeightKBlockNFloat, RuntimeISA>; + using refPorB = prologue_b::gemm::WeightKBlockNFloat, BTLA_ISA::NoSIMD>; PrologueB kernel; refPorB ref_ker; - auto ptr = kernel.createStorage(n, k, blocksize, QUANT_T, JBLAS_DTYPE::F8_E8M0); - auto ref_ptr = kernel.createStorage(n, k, blocksize, QUANT_T, JBLAS_DTYPE::F8_E8M0); + auto ptr = kernel.createStorage(n, k, blocksize, QUANT_T, BTLA_DTYPE::F8_E8M0); + auto ref_ptr = kernel.createStorage(n, k, blocksize, QUANT_T, BTLA_DTYPE::F8_E8M0); avector buffer(ptr.mSize); avector ref_buffer(ptr.mSize); ptr.assign(buffer.data()); @@ -169,7 +168,7 @@ class UT_BlockQunatize_F8 { ut::buffer_error(ref_dequant.data(), dequant.data(), dequant.size(), 0.01f); } }; -#ifdef JBLAS_UT_PROLOGUE_B +#ifdef BTLA_UT_PROLOGUE_B static UT_BlockQunatize_F8 sUT_BlockQunatize_F8; #endif @@ -178,17 +177,17 @@ class UT_TransposeBlockQuantize_F4 { UT_TransposeBlockQuantize_F4() { UT_START(); CheckISA(AVX512F); - ut(4096, 4096, 32, JBLAS_DTYPE::F4_BNB); - ut(1024, 4096, 32, JBLAS_DTYPE::F4_BNB); - ut(4096, 1024, 32, JBLAS_DTYPE::F4_BNB); - ut(48, 32, 32, JBLAS_DTYPE::F4_BNB); - ut(32, 32, 32, JBLAS_DTYPE::F4_BNB); - ut(48, 32, 32, JBLAS_DTYPE::F4_BNB); - ut(48, 32, 32, JBLAS_DTYPE::F4_NF4); - ut(48, 32, 32, JBLAS_DTYPE::F4_E2M1); + ut(4096, 4096, 32, BTLA_DTYPE::F4_BNB); + ut(1024, 4096, 32, BTLA_DTYPE::F4_BNB); + ut(4096, 1024, 32, BTLA_DTYPE::F4_BNB); + ut(48, 32, 32, BTLA_DTYPE::F4_BNB); + ut(32, 32, 32, BTLA_DTYPE::F4_BNB); + ut(48, 32, 32, BTLA_DTYPE::F4_BNB); + ut(48, 32, 32, BTLA_DTYPE::F4_NF4); + ut(48, 32, 32, BTLA_DTYPE::F4_E2M1); } - void ut(int n, int k, int blocksize, JBLAS_DTYPE F4_T) { + void ut(int n, int k, int blocksize, BTLA_DTYPE F4_T) { printf("Test Case: %d %d %d\n", n, k, blocksize); int ldb = n; utils::aligned_vector dequanRef(n * k); @@ -202,13 +201,13 @@ class UT_TransposeBlockQuantize_F4 { for (int j = 0; j < n; j++) { if (i % blocksize == 0) { switch (F4_T) { - case JBLAS_DTYPE::F4_E2M1: + case BTLA_DTYPE::F4_E2M1: quanW.data()[i * n + j] = 7; // make sure each block has maximum fp4e2m1 value(0b111) to quantize break; - case JBLAS_DTYPE::F4_BNB: + case BTLA_DTYPE::F4_BNB: quanW.data()[i * n + j] = 3; // make sure each block has maximum fp4bnb value(0b011) to quantize break; - case JBLAS_DTYPE::F4_NF4: + case BTLA_DTYPE::F4_NF4: quanW.data()[i * n + j] = 15; // make sure each block has maximum nf4 value(0b1111) to quantize break; default: @@ -220,23 +219,23 @@ class UT_TransposeBlockQuantize_F4 { for (int j = 0; j < k; j++) { for (int i = 0; i < n; i++) { switch (F4_T) { - case JBLAS_DTYPE::F4_E2M1: - dequanRef[j + i * k] = jblas::kernel::ref::f4_dequantize( - quanW.data()[j * ldb + i], scales[j / blocksize * n + i]); - quanW.data()[j * ldb + i] = jblas::kernel::ref::f4_quantize( - dequanRef[j + i * k] / scales[j / blocksize * n + i]); + case BTLA_DTYPE::F4_E2M1: + dequanRef[j + i * k] = kernel::ref::f4_dequantize(quanW.data()[j * ldb + i], + scales[j / blocksize * n + i]); + quanW.data()[j * ldb + i] = + kernel::ref::f4_quantize(dequanRef[j + i * k] / scales[j / blocksize * n + i]); break; - case JBLAS_DTYPE::F4_BNB: - dequanRef[j + i * k] = jblas::kernel::ref::f4_dequantize( - quanW.data()[j * ldb + i], scales[j / blocksize * n + i]); - quanW.data()[j * ldb + i] = jblas::kernel::ref::f4_quantize( - dequanRef[j + i * k] / scales[j / blocksize * n + i]); + case BTLA_DTYPE::F4_BNB: + dequanRef[j + i * k] = kernel::ref::f4_dequantize(quanW.data()[j * ldb + i], + scales[j / blocksize * n + i]); + quanW.data()[j * ldb + i] = + kernel::ref::f4_quantize(dequanRef[j + i * k] / scales[j / blocksize * n + i]); break; - case JBLAS_DTYPE::F4_NF4: - dequanRef[j + i * k] = jblas::kernel::ref::f4_dequantize( - quanW.data()[j * ldb + i], scales[j / blocksize * n + i]); - quanW.data()[j * ldb + i] = jblas::kernel::ref::f4_quantize( - dequanRef[j + i * k] / scales[j / blocksize * n + i]); + case BTLA_DTYPE::F4_NF4: + dequanRef[j + i * k] = kernel::ref::f4_dequantize(quanW.data()[j * ldb + i], + scales[j / blocksize * n + i]); + quanW.data()[j * ldb + i] = + kernel::ref::f4_quantize(dequanRef[j + i * k] / scales[j / blocksize * n + i]); break; default: break; @@ -244,11 +243,11 @@ class UT_TransposeBlockQuantize_F4 { } } - auto constexpr RuntimeISA = JblasAVX512F; - using PrologueB = jblas::prologue_b::gemm::WeightKBlockNFloat, RuntimeISA>; + auto constexpr RuntimeISA = BTLA_ISA::AVX512F; + using PrologueB = prologue_b::gemm::WeightKBlockNFloat, RuntimeISA>; PrologueB kernel; - auto packedW = kernel.createStorage(n, k, blocksize, F4_T, jblas_dtype); - auto packedW1 = kernel.createStorage(n, k, blocksize, F4_T, jblas_dtype); + auto packedW = kernel.createStorage(n, k, blocksize, F4_T, bestla_dtype); + auto packedW1 = kernel.createStorage(n, k, blocksize, F4_T, bestla_dtype); avector buf(packedW.mSize), buf1(packedW1.mSize); packedW.assign(buf.data()); packedW1.assign(buf1.data()); @@ -261,7 +260,7 @@ class UT_TransposeBlockQuantize_F4 { ut::buffer_error(dequanRef.data(), dequant.data(), dequant.size()); } }; -#ifdef JBLAS_UT_PROLOGUE_B +#ifdef BTLA_UT_PROLOGUE_B static UT_TransposeBlockQuantize_F4 sUT_TransposeBlockQuantize_F4; #endif @@ -271,14 +270,14 @@ class UT_BlockQuantize_INT4 { UT_START(); CheckISA(AVX2); CheckISA(AVX512F); - ut_2(4096, 4096, 128, JBLAS_DTYPE::S4_CLIP, false); - ut_2(4096, 4096, 128, JBLAS_DTYPE::S4_FULLRANGE, false); + ut_2(4096, 4096, 128, BTLA_DTYPE::S4_CLIP, false); + ut_2(4096, 4096, 128, BTLA_DTYPE::S4_FULLRANGE, false); CheckISA(AVX512F); - ut_512vnni(4096, 4096, 128, JBLAS_DTYPE::S4_CLIP, false); - ut_512vnni(4096, 4096, 128, JBLAS_DTYPE::S4_CLIP, true); - ut_512vnni(4096, 4096, 128, JBLAS_DTYPE::S4_FULLRANGE, false); + ut_512vnni(4096, 4096, 128, BTLA_DTYPE::S4_CLIP, false); + ut_512vnni(4096, 4096, 128, BTLA_DTYPE::S4_CLIP, true); + ut_512vnni(4096, 4096, 128, BTLA_DTYPE::S4_FULLRANGE, false); } - void ut_2(int n, int k, int blocksize, JBLAS_DTYPE qtype, bool asym = false) { + void ut_2(int n, int k, int blocksize, BTLA_DTYPE qtype, bool asym = false) { printf("Test Case: %d %d %d %s\n", n, k, blocksize, asym ? "asym" : "sym"); int ldb = n; int kblk_num = utils::updiv(k, blocksize); @@ -304,13 +303,13 @@ class UT_BlockQuantize_INT4 { } } - auto constexpr RuntimeISA = JblasAVX2; - using PrologueB = jblas::prologue_b::gemm::WeightKBlockNInteger, JblasAVX2>; - using PrologueB512 = jblas::prologue_b::gemm::WeightKBlockNInteger, JblasAVX512F>; + auto constexpr RuntimeISA = BTLA_ISA::AVX2; + using PrologueB = prologue_b::gemm::WeightKBlockNInteger, BTLA_ISA::AVX2>; + using PrologueB512 = prologue_b::gemm::WeightKBlockNInteger, BTLA_ISA::AVX512F>; PrologueB kernel; PrologueB512 kernel512; utils::aligned_vector retW(n * k); - auto packedW = kernel.createStorage(n, k, blocksize, qtype, jblas_dtype, jblas_dtype, asym); + auto packedW = kernel.createStorage(n, k, blocksize, qtype, bestla_dtype, bestla_dtype, asym); avector buffer(packedW.mSize); packedW.assign(buffer.data()); kernel.packWeight(n, k, dequant.data(), ldb, &packedW, &DefaultThreading); @@ -320,7 +319,7 @@ class UT_BlockQuantize_INT4 { kernel512.unpackWeight(n, k, &packedW, unpack512f32.data(), n, &DefaultThreading); ut::buffer_error(unpackf32.data(), unpack512f32.data(), unpackf32.size(), 0.01f); } - void ut_512vnni(int n, int k, int blocksize, JBLAS_DTYPE qtype, bool asym = false) { + void ut_512vnni(int n, int k, int blocksize, BTLA_DTYPE qtype, bool asym = false) { printf("Test Case: %d %d %d %s\n", n, k, blocksize, asym ? "asym" : "sym"); int ldb = n; int kblk_num = utils::updiv(k, blocksize); @@ -347,13 +346,12 @@ class UT_BlockQuantize_INT4 { } } - auto constexpr RuntimeISA = JblasAVX512_VNNI; - using PrologueB = - jblas::prologue_b::gemm::WeightKBlockNInteger, RuntimeISA>; + auto constexpr RuntimeISA = BTLA_ISA::AVX512_VNNI; + using PrologueB = prologue_b::gemm::WeightKBlockNInteger, RuntimeISA>; PrologueB kernel; utils::aligned_vector retW(n * k); - auto packedW = kernel.createStorage(n, k, blocksize, qtype, jblas_dtype, jblas_dtype, asym); + auto packedW = kernel.createStorage(n, k, blocksize, qtype, bestla_dtype, bestla_dtype, asym); avector buffer(packedW.mSize); packedW.assign(buffer.data()); kernel.packWeight(n, k, dequant.data(), ldb, &packedW, &DefaultThreading); @@ -364,7 +362,7 @@ class UT_BlockQuantize_INT4 { ut::buffer_error(dequant.data(), unpackf32.data(), dequant.size(), err_thres); } }; -#ifdef JBLAS_UT_PROLOGUE_B +#ifdef BTLA_UT_PROLOGUE_B static UT_BlockQuantize_INT4 sUT_BlockQuantize_INT4; #endif @@ -373,21 +371,21 @@ class UT_StorageMemCheck { UT_StorageMemCheck() { UT_START(); CheckISA(AVX512F); - ut_s4(4096, 4096, 128, JBLAS_DTYPE::S4_CLIP); - ut_s4(4096, 4096, 32, JBLAS_DTYPE::S4_FULLRANGE, true); - ut_f4(4096, 4096, 32, JBLAS_DTYPE::F4_BNB); - ut_f4(4096, 4096, 32, JBLAS_DTYPE::F4_E2M1); + ut_s4(4096, 4096, 128, BTLA_DTYPE::S4_CLIP); + ut_s4(4096, 4096, 32, BTLA_DTYPE::S4_FULLRANGE, true); + ut_f4(4096, 4096, 32, BTLA_DTYPE::F4_BNB); + ut_f4(4096, 4096, 32, BTLA_DTYPE::F4_E2M1); } - void ut_s4(int n, int k, int blocksize, JBLAS_DTYPE qtype, bool asym = false) { + void ut_s4(int n, int k, int blocksize, BTLA_DTYPE qtype, bool asym = false) { printf("Test C type Case: %d %d %d %s\n", n, k, blocksize, asym ? "asym" : "sym"); int ldb = n; int kblk_num = utils::updiv(k, blocksize); - using GemmCore = jblas::gemm::SCoreRowNAvx512f<48, 8>; - using PrologueB = jblas::prologue_b::gemm::WeightKBlockNInteger; + using GemmCore = gemm::SCoreRowNAvx512f<48, 8>; + using PrologueB = prologue_b::gemm::WeightKBlockNInteger; PrologueB ProWei; - auto packedW = ProWei.createStorage(n, k, blocksize, qtype, jblas_dtype, jblas_dtype, asym); + auto packedW = ProWei.createStorage(n, k, blocksize, qtype, bestla_dtype, bestla_dtype, asym); avector buf0(packedW.mSize), buf1(packedW.mSize); packedW.assign(buf0.data()); storage::gemm::StorageWeightKBlockNInteger tmp(GemmCore::ID); @@ -396,15 +394,15 @@ class UT_StorageMemCheck { buffer_error(buf0.data(), buf1.data(), buf0.size()); } - void ut_f4(int n, int k, int blocksize, JBLAS_DTYPE qtype) { + void ut_f4(int n, int k, int blocksize, BTLA_DTYPE qtype) { printf("Test C type Case: %d %d %d\n", n, k, blocksize); int ldb = n; int kblk_num = utils::updiv(k, blocksize); - using GemmCore = jblas::gemm::HCoreRowNAmxbf16<64, 16>; - using PrologueB = jblas::prologue_b::gemm::WeightKBlockNFloat; + using GemmCore = gemm::HCoreRowNAmxbf16<64, 16>; + using PrologueB = prologue_b::gemm::WeightKBlockNFloat; PrologueB ProWei; - auto packedW = ProWei.createStorage(n, k, blocksize, qtype, jblas_dtype); + auto packedW = ProWei.createStorage(n, k, blocksize, qtype, bestla_dtype); avector buf0(packedW.mSize), buf1(packedW.mSize); packedW.assign(buf0.data()); storage::gemm::StorageWeightKBlockNFloat tmp(GemmCore::ID); @@ -413,7 +411,7 @@ class UT_StorageMemCheck { buffer_error(buf0.data(), buf1.data(), buf0.size()); } }; -#ifdef JBLAS_UT_PROLOGUE_B +#ifdef BTLA_UT_PROLOGUE_B static UT_StorageMemCheck sUT_StorageMemCheck; #endif @@ -423,18 +421,18 @@ class UT_ShuffleIndices { UT_START(); CheckISA(AVX2); // ut_file(); - ut_s4(4096, 4096, 32, JBLAS_DTYPE::S4_FULLRANGE, true); - ut_s4(4096, 4096, 128, JBLAS_DTYPE::S4_CLIP); + ut_s4(4096, 4096, 32, BTLA_DTYPE::S4_FULLRANGE, true); + ut_s4(4096, 4096, 128, BTLA_DTYPE::S4_CLIP); } - void ut_s4(int n, int k, int blocksize, JBLAS_DTYPE qtype, bool asym = false) { + void ut_s4(int n, int k, int blocksize, BTLA_DTYPE qtype, bool asym = false) { printf("Test C type Case: %d %d %d %s\n", n, k, blocksize, asym ? "asym" : "sym"); int ldb = n; int kblk_num = utils::updiv(k, blocksize); - using GemmCore = jblas::gemm::SCoreRowNAvx2<24, 4>; - using PrologueB = jblas::prologue_b::gemm::WeightKBlockNInteger; + using GemmCore = gemm::SCoreRowNAvx2<24, 4>; + using PrologueB = prologue_b::gemm::WeightKBlockNInteger; PrologueB ProWei; - auto packedW = ProWei.createStorage(n, k, blocksize, qtype, jblas_dtype, jblas_dtype, asym); + auto packedW = ProWei.createStorage(n, k, blocksize, qtype, bestla_dtype, bestla_dtype, asym); ProWei.enableShuffle(&packedW); avector groupindices(k, 0); auto groupsize = utils::updiv(k, blocksize); @@ -461,22 +459,21 @@ class UT_ShuffleIndices { int k = 4096; int blocksize = 32; bool constexpr blauncher = false; - auto qtype = JBLAS_DTYPE::S4_CLIP; + auto qtype = BTLA_DTYPE::S4_CLIP; bool asym = true; auto warray = ut::readFile2Buffer("src0_data.bin"); auto aarray = ut::readFile2Buffer("src1_data.bin"); auto oarray = ut::readFile2Buffer("tensor_data.bin"); auto refoarray = ut::readFile2Buffer("tensor_data_ref.bin"); auto wptr = storage::gemm::PackedWeightParser::deserialBuffer(warray.data()); - using GemmCore = jblas::gemm::SCoreRowNAvx512f<48, 8>; + using GemmCore = gemm::SCoreRowNAvx512f<48, 8>; auto wptr_ = reinterpret_cast(wptr); utils::GemmProblem gp(1, m, n, k, blocksize); avector output(m * n); if constexpr (blauncher) { using Launcher = - wrapper::gemm::LauncherBase; + wrapper::gemm::LauncherBase; static Launcher kernel; auto rordA = kernel.mProA.createReorderStorage(m, k, blocksize); avector bufA(rordA.mSize); @@ -486,10 +483,10 @@ class UT_ShuffleIndices { parallel::GemmRunWithA>(kernel, args, &DefaultThreading); } else { - using Launcher = wrapper::gemm::LauncherKBlock< - GemmCore::ISA, GemmCore, jblas::prologue_a::gemm::ShuffleActivationKBlockBaseF32, - jblas::prologue_b::gemm::WeightKBlockNInteger, jblas::epilogue::gemm::CompFp32BlockEpilogue, - jblas::epilogue::gemm::AccumulatorWriteBackFp32>; + using Launcher = + wrapper::gemm::LauncherKBlock; static Launcher kernel; auto rordA = kernel.mProA.createReorderStorage(m, k, blocksize); auto redA = kernel.mProA.createReduceStorage(m, k, blocksize); @@ -510,7 +507,7 @@ class UT_ShuffleIndices { delete wptr; } }; -#ifdef JBLAS_UT_PROLOGUE_B +#ifdef BTLA_UT_PROLOGUE_B static UT_ShuffleIndices sUT_ShuffleIndices; #endif @@ -526,119 +523,112 @@ class UT_CompFp32 { void ut_f8() { CheckISA(AVX2); - ut(2, 4096, 4096, 32, JBLAS_DTYPE::F8_E4M3); - ut(2, 4096, 4096, 32, JBLAS_DTYPE::F8_E4M3); - ut(2, 4096, 4096, 32, JBLAS_DTYPE::F8_E5M2); - ut(2, 4096, 4096, 32, JBLAS_DTYPE::F8_E5M2); + ut(2, 4096, 4096, 32, BTLA_DTYPE::F8_E4M3); + ut(2, 4096, 4096, 32, BTLA_DTYPE::F8_E4M3); + ut(2, 4096, 4096, 32, BTLA_DTYPE::F8_E5M2); + ut(2, 4096, 4096, 32, BTLA_DTYPE::F8_E5M2); CheckISA(AVX512F); - ut(2, 4096, 4096, 32, JBLAS_DTYPE::F8_E4M3); - ut(2, 4096, 4096, 32, JBLAS_DTYPE::F8_E4M3); - ut(2, 4096, 4096, 32, JBLAS_DTYPE::F8_E5M2); - ut(2, 4096, 4096, 32, JBLAS_DTYPE::F8_E5M2); + ut(2, 4096, 4096, 32, BTLA_DTYPE::F8_E4M3); + ut(2, 4096, 4096, 32, BTLA_DTYPE::F8_E4M3); + ut(2, 4096, 4096, 32, BTLA_DTYPE::F8_E5M2); + ut(2, 4096, 4096, 32, BTLA_DTYPE::F8_E5M2); } void ut_s4() { CheckISA(AVX2); - ut_int(2, 4096, 4096, 32, JBLAS_DTYPE::S4_CLIP, - JBLAS_DTYPE::F32, false); - ut_int(2, 4096, 4096, 128, JBLAS_DTYPE::S4_CLIP, - JBLAS_DTYPE::F32, false); - ut_int(2, 4096, 4096, -1, JBLAS_DTYPE::S4_CLIP, - JBLAS_DTYPE::F32, false); - ut_int(2, 4096, 4096, 32, JBLAS_DTYPE::S4_FULLRANGE, - JBLAS_DTYPE::F32, false); - ut_int(2, 4096, 4096, 128, JBLAS_DTYPE::S4_FULLRANGE, - JBLAS_DTYPE::F32, false); - ut_int(2, 4096, 4096, -1, JBLAS_DTYPE::S4_FULLRANGE, - JBLAS_DTYPE::F32, false); - ut_int(2, 4096, 4096, 32, JBLAS_DTYPE::S4_CLIP, - JBLAS_DTYPE::BF16, false); - ut_int(2, 4096, 4096, 32, JBLAS_DTYPE::S4_FULLRANGE, - JBLAS_DTYPE::BF16, false); + ut_int(2, 4096, 4096, 32, BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32, + false); + ut_int(2, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32, + false); + ut_int(2, 4096, 4096, -1, BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32, + false); + ut_int(2, 4096, 4096, 32, BTLA_DTYPE::S4_FULLRANGE, BTLA_DTYPE::F32, + false); + ut_int(2, 4096, 4096, 128, BTLA_DTYPE::S4_FULLRANGE, BTLA_DTYPE::F32, + false); + ut_int(2, 4096, 4096, -1, BTLA_DTYPE::S4_FULLRANGE, BTLA_DTYPE::F32, + false); + ut_int(2, 4096, 4096, 32, BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::BF16, + false); + ut_int(2, 4096, 4096, 32, BTLA_DTYPE::S4_FULLRANGE, BTLA_DTYPE::BF16, + false); CheckISA(AVX512F); - ut_int(2, 4096, 4096, 32, JBLAS_DTYPE::S4_CLIP, - JBLAS_DTYPE::F32, false); - ut_int(2, 4096, 4096, 128, JBLAS_DTYPE::S4_CLIP, - JBLAS_DTYPE::F32, false); - ut_int(2, 4096, 4096, -1, JBLAS_DTYPE::S4_CLIP, - JBLAS_DTYPE::F32, false); - ut_int(2, 4096, 4096, 32, JBLAS_DTYPE::S4_FULLRANGE, - JBLAS_DTYPE::F32, false); - ut_int(2, 4096, 4096, 128, JBLAS_DTYPE::S4_FULLRANGE, - JBLAS_DTYPE::F32, false); - ut_int(2, 4096, 4096, -1, JBLAS_DTYPE::S4_FULLRANGE, - JBLAS_DTYPE::F32, false); - ut_int(2, 4096, 4096, 32, JBLAS_DTYPE::S4_CLIP, - JBLAS_DTYPE::BF16, false); - ut_int(2, 4096, 4096, 32, JBLAS_DTYPE::S4_FULLRANGE, - JBLAS_DTYPE::BF16, false); + ut_int(2, 4096, 4096, 32, BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32, + false); + ut_int(2, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32, + false); + ut_int(2, 4096, 4096, -1, BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32, + false); + ut_int(2, 4096, 4096, 32, BTLA_DTYPE::S4_FULLRANGE, + BTLA_DTYPE::F32, false); + ut_int(2, 4096, 4096, 128, BTLA_DTYPE::S4_FULLRANGE, + BTLA_DTYPE::F32, false); + ut_int(2, 4096, 4096, -1, BTLA_DTYPE::S4_FULLRANGE, + BTLA_DTYPE::F32, false); + ut_int(2, 4096, 4096, 32, BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::BF16, + false); + ut_int(2, 4096, 4096, 32, BTLA_DTYPE::S4_FULLRANGE, + BTLA_DTYPE::BF16, false); } void ut_s8() { CheckISA(AVX2); - ut_int(2, 4096, 4096, 32, JBLAS_DTYPE::S8, JBLAS_DTYPE::BF16, - false); - ut_int(2, 4096, 4096, 32, JBLAS_DTYPE::S8, JBLAS_DTYPE::F32, - false); - ut_int(2, 4096, 4096, 128, JBLAS_DTYPE::S8, JBLAS_DTYPE::F32, - false); - ut_int(2, 4096, 4096, -1, JBLAS_DTYPE::S8, JBLAS_DTYPE::F32, - false); + ut_int(2, 4096, 4096, 32, BTLA_DTYPE::S8, BTLA_DTYPE::BF16, false); + ut_int(2, 4096, 4096, 32, BTLA_DTYPE::S8, BTLA_DTYPE::F32, false); + ut_int(2, 4096, 4096, 128, BTLA_DTYPE::S8, BTLA_DTYPE::F32, false); + ut_int(2, 4096, 4096, -1, BTLA_DTYPE::S8, BTLA_DTYPE::F32, false); CheckISA(AVX512F); - ut_int(2, 4096, 4096, 32, JBLAS_DTYPE::S8, - JBLAS_DTYPE::BF16, false); - ut_int(2, 4096, 4096, 32, JBLAS_DTYPE::S8, - JBLAS_DTYPE::F32, false); - ut_int(2, 4096, 4096, 128, JBLAS_DTYPE::S8, - JBLAS_DTYPE::F32, false); - ut_int(2, 4096, 4096, -1, JBLAS_DTYPE::S8, - JBLAS_DTYPE::F32, false); + ut_int(2, 4096, 4096, 32, BTLA_DTYPE::S8, BTLA_DTYPE::BF16, + false); + ut_int(2, 4096, 4096, 32, BTLA_DTYPE::S8, BTLA_DTYPE::F32, false); + ut_int(2, 4096, 4096, 128, BTLA_DTYPE::S8, BTLA_DTYPE::F32, + false); + ut_int(2, 4096, 4096, -1, BTLA_DTYPE::S8, BTLA_DTYPE::F32, false); } void ut_f4() { CheckISA(AVX2); - ut(2, 4096, 4096, 32, JBLAS_DTYPE::F4_BNB); - ut(2, 4096, 4096, -1, JBLAS_DTYPE::F4_BNB); - ut(2, 4096, 4096, 32, JBLAS_DTYPE::F4_E2M1); - ut(2, 4096, 4096, -1, JBLAS_DTYPE::F4_E2M1); - ut(2, 4096, 4096, 32, JBLAS_DTYPE::F4_NF4); - ut(2, 4096, 4096, -1, JBLAS_DTYPE::F4_NF4); - ut(2, 4096, 4096, 32, JBLAS_DTYPE::F4_BNB); - ut(2, 4096, 4096, 32, JBLAS_DTYPE::F4_E2M1); - ut(2, 4096, 4096, 32, JBLAS_DTYPE::F4_NF4); + ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_BNB); + ut(2, 4096, 4096, -1, BTLA_DTYPE::F4_BNB); + ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_E2M1); + ut(2, 4096, 4096, -1, BTLA_DTYPE::F4_E2M1); + ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_NF4); + ut(2, 4096, 4096, -1, BTLA_DTYPE::F4_NF4); + ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_BNB); + ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_E2M1); + ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_NF4); CheckISA(AVX512F); - ut(2, 4096, 4096, 32, JBLAS_DTYPE::F4_BNB); - ut(2, 4096, 4096, -1, JBLAS_DTYPE::F4_BNB); - ut(2, 4096, 4096, 32, JBLAS_DTYPE::F4_E2M1); - ut(2, 4096, 4096, -1, JBLAS_DTYPE::F4_E2M1); - ut(2, 4096, 4096, 32, JBLAS_DTYPE::F4_NF4); - ut(2, 4096, 4096, -1, JBLAS_DTYPE::F4_NF4); - ut(2, 4096, 4096, 32, JBLAS_DTYPE::F4_BNB); - ut(2, 4096, 4096, 32, JBLAS_DTYPE::F4_E2M1); - ut(2, 4096, 4096, 32, JBLAS_DTYPE::F4_NF4); + ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_BNB); + ut(2, 4096, 4096, -1, BTLA_DTYPE::F4_BNB); + ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_E2M1); + ut(2, 4096, 4096, -1, BTLA_DTYPE::F4_E2M1); + ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_NF4); + ut(2, 4096, 4096, -1, BTLA_DTYPE::F4_NF4); + ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_BNB); + ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_E2M1); + ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_NF4); } - template class Wei> - void ut_int(int m, int n, int k, int blocksize, JBLAS_DTYPE qtype, JBLAS_DTYPE stype, bool isAsym) { + template class Wei> + void ut_int(int m, int n, int k, int blocksize, BTLA_DTYPE qtype, BTLA_DTYPE stype, bool isAsym) { printf("Test Case %s: %d %d %d-%d type:%s core:%s scaletype:%s\n", __FUNCTION__, m, n, k, blocksize, - jblas_dtype_str(qtype), jblas::gemm::CoreAttr::to_str(GemmCore_T::ID), jblas_dtype_str(stype)); + bestla_dtype_str(qtype), gemm::CoreAttr::to_str(GemmCore_T::ID), bestla_dtype_str(stype)); auto constexpr ISA = GemmCore_T::ISA; - using Launcher = jblas::wrapper::gemm::LauncherKBlock< - ISA, GemmCore_T, prologue_a::gemm::ActivationKBlockBaseF32, jblas::prologue_b::gemm::WeightKBlockNInteger, - jblas::epilogue::gemm::CompFp32BlockEpilogue, jblas::epilogue::gemm::AccumulatorWriteBackFp32>; - using Parallel = jblas::parallel::gemm::SchedulerKBlock; + using Launcher = + wrapper::gemm::LauncherKBlock; + using Parallel = parallel::gemm::SchedulerKBlock; Launcher launcher; blocksize = blocksize == -1 ? k : blocksize; using WType = typename Wei::StorageWeight; WType packedw(0); - if constexpr (std::is_same_v, - jblas::prologue_b::gemm::WeightKBlockNInteger>) { - packedw = launcher.mProB.createStorage(n, k, blocksize, qtype, stype, jblas_dtype, isAsym); - } else if constexpr (std::is_same_v, - jblas::prologue_b::gemm::WeightKBlockNFloat>) { - packedw = launcher.mProB.createStorage(n, k, blocksize, qtype, stype, JBLAS_DTYPE::EleBitsUndef, false); + if constexpr (std::is_same_v, prologue_b::gemm::WeightKBlockNInteger>) { + packedw = launcher.mProB.createStorage(n, k, blocksize, qtype, stype, bestla_dtype, isAsym); + } else if constexpr (std::is_same_v, prologue_b::gemm::WeightKBlockNFloat>) { + packedw = launcher.mProB.createStorage(n, k, blocksize, qtype, stype); } utils::avector buffer(packedw.mSize); @@ -656,26 +646,38 @@ class UT_CompFp32 { {&packedw}, {packedw.template SPtr(), packedw.SDtype(), packedw.CStep()}, {matC.data(), n}}; - GemmRun(launcher, args, &DefaultThreading); - auto err = get_ut_err(qtype); + parallel::GemmRun(launcher, args, &DefaultThreading); + auto err = INT8_ERR; + auto dbits = bestla_dtype_bits(qtype); + auto type = bestla_dtype_type(qtype); + auto constexpr dtype_int = bestla_dtype_type(BTLA_DTYPE::TypeInt); + if (type == dtype_int) { + if (dbits == 8) { + err = INT8_ERR; + } else { + err = INT4_ERR; + } + } else { + err = FP4_ERR; + } buffer_error(refC.data(), matC.data(), refC.size(), err); buffer_error(refCupk.data(), matC.data(), refCupk.size(), 0.001f); } - template class Wei, typename Scale_T> - void ut(int m, int n, int k, int blocksize, JBLAS_DTYPE qtype) { + template class Wei, typename Scale_T> + void ut(int m, int n, int k, int blocksize, BTLA_DTYPE qtype) { printf("Test Case %s: %d %d %d-%d type:%s core:%s scaletype:%s\n", __FUNCTION__, m, n, k, blocksize, - jblas_dtype_str(qtype), jblas::gemm::CoreAttr::to_str(GemmCore_T::ID), type_str); + bestla_dtype_str(qtype), gemm::CoreAttr::to_str(GemmCore_T::ID), type_str); auto constexpr ISA = GemmCore_T::ISA; - using Launcher = jblas::wrapper::gemm::LauncherKBlock; - using Parallel = jblas::parallel::gemm::SchedulerKBlock; + using Launcher = + wrapper::gemm::LauncherKBlock; + using Parallel = parallel::gemm::SchedulerKBlock; Launcher launcher; blocksize = blocksize == -1 ? k : blocksize; using WType = typename Wei::StorageWeight; WType packedw(0); - packedw = launcher.mProB.createStorage(n, k, blocksize, qtype, jblas_dtype); + packedw = launcher.mProB.createStorage(n, k, blocksize, qtype, bestla_dtype); utils::avector buffer(packedw.mSize); packedw.assign(buffer.data()); @@ -692,16 +694,16 @@ class UT_CompFp32 { {&packedw}, {packedw.template SPtr(), packedw.SDtype(), packedw.CStep()}, {matC.data(), n}}; - GemmRun(launcher, args, &DefaultThreading); + parallel::GemmRun(launcher, args, &DefaultThreading); auto err = FP4_ERR; - if (qtype == JBLAS_DTYPE::F8_E5M2 || qtype == JBLAS_DTYPE::F8_E4M3) err = F8_ERR; + if (qtype == BTLA_DTYPE::F8_E5M2 || qtype == BTLA_DTYPE::F8_E4M3) err = F8_ERR; buffer_error(refC.data(), matC.data(), refC.size(), err); buffer_error(refCupk.data(), matC.data(), refCupk.size(), 0.001f); } }; -#ifdef JBLAS_UT_PROLOGUE_B +#ifdef BTLA_UT_PROLOGUE_B static UT_CompFp32 sUT_CompFp32; #endif @@ -716,56 +718,57 @@ class UTBenchmark_CompFp32 { } void ut_s4() { - benchmark_all(1, 4096, 4096, 128, JBLAS_DTYPE::S4_CLIP); - benchmark_all(1, 4096, 4096, 128, JBLAS_DTYPE::S4_CLIP); - // benchmark_all(2048, 4096, 4096, 128, JBLAS_DTYPE::S4_CLIP); - // benchmark_all(4096, 4096, 11008, 128, JBLAS_DTYPE::S4_CLIP); - // benchmark_all(2, 4096, 4096, 32, JBLAS_DTYPE::S4_FULLRANGE); - // benchmark_all(2, 4096, 4096, 128, JBLAS_DTYPE::S4_FULLRANGE); - // benchmark_all(2, 4096, 4096, -1, JBLAS_DTYPE::S4_FULLRANGE); - // benchmark_all(2, 4096, 4096, 32, JBLAS_DTYPE::S4_CLIP); - // benchmark_all(2, 4096, 4096, 32, - // JBLAS_DTYPE::S4_FULLRANGE); + benchmark_all(1, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP); + benchmark_all(1, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP); + // benchmark_all(2048, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP); + // benchmark_all(4096, 4096, 11008, 128, BTLA_DTYPE::S4_CLIP); + // benchmark_all(2, 4096, 4096, 32, BTLA_DTYPE::S4_FULLRANGE); + // benchmark_all(2, 4096, 4096, 128, BTLA_DTYPE::S4_FULLRANGE); + // benchmark_all(2, 4096, 4096, -1, BTLA_DTYPE::S4_FULLRANGE); + // benchmark_all(2, 4096, 4096, 32, BTLA_DTYPE::S4_CLIP); + // benchmark_all(2, 4096, 4096, 32, + // BTLA_DTYPE::S4_FULLRANGE); } // void ut_s8() { - // ut(2, 4096, 4096, 32, JBLAS_DTYPE::S8); - // ut(2, 4096, 4096, 128, JBLAS_DTYPE::S8); - // ut(2, 4096, 4096, -1, JBLAS_DTYPE::S8); - // ut(2, 4096, 4096, 32, JBLAS_DTYPE::S8); + // ut(2, 4096, 4096, 32, BTLA_DTYPE::S8); + // ut(2, 4096, 4096, 128, BTLA_DTYPE::S8); + // ut(2, 4096, 4096, -1, BTLA_DTYPE::S8); + // ut(2, 4096, 4096, 32, BTLA_DTYPE::S8); // } // void ut_f4() { - // ut(2, 4096, 4096, 32, JBLAS_DTYPE::F4_BNB); - // ut(2, 4096, 4096, -1, JBLAS_DTYPE::F4_BNB); - // ut(2, 4096, 4096, 32, JBLAS_DTYPE::F4_E2M1); - // ut(2, 4096, 4096, -1, JBLAS_DTYPE::F4_E2M1); - // ut(2, 4096, 4096, 32, JBLAS_DTYPE::F4_NF4); - // ut(2, 4096, 4096, -1, JBLAS_DTYPE::F4_NF4); - // ut(2, 4096, 4096, 32, JBLAS_DTYPE::F4_BNB); - // ut(2, 4096, 4096, 32, JBLAS_DTYPE::F4_E2M1); - // ut(2, 4096, 4096, 32, JBLAS_DTYPE::F4_NF4); + // ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_BNB); + // ut(2, 4096, 4096, -1, BTLA_DTYPE::F4_BNB); + // ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_E2M1); + // ut(2, 4096, 4096, -1, BTLA_DTYPE::F4_E2M1); + // ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_NF4); + // ut(2, 4096, 4096, -1, BTLA_DTYPE::F4_NF4); + // ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_BNB); + // ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_E2M1); + // ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_NF4); // } - template class Wei, typename Scale_T> + template class Wei, typename Scale_T> void benchmark(int m, int n, int k, int blocksize, int batch, float* A, float* B, float* C, float timems, int threads, - JBLAS_DTYPE qtype) { + BTLA_DTYPE qtype) { LOG_T log; - using Parallel = jblas::parallel::gemm::SchedulerBase; - using Launcher = jblas::wrapper::gemm::LauncherBase; + using Parallel = parallel::gemm::SchedulerBase; + using Launcher = wrapper::gemm::LauncherBase; Launcher kernel; DefaultThreading.set_threads(threads); - auto corestr = jblas::gemm::CoreAttr::to_str(Core_T::ID); + auto corestr = gemm::CoreAttr::to_str(Core_T::ID); utils::timer tm; using WType = typename Wei::StorageWeight; WType tmpB(0); if constexpr (std::is_same_v, - jblas::prologue_b::gemm::WeightKBlockNInteger>) { - tmpB = kernel.mProB.createStorage(n, k, blocksize, qtype, jblas_dtype, jblas_dtype, false); + prologue_b::gemm::WeightKBlockNInteger>) { + tmpB = kernel.mProB.createStorage(n, k, blocksize, qtype, bestla_dtype, bestla_dtype, false); + } else if constexpr (std::is_same_v, - jblas::prologue_b::gemm::WeightKBlockNFloat>) { - tmpB = kernel.mProB.createStorage(n, k, blocksize, qtype, jblas_dtype, JBLAS_DTYPE::EleBitsUndef, false); + prologue_b::gemm::WeightKBlockNFloat>) { + tmpB = kernel.mProB.createStorage(n, k, blocksize, qtype, bestla_dtype); } std::vector packBs(batch, 0); std::vector bufB(tmpB.mSize * batch); @@ -786,7 +789,7 @@ class UTBenchmark_CompFp32 { log.start(); GemmProblem gp(1, m, n, k); typename Launcher::Param args{gp, {A + i * m * k, k}, {&packBs[i]}, {C + i * m * n, n}}; - GemmRun(kernel, args, &DefaultThreading); + parallel::GemmRun(kernel, args, &DefaultThreading); if (log.stop()) { double flops = double(psize) / log.avg_val / 1e6; double band = double(memsize) / log.avg_val / 1e6; @@ -797,26 +800,27 @@ class UTBenchmark_CompFp32 { } } - template class Wei, typename Scale_T> + template class Wei, typename Scale_T> void benchmark_mem(int m, int n, int k, int blocksize, int batch, float* A, float* B, float* C, float timems, - int threads, JBLAS_DTYPE qtype) { + int threads, BTLA_DTYPE qtype) { LOG_T log; - using Parallel = jblas::parallel::gemm::SchedulerKBlock; - using Launcher = jblas::wrapper::gemm::LauncherKBlock; + using Parallel = parallel::gemm::SchedulerKBlock; + using Launcher = + wrapper::gemm::LauncherKBlock; Launcher kernel; DefaultThreading.set_threads(threads); - auto corestr = jblas::gemm::CoreAttr::to_str(Core_T::ID); + auto corestr = gemm::CoreAttr::to_str(Core_T::ID); utils::timer tm; using WType = typename Wei::StorageWeight; WType tmpB(0); if constexpr (std::is_same_v, - jblas::prologue_b::gemm::WeightKBlockNInteger>) { - tmpB = kernel.mProB.createStorage(n, k, blocksize, qtype, jblas_dtype, jblas_dtype, false); + prologue_b::gemm::WeightKBlockNInteger>) { + tmpB = kernel.mProB.createStorage(n, k, blocksize, qtype, bestla_dtype, bestla_dtype, false); + } else if constexpr (std::is_same_v, - jblas::prologue_b::gemm::WeightKBlockNFloat>) { - tmpB = kernel.mProB.createStorage(n, k, blocksize, qtype, jblas_dtype, JBLAS_DTYPE::EleBitsUndef, false); + prologue_b::gemm::WeightKBlockNFloat>) { + tmpB = kernel.mProB.createStorage(n, k, blocksize, qtype, bestla_dtype); } std::vector packBs(batch, 0); std::vector bufB(tmpB.mSize * batch); @@ -841,7 +845,7 @@ class UTBenchmark_CompFp32 { {&packBs[i]}, {packBs[i].template SPtr(), packBs[i].SDtype(), packBs[i].CStep()}, {C + i * m * n, n}}; - GemmRun(kernel, args, &DefaultThreading); + parallel::GemmRun(kernel, args, &DefaultThreading); } if (log.stop()) { double t = log.avg_val / batch; @@ -853,8 +857,8 @@ class UTBenchmark_CompFp32 { } } - template