diff --git a/CMakeLists.txt b/CMakeLists.txt index 3ace8bf36..f0fc2c339 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -61,6 +61,7 @@ option(NS_AVX512_VNNI "neural_speed: enable AVX512-VNNI" option(NS_FMA "neural_speed: enable FMA" ON) option(NS_AMX "neural_speed: enable AMX" OFF) option(NS_USE_OMP "neural_speed: use OpenMP thread pool." ON) +option(NS_SYCL "neural_speed: enable SYCL for GPUs." OFF) option(NS_BUILD_TESTS "neural_speed: build tests" ${NS_STANDALONE}) option(NS_BUILD_EXAMPLES "neural_speed: build examples" ${NS_STANDALONE}) @@ -143,6 +144,11 @@ if(NS_USE_OMP) add_compile_definitions(NS_USE_OMP) endif() +if(NS_SYCL) + set(BTLA_SYCL ON CACHE BOOL "BesTLA with SYCL") + add_compile_definitions(NS_SYCL) +endif() + add_subdirectory(bestla) add_subdirectory(neural_speed) diff --git a/CMakePresets.json b/CMakePresets.json index 76d74df0f..d8b7ea609 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -138,8 +138,10 @@ "CMAKE_BUILD_TYPE": "Debug", "BTLA_UT_DEBUG": "ON", "BTLA_UT_ALL": "OFF", - "BTLA_SYCL": "ON", + "NS_SYCL": "ON", "BTLA_UT_BENCHMARK": "ON", + "BTLA_UT_OPENMP": "ON", + "BTLA_ENABLE_OPENMP": "ON", "CMAKE_CXX_COMPILER": "icx", "CMAKE_C_COMPILER": "icx" } diff --git a/bestla/CMakeLists.txt b/bestla/CMakeLists.txt index 0e75ef009..40d2acbb0 100644 --- a/bestla/CMakeLists.txt +++ b/bestla/CMakeLists.txt @@ -1,6 +1,9 @@ cmake_minimum_required(VERSION 3.12) project(bestla LANGUAGES CXX VERSION 0.1.0) +if(BTLA_SYCL) + include(cmake/sycl.cmake) +endif() include(cmake/FindSIMD.cmake) file(GLOB headers ${PROJECT_NAME}/*.h ${PROJECT_NAME}/*.hpp) @@ -55,11 +58,11 @@ endforeach() set(sycl_headers) set(sycl_libs) if(BTLA_SYCL) - include(cmake/sycl.cmake) file(GLOB sycl_headers ${PROJECT_NAME}/sycl/*.h ${PROJECT_NAME}/sycl/*.hpp) target_compile_definitions(${PROJECT_NAME} INTERFACE BTLA_SYCL) list(APPEND sycl_libs IntelSYCL::SYCL_CXX) - add_compile_options(-march=native) + target_compile_options(${PROJECT_NAME} INTERFACE -march=native) + target_link_libraries(${PROJECT_NAME} INTERFACE ${sycl_libs}) #add_link_options(-fsycl-targets=spir64 -Xsycl-target-backend "-options -ze-opt-large-register-file") endif(BTLA_SYCL) @@ -103,7 +106,7 @@ function(add_ut_flag UT_OPTION) endfunction() set(benchmark_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${PROJECT_NAME}/ut/bestla_benchmark.cpp) -# list(APPEND benchmark_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${PROJECT_NAME}/ut/sycl_benchmark.cpp) +list(APPEND benchmark_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${PROJECT_NAME}/ut/sycl_benchmark.cpp) if(UT_BUILD) @@ -150,6 +153,9 @@ endif(UT_BUILD) if(BTLA_UT_BENCHMARK) file(GLOB ut_headers ${PROJECT_NAME}/ut/*.h) include_directories(${PROJECT_NAME}) + if(NOT BTLA_SYCL) + list(REMOVE_ITEM benchmark_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${PROJECT_NAME}/ut/sycl_benchmark.cpp) + endif() add_executable(${PROJECT_NAME}_benchmark ${benchmark_srcs} ${headers} ${ut_headers}) if(BTLA_UT_OPENMP) include(FindOpenMP) diff --git a/bestla/bestla/bestla_prologue_b.h b/bestla/bestla/bestla_prologue_b.h index 1743189bf..7b0b75b5c 100644 --- a/bestla/bestla/bestla_prologue_b.h +++ b/bestla/bestla/bestla_prologue_b.h @@ -126,6 +126,28 @@ class WeightKBlockNInteger { return tmp; } + AUTOCALL void convertTransStorage(StorageWeight& srcstor, StorageWeight& dststor, parallel::IThreading* threading) { + auto s8buf = utils::amalloc((size_t)srcstor.mK * srcstor.mN); + auto s8transbuf = utils::amalloc((size_t)srcstor.mKPad * srcstor.mNPad); + unpackWeight(srcstor.mN, srcstor.mK, &srcstor, s8buf, srcstor.mN, threading); + transposeWeight(srcstor.mK, srcstor.mN, s8buf, srcstor.mN, s8transbuf, srcstor.mKPad, threading); + compressWeight(srcstor.mKPad, srcstor.mNPad, s8transbuf, srcstor.mKPad, dststor.WPtr(), srcstor.mDType, + threading); + if (s8buf) { + utils::afree(s8buf); + } + if (s8transbuf) { + utils::afree(s8transbuf); + } + int nk_scale = utils::updiv(srcstor.mKPad, srcstor.mBlockSize); + if (srcstor.mCorrection.mScaEleSize == 4) { + transposeWeight(nk_scale, srcstor.mNPad, srcstor.template SPtr(), srcstor.mNPad, + dststor.template SPtr(), dststor.CStep(), threading); + } else if (srcstor.mCorrection.mScaEleSize == 2) { + transposeWeight(nk_scale, srcstor.mNPad, srcstor.template SPtr(), srcstor.mNPad, + dststor.template SPtr(), dststor.CStep(), threading); + } + } AUTOCALL void doubleQuantScale(float* scale, size_t scale_size, int dq_blocksize, BTLA_DTYPE qtype, utils::aligned_vector* dq_buf) { if (qtype == BTLA_DTYPE::DQ8_BNB) { diff --git a/bestla/bestla/bestla_storage.h b/bestla/bestla/bestla_storage.h index e5c441087..9ce19bdc1 100644 --- a/bestla/bestla/bestla_storage.h +++ b/bestla/bestla/bestla_storage.h @@ -706,6 +706,22 @@ class StorageWeightKBlockNInteger : public IWeightKBlockBase { mPrologueID = BTLA_PROLOGUEB_IDS::WeightKBlockNInteger; } + StorageWeightKBlockNInteger toTrans() { + StorageWeightKBlockNInteger trans(-1); + trans.mK = mK; + trans.mN = mN; + trans.mNPad = mNPad; + trans.mKPad = mKPad; + trans.mBlockSize = mBlockSize; + trans.mDType = mDType; + trans.mQBuf.resize(mQBuf.size()); + int nk_scale = utils::updiv(mKPad, mBlockSize); + trans.mCorrection.resize(mNPad, nk_scale, mCorrection.mScaT, mCorrection.mZpT, mCorrection.mRedT, + mCorrection.mZpBuf.size() > 0, mCorrection.mRedBuf.size() > 0); + trans.update_size(); + return trans; + } + size_t resize(int NPad, int KPad, int Block, int N, int K, BTLA_DTYPE qtype, BTLA_DTYPE scalet, BTLA_DTYPE redt, bool IsAsym) { BTLA_DTYPE zpt = BTLA_DTYPE::S8; diff --git a/bestla/bestla/bestla_utils.h b/bestla/bestla/bestla_utils.h index d15fdde39..88f14b564 100644 --- a/bestla/bestla/bestla_utils.h +++ b/bestla/bestla/bestla_utils.h @@ -62,6 +62,17 @@ // As long as the compiler supports the ISA, we will enable it. // Only the ISA you use in your project will be compiled. +#if defined(_MSC_VER) && defined(__INTEL_LLVM_COMPILER) +#define CompileAVX512F() defined(__AVX512F__) +#define CompileAVX512VNNI() defined(__AVX512VNNI__) +#define CompileAVX2() defined(__AVX2__) && defined(__F16C__) && defined(__FMA__) +#define CompileAVXVNNI() defined(__AVXVNNI__) +#define CompileAMX() defined(__AMX_TILE__) +#define CompileBF16() defined(__AVX512BF16__) +#define CompileFP16() defined(__AVX512FP16__) +#define CompileAMXBF16() (CompileAMX()) +#define CompileAMXINT8() (CompileAMX()) +#else #define CompileAVX512F() BTLA_AVX512_FOUND #define CompileAVX512VNNI() BTLA_AVX512_VNNI_FOUND #define CompileAVX2() BTLA_AVX2_FOUND @@ -72,6 +83,7 @@ #define CompileAMXFP16() BTLA_AMX_FP16_FOUND #define CompileAMXINT8() BTLA_AMX_INT8_FOUND #define CompileAMX() BTLA_AMX_BF16_FOUND +#endif // called by launcher, time critical functions #define TLACALL \ @@ -475,6 +487,8 @@ class isa_base { static inline int padto_le(int src, int padding) { return src / padding * padding; } +static inline int64_t padto_le(int64_t src, int64_t padding) { return src / padding * padding; } + static inline size_t padto_le(size_t src, int padding) { return src / size_t(padding) * size_t(padding); } static inline int updiv(int a, int b) { return (a + b - 1) / b; } diff --git a/bestla/bestla/kernel_avx2.h b/bestla/bestla/kernel_avx2.h index 787f4a661..a6f7ad800 100644 --- a/bestla/bestla/kernel_avx2.h +++ b/bestla/bestla/kernel_avx2.h @@ -25,7 +25,7 @@ namespace avx2 { #pragma GCC push_options #pragma GCC target("avx2", "fma", "f16c") #elif defined(ICX) -#pragma clang attribute push(__attribute__((target("avx,avx2,fma"))), apply_to = function) +//#pragma clang attribute push(__attribute__((target("avx2,fma,f16c"))), apply_to = function) #endif static inline void zero_reg() { _mm256_zeroupper(); } diff --git a/bestla/bestla/sycl/sycl_device.h b/bestla/bestla/sycl/sycl_device.h index c23d241c1..96f819fd7 100644 --- a/bestla/bestla/sycl/sycl_device.h +++ b/bestla/bestla/sycl/sycl_device.h @@ -23,7 +23,7 @@ namespace sycl_device { class SyclDevice { public: - SyclDevice(bool profile) { + SyclDevice(int gpu_idx, bool profile) { // Create an exception handler for asynchronous SYCL exceptions static auto exception_handler = [](sycl::exception_list e_list) { for (std::exception_ptr const& e : e_list) { @@ -37,12 +37,38 @@ class SyclDevice { } } }; + auto devices = sycl::device::get_devices(sycl::info::device_type::gpu); + assert(gpu_idx < devices.size()); + + if (profile) { + sycl::property_list prop = {sycl::property::queue::enable_profiling(), sycl::property::queue::in_order()}; + mQueue = sycl::queue(devices[gpu_idx], exception_handler, prop); + } else { + sycl::property_list prop = {sycl::property::queue::in_order()}; + mQueue = sycl::queue(devices[gpu_idx], exception_handler); + } + } + SyclDevice(bool profile) { + // Create an exception handler for asynchronous SYCL exceptions + static auto exception_handler = [](sycl::exception_list e_list) { + for (std::exception_ptr const& e : e_list) { + try { + std::rethrow_exception(e); + } catch (std::exception const& e) { +#if _DEBUG + std::cout << "Failure" << std::endl; +#endif + std::terminate(); + } + } + }; auto d_selector{sycl::default_selector_v}; if (profile) { - sycl::property_list prop = {sycl::property::queue::enable_profiling()}; + sycl::property_list prop = {sycl::property::queue::enable_profiling(), sycl::property::queue::in_order()}; mQueue = sycl::queue(d_selector, exception_handler, prop); } else { + sycl::property_list prop = {sycl::property::queue::in_order()}; mQueue = sycl::queue(d_selector, exception_handler); } } @@ -51,20 +77,44 @@ class SyclDevice { inline std::string getName() { return mQueue.get_device().get_info(); }; + size_t getGlobalMemSize() { return mQueue.get_device().get_info(); } + size_t getMaxMemAllocSize() { return mQueue.get_device().get_info(); } + + double getGlobalMemSizeGB() { return double(getGlobalMemSize()) / 1e9; } + double getMaxMemAllocSizeMB() { return double(getGlobalMemSize()) / 1e6; } + + static inline bool is_cpu(const sycl::device& dev) { + return dev.get_info() == sycl::info::device_type::cpu; + } + + static inline bool is_gpu(const sycl::device& dev) { + return dev.get_info() == sycl::info::device_type::gpu; + } + + static inline bool is_cpu(sycl::queue* q) { + return q->get_device().get_info() == sycl::info::device_type::cpu; + } + + static inline bool is_gpu(sycl::queue* q) { + return q->get_device().get_info() == sycl::info::device_type::gpu; + } + void print() { std::cout << "Running on device: " << mQueue.get_device().get_info() << "\n"; - std::cout << "EU count:" << mQueue.get_device().get_info() - << "\n"; // 448 - std::cout << "EU count per subslice:" - << mQueue.get_device().get_info() << "\n"; // 8 - std::cout << "EU SIMD width:" << mQueue.get_device().get_info() - << "\n"; // 8 - std::cout << "HW threads per EU:" - << mQueue.get_device().get_info() << "\n"; // 8 - std::cout << "GPU slices:" << mQueue.get_device().get_info() - << "\n"; // 7 - std::cout << "Subslice per slice:" - << mQueue.get_device().get_info() << "\n"; // 8 + if (is_gpu(mQueue.get_device())) { + std::cout << "EU count:" << mQueue.get_device().get_info() << "\n"; + std::cout << "EU count per subslice:" + << mQueue.get_device().get_info() << "\n"; + std::cout << "EU SIMD width:" << mQueue.get_device().get_info() + << "\n"; + std::cout << "HW threads per EU:" + << mQueue.get_device().get_info() << "\n"; + std::cout << "GPU slices:" << mQueue.get_device().get_info() << "\n"; + std::cout << "Subslice per slice:" + << mQueue.get_device().get_info() << "\n"; + } + std::cout << "Global Memory size: " << getGlobalMemSizeGB() << "\n"; + std::cout << "Global Memory size: " << getMaxMemAllocSize() << "\n"; } sycl::queue mQueue; }; diff --git a/bestla/bestla/sycl/sycl_gemm.h b/bestla/bestla/sycl/sycl_gemm.h index 7ba1e7963..ed7e3c391 100644 --- a/bestla/bestla/sycl/sycl_gemm.h +++ b/bestla/bestla/sycl/sycl_gemm.h @@ -16,7 +16,7 @@ #ifdef BTLA_SYCL #include -#include "bestla_utils.h" +#include "bestla/bestla_utils.h" #include namespace bestla { @@ -64,6 +64,17 @@ class SGemmCoreSharedB { using SLM_B_Acc = sycl::local_accessor; + using AType = TA; + using BType = TB; + using CType = TC; + static auto constexpr NTILE = WgNEle; + static auto constexpr MTILE = WgMEle; + static auto constexpr KTILE = TileK; + static auto constexpr PACK_ROW = 1; + static int constexpr PREFERRED_N = NTILE; + static auto constexpr ISA = BTLA_ISA::ISA_COUNT; + static auto constexpr ID = 0; + static inline void compute(const TA* aptr, int lda, const SLM_B_Acc& bacc, TACC* accptr, const sycl_utils::nd_item_helper>& helper) { #pragma unroll(1) diff --git a/bestla/bestla/sycl/sycl_prologue_a.h b/bestla/bestla/sycl/sycl_prologue_a.h index 28350f276..9cf9d4d23 100644 --- a/bestla/bestla/sycl/sycl_prologue_a.h +++ b/bestla/bestla/sycl/sycl_prologue_a.h @@ -16,7 +16,7 @@ #ifdef BTLA_SYCL #include -#include "bestla_utils.h" +#include "bestla/bestla_utils.h" #include namespace bestla { diff --git a/bestla/bestla/sycl/sycl_prologue_b.h b/bestla/bestla/sycl/sycl_prologue_b.h index 089a81dd5..b6f845aa0 100644 --- a/bestla/bestla/sycl/sycl_prologue_b.h +++ b/bestla/bestla/sycl/sycl_prologue_b.h @@ -16,7 +16,7 @@ #ifdef BTLA_SYCL #include -#include "bestla_utils.h" +#include "bestla/bestla_utils.h" #include namespace bestla { @@ -85,10 +85,10 @@ class WeightS4 { .B[(helper.item_g_n() + in + (koffset + helper.sg_idx_m() + icp * GemmCoreT::WgM) * _param.ldb) / 2]; dstptr[(helper.sg_idx_m() + icp * GemmCoreT::WgM) * GemmCoreT::WgNEle + (helper.sg_idx_n() * GemmCoreT::SgSize + helper.sg_id()) * GemmCoreT::TileN + in] = - static_cast((tmps8 & 0x0f) << 4) * scale[in]; + static_cast((tmps8 & 0x0f) - 8) * scale[in]; dstptr[(helper.sg_idx_m() + icp * GemmCoreT::WgM) * GemmCoreT::WgNEle + (helper.sg_idx_n() * GemmCoreT::SgSize + helper.sg_id()) * GemmCoreT::TileN + in + 1] = - static_cast((tmps8 & 0xf0)) * scale[in + 1]; + static_cast((tmps8 >> 4) - 8) * scale[in + 1]; } } } @@ -107,7 +107,7 @@ class WeightS4 { int nsg_k = k / GroupK; int nsg_n = n / GroupN; sycl::range<1> group{SgSize}; - sycl::range<1> problem{nsg_n * nsg_k * SgSize}; + sycl::range<1> problem{static_cast(nsg_n) * nsg_k * SgSize}; auto B_d = in.B; auto S_d = in.scale; int ldb = in.ldb; @@ -132,8 +132,8 @@ class WeightS4 { for (int ik = 0; ik < TileK; ik += 1) { for (int in = 0; in < TileN; in += 2) { uint8_t srcu8 = *(bptr + (ik * ldb + sg_id * TileN + in) / 2); - tmp[ik * TileN + in] = static_cast((srcu8 & 0x0f) << 4) * scale[in]; - tmp[ik * TileN + in + 1] = static_cast((srcu8 & 0xf0)) * scale[in + 1]; + tmp[ik * TileN + in] = static_cast((srcu8 & 0x0f) - 8) * scale[in]; + tmp[ik * TileN + in + 1] = static_cast((srcu8 >> 4) - 8) * scale[in + 1]; } } for (int ik = 0; ik < TileK; ik += 1) { @@ -176,15 +176,15 @@ class WeightS4Trans { auto scale = _param.scale[(sgn + icp * GemmCoreT::SgCount) * _param.ldb + koffset / blocksize]; auto tmps8 = _param.B[((sgn + icp * GemmCoreT::SgCount) * wldb + (koffset + helper.sg_id() * LoadTileK)) / 2]; if constexpr (std::is_same_v) { - sycl::half2 tmpBf = {static_cast((tmps8 & 0x0f) << 4), static_cast((tmps8 & 0xf0))}; + sycl::half2 tmpBf = {static_cast((tmps8 & 0x0f) - 8), static_cast((tmps8 >> 4) - 8)}; tmpBf *= scale; dstptr[sg_off + helper.sg_group_id() + icp * GemmCoreT::SgCount] = tmpBf[0]; dstptr[sg_off + GemmCoreT::WgNEle + helper.sg_group_id() + icp * GemmCoreT::SgCount] = tmpBf[1]; } else { dstptr[sg_off + helper.sg_group_id() + icp * GemmCoreT::SgCount] = - static_cast((tmps8 & 0x0f) << 4) * scale; + static_cast((tmps8 & 0x0f) - 8) * scale; dstptr[sg_off + GemmCoreT::WgNEle + helper.sg_group_id() + icp * GemmCoreT::SgCount] = - static_cast((tmps8 & 0xf0)) * scale; + static_cast((tmps8 >> 4) - 8) * scale; } } } @@ -204,7 +204,7 @@ class WeightS4Trans { int nsg_k = k / GroupK; int nsg_n = n / GroupN; sycl::range<1> group{SgSize}; - sycl::range<1> problem{nsg_n * nsg_k * SgSize}; + sycl::range<1> problem{static_cast(nsg_n) * nsg_k * SgSize}; auto B_d = in.B; auto S_d = in.scale; int ldb = in.ldb; @@ -224,7 +224,6 @@ class WeightS4Trans { auto sptr = S_d + sg_k / blocksize + g_n * ldb; auto bptr = B_d + (sg_k + g_n * ldbn) / 2; auto dbptr = outptr + sg_k + g_n * k; - float tmp[TileK]; int constexpr Unroll = 4; #pragma unroll for (int ik = 0; ik < TileK; ik += Unroll) { @@ -232,8 +231,8 @@ class WeightS4Trans { float scale = sptr[(ik * SgSize + sg_id * Unroll) / blocksize]; for (int ir = 0; ir < Unroll; ir += 2) { uint8_t srcu8 = *(bptr + (ik * SgSize + sg_id * Unroll + ir) / 2); - dst[ir] = static_cast((srcu8 & 0x0f) << 4) * scale; - dst[ir + 1] = static_cast((srcu8 & 0xf0)) * scale; + dst[ir] = static_cast((srcu8 & 0x0f) - 8) * scale; + dst[ir + 1] = static_cast((srcu8 >> 4) - 8) * scale; } *(sycl::vec*)&dbptr[ik * SgSize + sg_id * Unroll] = *(sycl::vec*)dst; } @@ -256,7 +255,7 @@ class WeightS4Trans { int nsg_k = k / GroupK; int nsg_n = n / GroupN; sycl::range<1> group{SgSize}; - sycl::range<1> problem{nsg_n * nsg_k * SgSize}; + sycl::range<1> problem{static_cast(nsg_n) * nsg_k * SgSize}; auto B_d = in.B; auto S_d = in.scale; int ldb = in.ldb; @@ -319,7 +318,7 @@ class WeightS4Trans { int nsg_k = k / GroupK; int nsg_n = n / GroupN; sycl::range<1> group{SgSize}; - sycl::range<1> problem{nsg_n * nsg_k * SgSize}; + sycl::range<1> problem{static_cast(nsg_n) * nsg_k * SgSize}; auto B_d = in.B; auto S_d = in.scale; int ldb = in.ldb; @@ -342,8 +341,8 @@ class WeightS4Trans { for (int in = 0; in < TileN; in++) { float scale = sptr[sg_id * TileK / blocksize + in * ldb]; uint8_t srcu8 = *(bptr + (sg_id * TileK + in * ldbn) / 2); - tmp[in] = high4 ? static_cast((srcu8 & 0xf0)) * scale - : static_cast((srcu8 & 0x0f) << 4) * scale; + tmp[in] = high4 ? static_cast((srcu8 >> 4) - 8) * scale + : static_cast((srcu8 & 0x0f) - 8) * scale; } float tmpT[TileN]; @@ -369,85 +368,235 @@ class WeightS4Trans { auto B = paramB.B; auto B_scale = paramB.scale; int ldb = paramB.ldb; + int constexpr Unroll = 2; int constexpr SgSize = 16; - int constexpr TileK = 32; - int constexpr GroupK = SgSize * TileK; sycl::range<1> group{SgSize}; - sycl::range<1> problem{n * SgSize}; - - auto ev = q->submit([&](sycl::handler& cgh) { - cgh.parallel_for( - sycl::nd_range<1>(problem, group), - [=](sycl::nd_item<1> it) [[cl::reqd_work_group_size( - 1, 1, SgSize)]] [[intel::kernel_args_restrict]] [[intel::reqd_sub_group_size(SgSize)]] { - int g_idx = it.get_group(0); - auto sg = it.get_sub_group(); - int sg_id = sg.get_local_id()[0]; - int g_n = g_idx; - auto sptr = B_scale + g_n * ldb; - auto bptr = B + g_n * k / 2; - auto aptr = A; - auto cptr = C + g_n; - if constexpr (std::is_same_v) { - sycl::half2 tmpAcc = {0.f, 0.f}; - int constexpr Unroll = 2; - for (int i = 0; i < k; i += GroupK * Unroll) { + sycl::range<1> problem{static_cast(n) * SgSize}; + if (k % (SgSize * 32 * Unroll) == 0) { + int constexpr TileK = 32; + int constexpr GroupK = SgSize * TileK; + auto ev = q->submit([&](sycl::handler& cgh) { + cgh.parallel_for(sycl::nd_range<1>(problem, group), + [=](sycl::nd_item<1> it) [[sycl::reqd_work_group_size( + 1, 1, SgSize)]] [[intel::kernel_args_restrict]] [[intel::reqd_sub_group_size(SgSize)]] { + int g_idx = it.get_group(0); + auto sg = it.get_sub_group(); + int sg_id = sg.get_local_id()[0]; + int g_n = g_idx; + auto sptr = B_scale + g_n * ldb; + auto bptr = B + g_n * k / 2; + auto aptr = A; + auto cptr = C + g_n; + if constexpr (std::is_same_v) { + sycl::half2 tmpAcc = {0.f, 0.f}; + for (int i = 0; i < k; i += GroupK * Unroll) { +#pragma unroll + for (int iu = 0; iu < Unroll; iu++) { + uint8_t tmps8[TileK / 2]; + *(sycl::vec*)tmps8 = + *(sycl::vec*)(bptr + sg_id * TileK / 2); + CType scale = *(sptr + sg_id * TileK / blocksize); +#pragma unroll + for (int ikk = 0; ikk < TileK; ikk += 2) { + sycl::half2 tmpA = *(sycl::half2*)&aptr[sg_id * TileK + ikk]; + sycl::half2 tmpB = {static_cast((tmps8[ikk / 2] & 0x0f) - 8), + static_cast((tmps8[ikk / 2] >> 4) - 8)}; + tmpAcc += tmpA * tmpB * scale; + } + sptr += GroupK / blocksize; + aptr += GroupK; + bptr += GroupK / 2; + } + } + sycl::half2 sum = {0.f, 0.f}; + for (int i = 0; i < SgSize; i += 1) { + sum += sg.shuffle(tmpAcc, i); + } + if (sg_id == 0) { + *cptr = sum[0] + sum[1]; + } + } else { + CType tmpAcc = 0.f; + int constexpr Unroll = 2; + for (int i = 0; i < k; i += GroupK * Unroll) { +#pragma unroll + for (int iu = 0; iu < Unroll; iu++) { + uint8_t tmps8[TileK / 2]; + *(sycl::vec*)tmps8 = + *(sycl::vec*)(bptr + sg_id * TileK / 2); + CType scale = *(sptr + sg_id * TileK / blocksize); #pragma unroll - for (int iu = 0; iu < Unroll; iu++) { - uint8_t tmps8[TileK / 2]; - *(sycl::vec*)tmps8 = *(sycl::vec*)(bptr + sg_id * TileK / 2); - CType scale = *(sptr + sg_id * TileK / blocksize); + for (int ikk = 0; ikk < TileK; ikk += 2) { + tmpAcc += CType(aptr[sg_id * TileK + ikk]) * + static_cast((tmps8[ikk / 2] & 0x0f) - 8) * scale; + tmpAcc += CType(aptr[sg_id * TileK + ikk + 1]) * + static_cast((tmps8[ikk / 2] >> 4) - 8) * scale; + } + sptr += GroupK / blocksize; + aptr += GroupK; + bptr += GroupK / 2; + } + } + float sum = 0.f; + for (int i = 0; i < SgSize; i += 1) { + sum += sg.shuffle(tmpAcc, i); + } + if (sg_id == 0) { + *cptr = sum; + } + } + }); + }); + return ev; + } else { + int constexpr TileK = 32; + int constexpr GroupK = SgSize * TileK; + int k_body = utils::padto_le(k, GroupK * Unroll); + int constexpr TileK2 = 8; + int constexpr GroupK2 = SgSize * TileK2; + int k_body2 = utils::padto_le(k, GroupK2 * Unroll); + auto ev = q->submit([&](sycl::handler& cgh) { + cgh.parallel_for( + sycl::nd_range<1>(problem, group), + [=](sycl::nd_item<1> it) [[sycl::reqd_work_group_size( + 1, 1, SgSize)]] [[intel::kernel_args_restrict]] [[intel::reqd_sub_group_size(SgSize)]] { + int g_idx = it.get_group(0); + auto sg = it.get_sub_group(); + int sg_id = sg.get_local_id()[0]; + int g_n = g_idx; + auto sptr = B_scale + g_n * ldb; + auto bptr = B + g_n * k / 2; + auto aptr = A; + auto cptr = C + g_n; + if constexpr (std::is_same_v) { + sycl::half2 tmpAcc = {0.f, 0.f}; + int i = 0; + for (; i < k_body; i += GroupK * Unroll) { +#pragma unroll + for (int iu = 0; iu < Unroll; iu++) { + uint8_t tmps8[TileK / 2]; + *(sycl::vec*)tmps8 = + *(sycl::vec*)(bptr + sg_id * TileK / 2); + CType scale = *(sptr + sg_id * TileK / blocksize); +#pragma unroll + for (int ikk = 0; ikk < TileK; ikk += 2) { + sycl::half2 tmpA = *(sycl::half2*)&aptr[sg_id * TileK + ikk]; + sycl::half2 tmpB = {static_cast((tmps8[ikk / 2] & 0x0f) - 8), + static_cast((tmps8[ikk / 2] >> 4) - 8)}; + tmpAcc += tmpA * tmpB * scale; + } + sptr += GroupK / blocksize; + aptr += GroupK; + bptr += GroupK / 2; + } + } + if (i + GroupK2 * Unroll < k_body2) { + for (; i < k_body2; i += GroupK2 * Unroll) { #pragma unroll - for (int ikk = 0; ikk < TileK; ikk += 2) { - sycl::half2 tmpA = *(sycl::half2*)&aptr[sg_id * TileK + ikk]; - sycl::half2 tmpB = {static_cast((tmps8[ikk / 2] & 0x0f) << 4), - static_cast((tmps8[ikk / 2] & 0xf0))}; + for (int iu = 0; iu < Unroll; iu++) { + uint8_t tmps8[TileK2 / 2]; + *(sycl::vec*)tmps8 = + *(sycl::vec*)(bptr + sg_id * TileK2 / 2); + CType scale = *(sptr + sg_id * TileK2 / blocksize); +#pragma unroll + for (int ikk = 0; ikk < TileK2; ikk += 2) { + sycl::half2 tmpA = *(sycl::half2*)&aptr[sg_id * TileK2 + ikk]; + sycl::half2 tmpB = {static_cast((tmps8[ikk / 2] & 0x0f) - 8), + static_cast((tmps8[ikk / 2] >> 4) - 8)}; + tmpAcc += tmpA * tmpB * scale; + } + sptr += GroupK2 / blocksize; + aptr += GroupK2; + bptr += GroupK2 / 2; + } + } + } + if (i + SgSize * 2 < k) { + for (; i < k; i += SgSize * 2) { + uint8_t tmps8 = *(bptr + sg_id); + CType scale = *(sptr + sg_id * 2 / blocksize); + sycl::half2 tmpA = *(sycl::half2*)&aptr[sg_id * 2]; + sycl::half2 tmpB = {static_cast((tmps8 & 0x0f) - 8), static_cast((tmps8 >> 4) - 8)}; tmpAcc += tmpA * tmpB * scale; + sptr += SgSize * 2 / blocksize; + aptr += SgSize * 2; + bptr += SgSize * 2 / 2; } - sptr += GroupK / blocksize; - aptr += GroupK; - bptr += GroupK / 2; } - } - sycl::half2 sum = {0.f, 0.f}; - for (int i = 0; i < SgSize; i += 1) { - sum += sg.shuffle(tmpAcc, i); - } - if (sg_id == 0) { - *cptr = sum[0] + sum[1]; - } - } else { - CType tmpAcc = 0.f; - int constexpr Unroll = 2; - for (int i = 0; i < k; i += GroupK * Unroll) { + sycl::half2 sum = {0.f, 0.f}; + for (int i = 0; i < SgSize; i += 1) { + sum += sg.shuffle(tmpAcc, i); + } + if (sg_id == 0) { + *cptr = sum[0] + sum[1]; + } + } else { + CType tmpAcc = 0.f; + int constexpr Unroll = 2; + int i = 0; + for (; i < k_body; i += GroupK * Unroll) { #pragma unroll - for (int iu = 0; iu < Unroll; iu++) { - uint8_t tmps8[TileK / 2]; - *(sycl::vec*)tmps8 = *(sycl::vec*)(bptr + sg_id * TileK / 2); - CType scale = *(sptr + sg_id * TileK / blocksize); + for (int iu = 0; iu < Unroll; iu++) { + uint8_t tmps8[TileK / 2]; + *(sycl::vec*)tmps8 = + *(sycl::vec*)(bptr + sg_id * TileK / 2); + CType scale = *(sptr + sg_id * TileK / blocksize); #pragma unroll - for (int ikk = 0; ikk < TileK; ikk += 2) { - tmpAcc += - CType(aptr[sg_id * TileK + ikk]) * static_cast((tmps8[ikk / 2] & 0x0f) << 4) * scale; - tmpAcc += - CType(aptr[sg_id * TileK + ikk + 1]) * static_cast((tmps8[ikk / 2] & 0xf0)) * scale; + for (int ikk = 0; ikk < TileK; ikk += 2) { + tmpAcc += + CType(aptr[sg_id * TileK + ikk]) * static_cast((tmps8[ikk / 2] & 0x0f) - 8) * scale; + tmpAcc += + CType(aptr[sg_id * TileK + ikk + 1]) * static_cast((tmps8[ikk / 2] >> 4) - 8) * scale; + } + sptr += GroupK / blocksize; + aptr += GroupK; + bptr += GroupK / 2; } - sptr += GroupK / blocksize; - aptr += GroupK; - bptr += GroupK / 2; + } + if (i + GroupK2 * Unroll < k_body2) { + for (; i < k_body2; i += GroupK2 * Unroll) { +#pragma unroll + for (int iu = 0; iu < Unroll; iu++) { + uint8_t tmps8[TileK2 / 2]; + *(sycl::vec*)tmps8 = + *(sycl::vec*)(bptr + sg_id * TileK2 / 2); + CType scale = *(sptr + sg_id * TileK2 / blocksize); +#pragma unroll + for (int ikk = 0; ikk < TileK2; ikk += 2) { + tmpAcc += CType(aptr[sg_id * TileK2 + ikk]) * static_cast((tmps8[ikk / 2] & 0x0f) - 8) * + scale; + tmpAcc += CType(aptr[sg_id * TileK2 + ikk + 1]) * + static_cast((tmps8[ikk / 2] >> 4) - 8) * scale; + } + sptr += GroupK2 / blocksize; + aptr += GroupK2; + bptr += GroupK2 / 2; + } + } + } + if (i + SgSize * Unroll < k) { + for (; i < k; i += SgSize) { + uint8_t tmps8 = *(bptr + sg_id / 2); + CType scale = *(sptr + sg_id / blocksize); + tmpAcc += CType(aptr[sg_id]) * static_cast((tmps8 & 0x0f) - 8) * scale; + tmpAcc += CType(aptr[sg_id]) * static_cast((tmps8 >> 4) - 8) * scale; + sptr += SgSize / blocksize; + aptr += SgSize; + bptr += SgSize / 2; + } + } + float sum = 0.f; + for (int i = 0; i < SgSize; i += 1) { + sum += sg.shuffle(tmpAcc, i); + } + if (sg_id == 0) { + *cptr = sum; } } - float sum = 0.f; - for (int i = 0; i < SgSize; i += 1) { - sum += sg.shuffle(tmpAcc, i); - } - if (sg_id == 0) { - *cptr = sum; - } - } - }); - }); - return ev; + }); + }); + return ev; + } } }; } // namespace sycl_prologue_b diff --git a/bestla/bestla/sycl/sycl_storage.h b/bestla/bestla/sycl/sycl_storage.h new file mode 100644 index 000000000..db446b071 --- /dev/null +++ b/bestla/bestla/sycl/sycl_storage.h @@ -0,0 +1,125 @@ +// Copyright (c) 2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include "sycl_utils.h" +#include "bestla/bestla_storage.h" + +namespace bestla { +namespace sycl_storage { +class StorageWeightKBlockNInteger { + public: + BTLA_PROLOGUEB_IDS mPrologueID = BTLA_PROLOGUEB_IDS::Undef; + uint64_t mCoreId = 0; + BTLA_DTYPE mDType = BTLA_DTYPE::F32; + BTLA_DTYPE mScaT = BTLA_DTYPE::F32, mZpT = BTLA_DTYPE::F32, mRedT = BTLA_DTYPE::F32; + int mNPad = 0, mKPad = 0; + int mN = 0, mK = 0; + int mBlockSize = 1; + int mDqBlockSize = 0; + int mCStep = 0; + int8_t* mQBuf = nullptr; + size_t mWSize = 0; + int8_t* mSBuf = nullptr; + size_t mCSize = 0; + int8_t *mZpBuf = nullptr, *mRedBuf = nullptr; + size_t mZpSize = 0, mRedSize = 0; + int8_t* mDQCorrectionBuf = nullptr; + int8_t* mShuffleIndices = nullptr; + size_t mDQCorSize = 0, mShufSize = 0; + + StorageWeightKBlockNInteger(bestla::storage::gemm::StorageWeightKBlockNInteger& _hoststor) { + mPrologueID = BTLA_PROLOGUEB_IDS::WeightKBlockNInteger; + mCoreId = 0; + mDType = _hoststor.mDType; + mScaT = _hoststor.SDtype(); + mZpT = _hoststor.ZDtype(); + mRedT = _hoststor.RDtype(); + mNPad = _hoststor.mNPad; + mKPad = _hoststor.mKPad; + mN = _hoststor.mN; + mK = _hoststor.mK; + mBlockSize = _hoststor.mBlockSize; + mDqBlockSize = _hoststor.mDqBlockSize; + mWSize = _hoststor.template WSize(); + mCSize = _hoststor.CSize() * utils::bestla_dtype_size(mScaT); + mCStep = _hoststor.CStep(); + + if (_hoststor.template ZPtr()) { + mZpSize = _hoststor.CSize() * utils::bestla_dtype_size(mZpT); + } + if (_hoststor.template RPtr()) { + mRedSize = _hoststor.CSize() * utils::bestla_dtype_size(mRedT); + } + // TODO DQ,shuffle support + } + + size_t getDeviceSize() { return mWSize + mCSize + mZpSize + mRedSize + mDQCorSize + mShufSize; } + + void assign(int8_t* dptr) { + mQBuf = dptr; + dptr += mWSize; + mSBuf = dptr; + dptr += mCSize; + if (mZpSize) { + mZpBuf = dptr; + dptr += mZpSize; + } + if (mRedSize) { + mRedBuf = dptr; + dptr += mRedSize; + } + if (mDQCorSize) { + mDQCorrectionBuf = dptr; + dptr += mDQCorSize; + } + if (mShuffleIndices) { + mDQCorrectionBuf = dptr; + dptr += mShufSize; + } + } + + void fromHost(bestla::storage::gemm::StorageWeightKBlockNInteger& _hoststor, sycl::queue* queue) { + if (_hoststor.template WPtr() && mQBuf) { + queue->memcpy(mQBuf, _hoststor.template WPtr(), mWSize); + } + if (_hoststor.template SPtr() && mSBuf) { + queue->memcpy(mSBuf, _hoststor.template SPtr(), mCSize); + } + if (_hoststor.template ZPtr() && mZpBuf) { + queue->memcpy(mZpBuf, _hoststor.template ZPtr(), mZpSize); + } + if (_hoststor.template RPtr() && mRedBuf) { + queue->memcpy(mRedBuf, _hoststor.template RPtr(), mRedSize); + } + queue->wait(); + } + + void toHost(bestla::storage::gemm::StorageWeightKBlockNInteger& _hoststor, sycl::queue* queue) { + if (mQBuf) { + queue->memcpy(_hoststor.template WPtr(), mQBuf, mWSize); + } + if (mSBuf) { + queue->memcpy(_hoststor.template SPtr(), mSBuf, mCSize); + } + if (mZpBuf) { + queue->memcpy(_hoststor.template ZPtr(), mZpBuf, mZpSize); + } + if (mRedBuf) { + queue->memcpy(_hoststor.template RPtr(), mRedBuf, mRedSize); + } + queue->wait(); + } +}; +} // namespace sycl_storage +} // namespace bestla diff --git a/bestla/bestla/sycl/sycl_utils.h b/bestla/bestla/sycl/sycl_utils.h index 2cdf01626..606355a4e 100644 --- a/bestla/bestla/sycl/sycl_utils.h +++ b/bestla/bestla/sycl/sycl_utils.h @@ -13,7 +13,7 @@ // limitations under the License. #pragma once #include "sycl_device.h" -#include "bestla_utils.h" +#include "bestla/bestla_utils.h" namespace bestla { namespace sycl_utils { diff --git a/bestla/bestla/sycl/sycl_wrapper.h b/bestla/bestla/sycl/sycl_wrapper.h index 29dd84997..49fb53de5 100644 --- a/bestla/bestla/sycl/sycl_wrapper.h +++ b/bestla/bestla/sycl/sycl_wrapper.h @@ -16,7 +16,7 @@ #ifdef BTLA_SYCL #include -#include "bestla_utils.h" +#include "bestla/bestla_utils.h" #include "sycl_utils.h" #include "sycl_device.h" #include "sycl_gemm.h" @@ -56,12 +56,12 @@ class Launcher { int ldb = _param.paramB.ldb; int ldc = _param.paramC.ldc; int m_pad = utils::padto(utils::updiv(m, GemmCore::TileM), GemmCore::WgM); - sycl::range<2> problem{m_pad, n / GemmCore::TileN}; + sycl::range<2> problem{static_cast(m_pad), static_cast(n) / GemmCore::TileN}; auto ev = q->submit([&](sycl::handler& cgh) { sycl::local_accessor slm_b(sycl::range(GemmCore::SLM_B_Size), cgh); cgh.parallel_for( sycl::nd_range<2>(problem, group), - [=](sycl::nd_item<2> it) [[cl::reqd_work_group_size( + [=](sycl::nd_item<2> it) [[sycl::reqd_work_group_size( 1, GemmCore::WgM, GemmCore::WgN)]] [[intel::kernel_args_restrict]] [[intel::reqd_sub_group_size(GemmCore::SgSize)]] { sycl_utils::nd_item_helper helper(it); @@ -148,15 +148,15 @@ class LauncherWOQ { int ldb = _param.paramB.ldb; int ldc = _param.paramC.ldc; int m_pad = utils::padto(utils::updiv(m, GemmCore::TileM), GemmCore::WgM); - sycl::range<2> problem{m_pad, n / GemmCore::TileN}; + sycl::range<2> problem{static_cast(m_pad), static_cast(n) / GemmCore::TileN}; auto ev = q->submit([&](sycl::handler& cgh) { sycl::local_accessor slm_b(sycl::range(GemmCore::SLM_B_Size), cgh); cgh.parallel_for( sycl::nd_range<2>(problem, group), - [=](sycl::nd_item<2> it) [[cl::reqd_work_group_size( + [=](sycl::nd_item<2> it) [[sycl::reqd_work_group_size( 1, GemmCore::WgM, GemmCore::WgN)]] [[intel::kernel_args_restrict]] [[intel::reqd_sub_group_size(GemmCore::SgSize)]] { - nd_item_helper helper(it); + sycl_utils::nd_item_helper helper(it); if constexpr (debug) { compute_tile(k, blocksize, B, B_scale, ldb, slm_b, A, lda, C, ldc, it); } else { diff --git a/bestla/bestla/ut/bestla_benchmark.cpp b/bestla/bestla/ut/bestla_benchmark.cpp index f3c09ac03..f593704cc 100644 --- a/bestla/bestla/ut/bestla_benchmark.cpp +++ b/bestla/bestla/ut/bestla_benchmark.cpp @@ -581,7 +581,7 @@ class UTWOQ_CompFp32 { } auto psize = (size_t)m * n * k * 2; int blks = k / blocksize; - int nbits = utils::bestla_dtype_bits(qtype); + size_t nbits = utils::bestla_dtype_bits(qtype); auto memsize = (size_t)(n * k * nbits / 8 + n * blks * sizeof(Scale_T)) + (m * k + m * n) * sizeof(float); tm.start(); while (tm.stop() < timems) { @@ -700,7 +700,7 @@ class UTWOQ_CompBf16 { } auto psize = (size_t)m * n * k * 2; int blks = k / blocksize; - int nbits = utils::bestla_dtype_bits(qtype); + size_t nbits = utils::bestla_dtype_bits(qtype); auto memsize = (size_t)(n * k * nbits / 8 + n * blks * sizeof(Scale_T)) + (m * k + m * n) * sizeof(float); tm.start(); while (tm.stop() < timems) { @@ -816,7 +816,7 @@ class UTWOQ_CompInt8 { quanA.assign(bufferA.data()); auto psize = (size_t)m * n * k * 2; int blks = k / blocksize; - int nbits = utils::bestla_dtype_bits(qtype); + auto nbits = utils::bestla_dtype_bits(qtype); auto memsize = (size_t)(n * k * nbits / 8 + n * blks * sizeof(Scale_T)) + (m * k + m * n) * sizeof(float); if (isasym) { memsize += n * blks * sizeof(int8_t); @@ -878,7 +878,7 @@ class UTWOQ_CompInt8 { quanA.assign(bufferA.data()); auto psize = (size_t)m * n * k * 2; int blks = k / blocksize; - int nbits = utils::bestla_dtype_bits(qtype); + auto nbits = utils::bestla_dtype_bits(qtype); auto memsize = (size_t)(n * k * nbits / 8 + n * blks * sizeof(Scale_T)) + (m * k + m * n) * sizeof(float); if (isasym) { memsize += n * blks * sizeof(int8_t); diff --git a/bestla/bestla/ut/sycl_benchmark.cpp b/bestla/bestla/ut/sycl_benchmark.cpp index 3de21c57f..c1309f4a8 100644 --- a/bestla/bestla/ut/sycl_benchmark.cpp +++ b/bestla/bestla/ut/sycl_benchmark.cpp @@ -42,14 +42,14 @@ class Benchmark_Fp32Fp32 { auto C_d = C; auto psize = (size_t)m * n * k * 2; sycl::range<2> group{SGemmT::WgM, SGemmT::WgN}; - sycl::range<2> problem{m / SGemmT::TileM, n / SGemmT::TileN}; + sycl::range<2> problem{static_cast(m) / SGemmT::TileM, static_cast(n) / SGemmT::TileN}; utils::GemmProblem gp(1, m, n, k); tm.start(); while (tm.stop() < timems) { for (size_t i = 0; i < batch; i++) { - auto e_esimd = KernelLauncher::compute(q, m, n, k, {{A, k}, {B, n}, {C, n}}); - e_esimd.wait(); - log.add(event_helper::execute_time(e_esimd) * 1000); + auto ev = KernelLauncher::compute(q, m, n, k, {{A, k}, {B, n}, {C, n}}); + ev.wait(); + log.add(event_helper::execute_time(ev) * 1000); if (tm.stop() >= timems) { break; } @@ -88,6 +88,7 @@ class Benchmark_Fp32Fp32 { #ifdef BTLA_UT_SYCL static Benchmark_Fp32Fp32 sBenchmark_Fp32Fp32; #endif + class Benchmark_Fp16Fp16 { public: Benchmark_Fp16Fp16() { @@ -122,9 +123,9 @@ class Benchmark_Fp16Fp16 { tm.start(); while (tm.stop() < timems) { for (size_t i = 0; i < batch; i++) { - auto e_esimd = KernelLauncher::compute(q, m, n, k, {{A, k}, {B, n}, {C, n}}); - e_esimd.wait(); - log.add(event_helper::execute_time(e_esimd) * 1000); + auto ev = KernelLauncher::compute(q, m, n, k, {{A, k}, {B, n}, {C, n}}); + ev.wait(); + log.add(event_helper::execute_time(ev) * 1000); if (tm.stop() >= timems) { break; } @@ -160,6 +161,7 @@ class Benchmark_S4Fp32Fp32 { Benchmark_S4Fp32Fp32() { UT_START(); benchmark_all(1, 4096, 4096); + benchmark_all(1, 4096, 11008); benchmark_all(1, 4096, 4096 * 3); benchmark_all(1, 4096 * 3, 4096); benchmark_all(1024, 4096, 4096); @@ -192,14 +194,14 @@ class Benchmark_S4Fp32Fp32 { auto C_d = C; auto psize = (size_t)m * n * k * 2; sycl::range<2> group{SGemmT::WgM, SGemmT::WgN}; - sycl::range<2> problem{m / SGemmT::TileM, n / SGemmT::TileN}; + sycl::range<2> problem{static_cast(m) / SGemmT::TileM, static_cast(n) / SGemmT::TileN}; utils::GemmProblem gp(1, m, n, k); tm.start(); while (tm.stop() < timems) { for (size_t i = 0; i < batch; i++) { - auto e_esimd = KernelLauncher::compute(q, m, n, k, 128, {{A, k}, {B, B_scale, n}, {C, n}}); - e_esimd.wait(); - log.add(event_helper::execute_time(e_esimd) * 1000); + auto ev = KernelLauncher::compute(q, m, n, k, 128, {{A, k}, {B, B_scale, n}, {C, n}}); + ev.wait(); + log.add(event_helper::execute_time(ev) * 1000); if (tm.stop() >= timems) { break; } @@ -224,9 +226,9 @@ class Benchmark_S4Fp32Fp32 { tm.start(); while (tm.stop() < timems) { for (size_t i = 0; i < batch; i++) { - auto e_esimd = KernelLauncherT::compute(q, m, n, k, 128, {{A, k}, {B, B_scale, blks}, {C, n}}); - e_esimd.wait(); - log.add(event_helper::execute_time(e_esimd) * 1000); + auto ev = KernelLauncherT::compute(q, m, n, k, 128, {{A, k}, {B, B_scale, blks}, {C, n}}); + ev.wait(); + log.add(event_helper::execute_time(ev) * 1000); if (tm.stop() >= timems) { break; } @@ -257,10 +259,10 @@ class Benchmark_S4Fp32Fp32 { int constexpr TileK = 32; int constexpr GroupK = SgSize * TileK; sycl::range<1> group{SgSize}; - sycl::range<1> problem{n * SgSize}; - auto e_esimd = ProBTransT::gemv(A_d, {B_d, S_d, blks}, C_d, n, k, blocksize, q); - e_esimd.wait(); - log.add(event_helper::execute_time(e_esimd) * 1000); + sycl::range<1> problem{static_cast(n) * SgSize}; + auto ev = ProBTransT::gemv(A_d, {B_d, S_d, blks}, C_d, n, k, blocksize, q); + ev.wait(); + log.add(event_helper::execute_time(ev) * 1000); if (tm.stop() >= timems) { break; } @@ -315,6 +317,7 @@ class Benchmark_S4Fp16Fp16 { Benchmark_S4Fp16Fp16() { UT_START(); benchmark_all(1, 4096, 4096, 128); + benchmark_all(1, 4096, 11008, 128); benchmark_all(1, 4096, 4096 * 4, 128); benchmark_all(1, 4096 * 3, 4096, 128); benchmark_all(1024, 4096, 4096, 32); @@ -352,9 +355,9 @@ class Benchmark_S4Fp16Fp16 { tm.start(); while (tm.stop() < timems) { for (size_t i = 0; i < batch; i++) { - auto e_esimd = KernelLauncher::compute(q, m, n, k, blocksize, {{A_d, k}, {B_d, S_d, n}, {C_d, n}}); - e_esimd.wait(); - log.add(event_helper::execute_time(e_esimd) * 1000); + auto ev = KernelLauncher::compute(q, m, n, k, blocksize, {{A_d, k}, {B_d, S_d, n}, {C_d, n}}); + ev.wait(); + log.add(event_helper::execute_time(ev) * 1000); if (tm.stop() >= timems) { break; } @@ -364,7 +367,7 @@ class Benchmark_S4Fp16Fp16 { double flops = double(psize) / log.min_val / 1e6; printf(" %s Flops:%.3f\n", log.get_log_str(), flops); } - +#if 0 template void benchmark_gemmT(int m, int n, int k, int blocksize, int batch, AType* A, uint8_t* B, BType* B_scale, CType* C, float timems) { @@ -381,9 +384,9 @@ class Benchmark_S4Fp16Fp16 { tm.start(); while (tm.stop() < timems) { for (size_t i = 0; i < batch; i++) { - auto e_esimd = KernelTLauncher::compute({m, n, k, blocksize, {A_d, k}, {B_d, S_d, blks}, {C_d, n}}, q); - e_esimd.wait(); - log.add(event_helper::execute_time(e_esimd) * 1000); + auto ev = KernelTLauncher::compute({m, n, k, blocksize, {A_d, k}, {B_d, S_d, blks}, {C_d, n}}, q); + ev.wait(); + log.add(event_helper::execute_time(ev) * 1000); if (tm.stop() >= timems) { break; } @@ -410,9 +413,9 @@ class Benchmark_S4Fp16Fp16 { tm.start(); while (tm.stop() < timems) { for (size_t i = 0; i < batch; i++) { - auto e_esimd = KernelTLauncher::compute({m, n, k, blocksize, {A_d, k}, {B_d, S_d, blks}, {C_d, n}}, q); - e_esimd.wait(); - log.add(event_helper::execute_time(e_esimd) * 1000); + auto ev = KernelTLauncher::compute({m, n, k, blocksize, {A_d, k}, {B_d, S_d, blks}, {C_d, n}}, q); + ev.wait(); + log.add(event_helper::execute_time(ev) * 1000); if (tm.stop() >= timems) { break; } @@ -422,6 +425,7 @@ class Benchmark_S4Fp16Fp16 { double flops = double(psize) / log.min_val / 1e6; printf(" %s Flops:%.3f\n", log.get_log_str(), flops); } +#endif template void benchmark_gemv_T2(int m, int n, int k, int blocksize, int batch, AType* A, uint8_t* B, BType* B_scale, CType* C, float timems) { @@ -438,9 +442,9 @@ class Benchmark_S4Fp16Fp16 { tm.start(); while (tm.stop() < timems) { for (size_t i = 0; i < batch; i++) { - auto e_esimd = ProBTransT::gemv(A_d, {B_d, S_d, blks}, C_d, n, k, blocksize, q); - e_esimd.wait(); - log.add(event_helper::execute_time(e_esimd) * 1000); + auto ev = ProBTransT::gemv(A_d, {B_d, S_d, blks}, C_d, n, k, blocksize, q); + ev.wait(); + log.add(event_helper::execute_time(ev) * 1000); if (tm.stop() >= timems) { break; } @@ -505,8 +509,8 @@ class Benchmark_DequantS4 { for (int i = 0; i < k; i += 2) { auto tmp = srcptr[i / 2 + j * k / 2]; auto noffset = i / blocksize + j * blks; - ref[i + j * k] = static_cast(static_cast(tmp.x) << 4) * scale[noffset]; - ref[i + 1 + j * k] = static_cast(static_cast(tmp.y) << 4) * scale[noffset]; + ref[i + j * k] = static_cast(static_cast(tmp.x) - 8) * scale[noffset]; + ref[i + 1 + j * k] = static_cast(static_cast(tmp.y) - 8) * scale[noffset]; } } sycl_vector dS(scale.size(), q), dequantB(n * k, q); @@ -524,9 +528,9 @@ class Benchmark_DequantS4 { tm.start(); while (tm.stop() < TestMs) { for (size_t i = 0; i < 1; i++) { - auto e_esimd = + auto ev = ProB::dequant_s4_trans(n, k, blocksize, {B_d, S_d, blks}, DB_d, q); - log.add(event_helper::execute_time(e_esimd) * 1000); + log.add(event_helper::execute_time(ev) * 1000); if (tm.stop() >= TestMs) { break; } @@ -557,8 +561,8 @@ class Benchmark_DequantS4 { auto tmp = srcptr[i / 2 + j * k / 2]; auto noffset = i / blocksize + j * blks; auto s = float(scale[noffset]); - ref[i + j * k] = static_cast(static_cast(tmp.x) << 4) * s; - ref[i + 1 + j * k] = static_cast(static_cast(tmp.y) << 4) * s; + ref[i + j * k] = static_cast(static_cast(tmp.x) - 8) * s; + ref[i + 1 + j * k] = static_cast(static_cast(tmp.y) - 8) * s; } } sycl_vector dS(scale.size(), q), dequantB(n * k, q); @@ -576,9 +580,9 @@ class Benchmark_DequantS4 { tm.start(); while (tm.stop() < TestMs) { for (size_t i = 0; i < 1; i++) { - auto e_esimd = + auto ev = ProB::dequant_s4_trans(n, k, blocksize, {B_d, S_d, blks}, DB_d, q); - log.add(event_helper::execute_time(e_esimd) * 1000); + log.add(event_helper::execute_time(ev) * 1000); if (tm.stop() >= TestMs) { break; } @@ -608,8 +612,8 @@ class Benchmark_DequantS4 { for (int i = 0; i < k; i += 2) { auto tmp = srcptr[i / 2 + j * k / 2]; auto noffset = i / blocksize + j * blks; - ref[i + j * k] = static_cast(static_cast(tmp.x) << 4) * scale[noffset]; - ref[i + 1 + j * k] = static_cast(static_cast(tmp.y) << 4) * scale[noffset]; + ref[i + j * k] = static_cast(static_cast(tmp.x) - 8) * scale[noffset]; + ref[i + 1 + j * k] = static_cast(static_cast(tmp.y) - 8) * scale[noffset]; } } sycl_vector dS(scale.size(), q), dequantB(n * k, q); @@ -627,8 +631,8 @@ class Benchmark_DequantS4 { tm.start(); while (tm.stop() < TestMs) { for (size_t i = 0; i < 1; i++) { - auto e_esimd = ProB::dequant_s4(n, k, blocksize, {B_d, S_d, blks}, DB_d, q); - log.add(event_helper::execute_time(e_esimd) * 1000); + auto ev = ProB::dequant_s4(n, k, blocksize, {B_d, S_d, blks}, DB_d, q); + log.add(event_helper::execute_time(ev) * 1000); if (tm.stop() >= TestMs) { break; } @@ -657,8 +661,8 @@ class Benchmark_DequantS4 { for (int j = 0; j < n; j += 2) { auto tmp = srcptr[i * n / 2 + j / 2]; auto noffset = i / blocksize * n + j; - ref[i * n + j + 0] = static_cast(static_cast(tmp.x) << 4) * scale[noffset + 0]; - ref[i * n + j + 1] = static_cast(static_cast(tmp.y) << 4) * scale[noffset + 1]; + ref[i * n + j + 0] = static_cast(static_cast(tmp.x) - 8) * scale[noffset + 0]; + ref[i * n + j + 1] = static_cast(static_cast(tmp.y) - 8) * scale[noffset + 1]; } } sycl_vector dS(scale.size(), q), dequantB(n * k, q); @@ -673,10 +677,10 @@ class Benchmark_DequantS4 { tm.start(); while (tm.stop() < TestMs) { for (size_t i = 0; i < 1; i++) { - auto e_esimd = ProB::dequant_s4(n, k, blocksize, {dB.data(), dS.data(), n}, - dequantB.data(), q); - e_esimd.wait(); - log.add(event_helper::execute_time(e_esimd) * 1000); + auto ev = ProB::dequant_s4(n, k, blocksize, {dB.data(), dS.data(), n}, + dequantB.data(), q); + ev.wait(); + log.add(event_helper::execute_time(ev) * 1000); if (tm.stop() >= TestMs) { break; } diff --git a/bestla/bestla/ut/sycl_gemm.cpp b/bestla/bestla/ut/sycl_gemm.cpp index d715e3afd..19c58946a 100644 --- a/bestla/bestla/ut/sycl_gemm.cpp +++ b/bestla/bestla/ut/sycl_gemm.cpp @@ -2,7 +2,7 @@ #include "sycl_ut.h" #include "../sycl/sycl_wrapper.h" #include "bestla_prologue_b.h" - +#undef BTLA_UT_SYCL namespace bestla { using namespace ut; using namespace utils; @@ -40,8 +40,8 @@ class UT_SyclSGemm { auto A_d = dA.data(); auto B_d = dB.data(); auto C_d = dC.data(); - auto e_esimd = KernelLauncher::compute(q, m, n, k, {{A_d, k}, {B_d, n}, {C_d, n}}); - e_esimd.wait(); + auto ev = KernelLauncher::compute(q, m, n, k, {{A_d, k}, {B_d, n}, {C_d, n}}); + ev.wait(); q->memcpy(matC.data(), C_d, matC.size() * 4).wait(); buffer_error(ref.data(), matC.data(), ref.size(), 0.001f); } @@ -83,8 +83,8 @@ class UT_SyclHGemm { auto A_d = dA.data(); auto B_d = dB.data(); auto C_d = dC.data(); - auto e_esimd = KernelLauncher::compute(q, m, n, k, {{A_d, k}, {B_d, n}, {C_d, n}}); - e_esimd.wait(); + auto ev = KernelLauncher::compute(q, m, n, k, {{A_d, k}, {B_d, n}, {C_d, n}}); + ev.wait(); q->memcpy(matC.data(), C_d, matC.size() * 2).wait(); buffer_error(ref.data(), matC.data(), ref.size(), utils::fp16(0.2f)); } @@ -97,6 +97,9 @@ class UT_SyclS4SGemm { public: UT_SyclS4SGemm() { UT_START(); + utT(6, 4096, 11008, 128); + ut(6, 32000, 4096, 128); + utT(6, 32000, 4096, 128); ut(300, 1024, 1024, 32); ut(1024, 1024, 1024, 32); utT(1024, 1024, 1024, 32); @@ -129,8 +132,8 @@ class UT_SyclS4SGemm { for (int j = 0; j < n; j += 2) { auto tmp = srcptr[i * n / 2 + j / 2]; auto noffset = i / blocksize * n + j; - matB[i * n + j + 0] = static_cast(static_cast(tmp.x) << 4) * B_scale[noffset + 0]; - matB[i * n + j + 1] = static_cast(static_cast(tmp.y) << 4) * B_scale[noffset + 1]; + matB[i * n + j + 0] = static_cast(static_cast(tmp.x) - 8) * B_scale[noffset + 0]; + matB[i * n + j + 1] = static_cast(static_cast(tmp.y) - 8) * B_scale[noffset + 1]; } } gemmref_fp32fp32fp32(m, n, k, matA.data(), matB.data(), ref.data(), k, n, n); @@ -146,8 +149,8 @@ class UT_SyclS4SGemm { auto B_scale_d = dB_scale.data(); auto C_d = dC.data(); utils::GemmProblem gp(1, m, n, k); - auto e_esimd = KernelLauncher::compute(q, m, n, k, blocksize, {{A_d, k}, {Bs8_d, B_scale_d, n}, {C_d, n}}); - e_esimd.wait(); + auto ev = KernelLauncher::compute(q, m, n, k, blocksize, {{A_d, k}, {Bs8_d, B_scale_d, n}, {C_d, n}}); + ev.wait(); q->memcpy(matC.data(), C_d, matC.size() * 4).wait(); buffer_error(ref.data(), matC.data(), ref.size(), 0.001f); } @@ -168,8 +171,8 @@ class UT_SyclS4SGemm { for (int j = 0; j < k; j += 2) { auto tmp = srcptr[i * k / 2 + j / 2]; auto noffset = i * blks + j / blocksize; - matB[i * k + j + 0] = static_cast(static_cast(tmp.x) << 4) * B_scale[noffset]; - matB[i * k + j + 1] = static_cast(static_cast(tmp.y) << 4) * B_scale[noffset]; + matB[i * k + j + 0] = static_cast(static_cast(tmp.x) - 8) * B_scale[noffset]; + matB[i * k + j + 1] = static_cast(static_cast(tmp.y) - 8) * B_scale[noffset]; } } avector matBNT(k * n); @@ -187,8 +190,8 @@ class UT_SyclS4SGemm { auto B_scale_d = dB_scale.data(); auto C_d = dC.data(); utils::GemmProblem gp(1, m, n, k); - auto e_esimd = KernelTLauncher::compute(q, m, n, k, blocksize, {{A_d, k}, {Bs8_d, B_scale_d, blks}, {C_d, n}}); - e_esimd.wait(); + auto ev = KernelTLauncher::compute(q, m, n, k, blocksize, {{A_d, k}, {Bs8_d, B_scale_d, blks}, {C_d, n}}); + ev.wait(); q->memcpy(matC.data(), C_d, matC.size() * 4).wait(); buffer_error(ref.data(), matC.data(), ref.size(), 0.001f); } @@ -233,8 +236,8 @@ class UT_SyclS4HGemm { for (int j = 0; j < n; j += 2) { auto tmp = srcptr[i * n / 2 + j / 2]; auto noffset = i / blocksize * n + j; - matB[i * n + j + 0] = static_cast(static_cast(tmp.x) << 4) * float(B_scale[noffset + 0]); - matB[i * n + j + 1] = static_cast(static_cast(tmp.y) << 4) * float(B_scale[noffset + 1]); + matB[i * n + j + 0] = static_cast(static_cast(tmp.x) - 8) * float(B_scale[noffset + 0]); + matB[i * n + j + 1] = static_cast(static_cast(tmp.y) - 8) * float(B_scale[noffset + 1]); } } gemmref_fp16fp16fp16(m, n, k, matA.data(), matB.data(), ref.data(), k, n, n); @@ -247,8 +250,8 @@ class UT_SyclS4HGemm { auto Bs8_d = dBs8.data(); auto B_scale_d = dB_scale.data(); auto C_d = dC.data(); - auto e_esimd = KernelLauncher::compute(q, m, n, k, blocksize, {{A_d, k}, {Bs8_d, B_scale_d, n}, {C_d, n}}); - e_esimd.wait(); + auto ev = KernelLauncher::compute(q, m, n, k, blocksize, {{A_d, k}, {Bs8_d, B_scale_d, n}, {C_d, n}}); + ev.wait(); q->memcpy(matC.data(), C_d, matC.size() * 2).wait(); buffer_error(ref.data(), matC.data(), ref.size(), utils::fp16(0.2f)); } @@ -269,8 +272,8 @@ class UT_SyclS4HGemm { for (int j = 0; j < k; j += 2) { auto tmp = srcptr[i * k / 2 + j / 2]; auto noffset = i * blks + j / blocksize; - matB[i * k + j + 0] = static_cast(static_cast(tmp.x) << 4) * float(B_scale[noffset]); - matB[i * k + j + 1] = static_cast(static_cast(tmp.y) << 4) * float(B_scale[noffset]); + matB[i * k + j + 0] = static_cast(static_cast(tmp.x) - 8) * float(B_scale[noffset]); + matB[i * k + j + 1] = static_cast(static_cast(tmp.y) - 8) * float(B_scale[noffset]); } } avector matBNT(k * n); @@ -285,8 +288,8 @@ class UT_SyclS4HGemm { auto Bs8_d = dBs8.data(); auto B_scale_d = dB_scale.data(); auto C_d = dC.data(); - auto e_esimd = KernelTLauncher::compute(q, m, n, k, blocksize, {{A_d, k}, {Bs8_d, B_scale_d, blks}, {C_d, n}}); - e_esimd.wait(); + auto ev = KernelTLauncher::compute(q, m, n, k, blocksize, {{A_d, k}, {Bs8_d, B_scale_d, blks}, {C_d, n}}); + ev.wait(); q->memcpy(matC.data(), C_d, matC.size() * 2).wait(); buffer_error(ref.data(), matC.data(), ref.size(), utils::fp16(0.2f)); } @@ -317,8 +320,8 @@ class UT_SyclInt4Dequant { for (int j = 0; j < n; j += 2) { auto tmp = srcptr[i * n / 2 + j / 2]; auto noffset = i / blocksize * n + j; - ref[i * n + j + 0] = static_cast(static_cast(tmp.x) << 4) * scale[noffset + 0]; - ref[i * n + j + 1] = static_cast(static_cast(tmp.y) << 4) * scale[noffset + 1]; + ref[i * n + j + 0] = static_cast(static_cast(tmp.x) - 8) * scale[noffset + 0]; + ref[i * n + j + 1] = static_cast(static_cast(tmp.y) - 8) * scale[noffset + 1]; } } using ProB = sycl_prologue_b::WeightS4; @@ -329,8 +332,8 @@ class UT_SyclInt4Dequant { auto S_d = dS.data(); auto B_d = dB.data(); auto DB_d = dequantB.data(); - auto e_esimd = ProB::dequant_s4(n, k, blocksize, {B_d, S_d, n}, DB_d, q); - e_esimd.wait(); + auto ev = ProB::dequant_s4(n, k, blocksize, {B_d, S_d, n}, DB_d, q); + ev.wait(); q->memcpy(dequant.data(), DB_d, dequant.size() * 4).wait(); buffer_error(ref.data(), dequant.data(), dequant.size(), 0.001f); } @@ -349,8 +352,8 @@ class UT_SyclInt4Dequant { for (int j = 0; j < k; j += 2) { auto tmp = srcptr[i * k / 2 + j / 2]; auto noffset = i * blks + j / blocksize; - ref[i * k + j + 0] = static_cast(static_cast(tmp.x) << 4) * scale[noffset]; - ref[i * k + j + 1] = static_cast(static_cast(tmp.y) << 4) * scale[noffset]; + ref[i * k + j + 0] = static_cast(static_cast(tmp.x) - 8) * scale[noffset]; + ref[i * k + j + 1] = static_cast(static_cast(tmp.y) - 8) * scale[noffset]; } } using ProB = sycl_prologue_b::WeightS4Trans; @@ -361,15 +364,15 @@ class UT_SyclInt4Dequant { auto S_d = dS.data(); auto B_d = dB.data(); auto DB_d = dequantB.data(); - auto e_esimd = ProB::dequant_s4(n, k, blocksize, {B_d, S_d, blks}, DB_d, q); - e_esimd.wait(); + auto ev = ProB::dequant_s4(n, k, blocksize, {B_d, S_d, blks}, DB_d, q); + ev.wait(); q->memcpy(dequant.data(), DB_d, dequant.size() * 4).wait(); buffer_error(ref.data(), dequant.data(), dequant.size(), 0.001f); avector refNT(k * n); kernel::wrapper::Transpose2D::forward(ref.data(), refNT.data(), n, k, k, n); - e_esimd = ProB::dequant_s4_trans(n, k, blocksize, {B_d, S_d, blks}, DB_d, q); - e_esimd.wait(); + ev = ProB::dequant_s4_trans(n, k, blocksize, {B_d, S_d, blks}, DB_d, q); + ev.wait(); q->memcpy(dequant.data(), DB_d, dequant.size() * 4).wait(); buffer_error(refNT.data(), dequant.data(), dequant.size(), 0.001f); } @@ -382,7 +385,9 @@ class UT_SyclS4Gemv { public: UT_SyclS4Gemv() { UT_START(); + ut_T(1024, 11008, 32); ut_T(1024, 1024, 32); + ut_half(1024, 11008, 32); ut_half(1024, 1024, 32); } using SGemm_t = xve::DefaultSGemmCore; @@ -411,8 +416,8 @@ class UT_SyclS4Gemv { for (int j = 0; j < k; j += 2) { auto tmp = srcptr[i * k / 2 + j / 2]; auto noffset = i * blks + j / blocksize; - dqB[i + (j + 0) * n] = static_cast(static_cast(tmp.x) << 4) * scale[noffset]; - dqB[i + (j + 1) * n] = static_cast(static_cast(tmp.y) << 4) * scale[noffset]; + dqB[i + (j + 0) * n] = static_cast(static_cast(tmp.x) - 8) * scale[noffset]; + dqB[i + (j + 1) * n] = static_cast(static_cast(tmp.y) - 8) * scale[noffset]; } } gemmref_fp32fp32fp32(1, n, k, A.data(), dqB.data(), refC.data(), k, n, n); @@ -430,8 +435,8 @@ class UT_SyclS4Gemv { auto A_d = dA.data(); auto B_d = dB.data(); auto C_d = dC.data(); - auto e_esimd = ProBTransT::gemv(A_d, {B_d, S_d, blks}, C_d, n, k, blocksize, q); - e_esimd.wait(); + auto ev = ProBTransT::gemv(A_d, {B_d, S_d, blks}, C_d, n, k, blocksize, q); + ev.wait(); q->memcpy(C.data(), C_d, C.size() * 4).wait(); buffer_error(refC.data(), C.data(), C.size(), 0.001f); } @@ -452,8 +457,8 @@ class UT_SyclS4Gemv { auto tmp = srcptr[i * k / 2 + j / 2]; auto noffset = i * blks + j / blocksize; float fscale = float(scale[noffset]); - dqB[i + (j + 0) * n] = static_cast(static_cast(tmp.x) << 4) * fscale; - dqB[i + (j + 1) * n] = static_cast(static_cast(tmp.y) << 4) * fscale; + dqB[i + (j + 0) * n] = static_cast(static_cast(tmp.x) - 8) * fscale; + dqB[i + (j + 1) * n] = static_cast(static_cast(tmp.y) - 8) * fscale; } } gemmref_fp16fp16fp16(1, n, k, A.data(), dqB.data(), refC.data(), k, n, n); @@ -471,9 +476,9 @@ class UT_SyclS4Gemv { auto A_d = dA.data(); auto B_d = dB.data(); auto C_d = dC.data(); - auto e_esimd = sycl_prologue_b::WeightS4Trans::gemv(A_d, {B_d, S_d, blks}, C_d, - n, k, blocksize, q); - e_esimd.wait(); + auto ev = sycl_prologue_b::WeightS4Trans::gemv(A_d, {B_d, S_d, blks}, C_d, n, k, + blocksize, q); + ev.wait(); q->memcpy(C.data(), C_d, C.size() * 2).wait(); buffer_error(refC.data(), C.data(), C.size(), utils::fp16(0.1f)); } @@ -481,5 +486,277 @@ class UT_SyclS4Gemv { #ifdef BTLA_UT_SYCL static UT_SyclS4Gemv sUT_SyclS4Gemv; #endif + +void mha_sref(float* Q, float* K, float* V, float* S, float* O, int batch, int seq, int seqA, int hnum, int hsize) { + avector tmps(seqA); + int nf = hnum * hsize; + const float attn_scale = 1.0f / sqrtf(static_cast(hsize)); + int n_past = seqA - seq; + for (int i = 0; i < batch; i++) { + for (int j = 0; j < seq; j++) { + for (int ii = 0; ii < hnum; ii++) { + float maxs = 0.f; + for (int jj = 0; jj < seqA; jj++) { + float tmp = 0.f; + if (jj <= j + n_past) { + for (int kk = 0; kk < hsize; kk++) { + tmp += + Q[i * seq * nf + j * nf + ii * hsize + kk] * K[i * nf * seqA + ii * seqA * hsize + jj * hsize + kk]; + } + tmp *= attn_scale; + } else { + tmp = -INFINITY; + } + + tmps[jj] = tmp; + maxs = std::max(maxs, tmp); + } + float sums = 0.f; + for (int jj = 0; jj < seqA; jj++) { + tmps[jj] = std::expf(tmps[jj] - maxs); + sums += tmps[jj]; + } + sums = 1.f / sums; + for (int jj = 0; jj < seqA; jj++) { + tmps[jj] *= sums; + S[i * seq * hnum * seqA + j * hnum * seqA + ii * seqA + jj] = tmps[jj]; + } + for (int kk = 0; kk < hsize; kk++) { + float tmp = 0.f; + for (int jj = 0; jj < seqA; jj++) { + tmp += tmps[jj] * V[i * nf * seqA + ii * hsize * seqA + kk * seqA + jj]; + } + O[i * seq * nf + j * nf + ii * hsize + kk] = tmp; + } + } + } + } +} + +class UT_MHASgemm { + public: + UT_MHASgemm() { + UT_START(); + ut_T(1, 1, 1, 32, 128); + ut_T(1, 1, 64, 32, 128); + ut_T(4, 1, 64, 32, 128); + ut_T(4, 64, 64, 32, 128); + } + template + class MHA { + public: + template + static sycl::event forward(int batch, int seq, int seq_acc, int hnum, int hsize, const T* Q, const T* K, const T* V, + T_DST* O, sycl::queue* q) { + const float attn_scale = 1.0f / sqrtf(static_cast(hsize)); + int constexpr SgSize = 16; + assert(hsize % SgSize == 0); + int n_past = seq_acc - seq; + if constexpr (Mask) { + assert(seq > 1); + } + int WgSize = SgSize; + int seq_acc_pad = utils::padto_le(seq_acc, WgSize * 2); + int nf = hnum * hsize; + auto ev = q->submit([&](sycl::handler& cgh) { + sycl::local_accessor slm(sycl::range(std::max(seq_acc, 1024)), cgh); + cgh.parallel_for(sycl::nd_range<1>(WgSize * batch * seq * hnum, WgSize), + [=](auto it) [[intel::reqd_sub_group_size(SgSize)]] { + auto sg = it.get_sub_group(); + auto sg_idx = sg.get_group_id()[0]; + auto wg_idx = it.get_group(0); + auto wg_loc_id = it.get_local_id(); + auto lane_id = sg.get_local_id()[0]; + + int i = wg_idx; + int ih = i % hnum; + i /= hnum; + int is = i % seq; + i /= seq; + int ib = i % batch; + size_t Q_off = ib * seq * nf + is * nf + ih * hsize; + size_t K_off = ib * seq_acc * nf + ih * hsize * seq_acc; + size_t V_off = ib * seq_acc * nf + ih * hsize * seq_acc; + size_t O_off = ib * seq * nf + is * nf + ih * hsize; + typedef sycl::vec TC; + T maxs = -INFINITY; + for (int jj = 0; jj < seq_acc; jj++) { + TC tmp = {0, 0}; + if constexpr (Mask) { + if (jj <= is + n_past) { + for (int ik = wg_loc_id * 2; ik < hsize; ik += WgSize * 2) { + tmp += *(TC*)&Q[Q_off + ik] * *(TC*)&K[K_off + jj * hsize + ik]; + } + tmp *= attn_scale; + } else { + tmp = {-INFINITY, -INFINITY}; + } + } else { + for (int ik = wg_loc_id * 2; ik < hsize; ik += WgSize * 2) { + tmp += *(TC*)&Q[Q_off + ik] * *(TC*)&K[K_off + jj * hsize + ik]; + } + tmp *= attn_scale; + } + T tmp_sum = tmp[0] + tmp[1]; + T sum = 0; + for (int i = 0; i < SgSize; i += 1) { + sum += sg.shuffle(tmp_sum, i); + } + slm[jj] = sum; + maxs = std::max(maxs, sum); + } + float fsums = 0.f; + float fmax = float(maxs); + int jj = wg_loc_id * 2; + for (; jj < seq_acc_pad; jj += WgSize * 2) { + auto s2 = *(TC*)&slm[jj]; + s2[0] = std::expf(s2[0] - fmax); + s2[1] = std::expf(s2[1] - fmax); + fsums += s2[0]; + fsums += s2[1]; + *(TC*)&slm[jj] = s2; + } + if (jj < seq_acc) { + slm[jj] = std::expf(float(slm[jj]) - fmax); + fsums += slm[jj]; + if (jj + 1 < seq_acc) { + slm[jj + 1] = std::expf(float(slm[jj + 1]) - fmax); + fsums += slm[jj + 1]; + } + } + float gsum = 0; + for (int i = 0; i < SgSize; i += 1) { + gsum += sg.shuffle(fsums, i); + } + T scale = 1.f / gsum; + jj = wg_loc_id * 2; + for (; jj < seq_acc_pad; jj += WgSize * 2) { + auto s2 = *(TC*)&slm[jj]; + s2 *= scale; + *(TC*)&slm[jj] = s2; + } + if (jj < seq_acc) { + slm[jj] *= scale; + if (jj + 1 < seq_acc) { + slm[jj + 1] *= scale; + } + } + + for (int kk = 0; kk < hsize; kk++) { + TC tmp = {0, 0}; + jj = wg_loc_id * 2; + for (; jj < seq_acc_pad; jj += WgSize * 2) { + auto s2 = *(TC*)&slm[jj]; + auto v2 = *(TC*)&V[V_off + kk * seq_acc + jj]; + tmp += s2 * v2; + } + if (jj < seq_acc) { + tmp[0] += slm[jj] * V[V_off + kk * seq_acc + jj]; + if (jj + 1 < seq_acc) { + tmp[1] += slm[jj + 1] * V[V_off + kk * seq_acc + jj + 1]; + } + } + T tmp_sum = tmp[0] + tmp[1]; + T sum = 0; + for (int i = 0; i < SgSize; i += 1) { + sum += sg.shuffle(tmp_sum, i); + } + O[O_off + kk] = sum; + } + }); + }); + return ev; + } + }; + + void ut_T(int batch, int seq, int seqA, int hnum, int hsize) { + auto dev = UT_Device::get(); + auto q = dev->getQueue(); + assert(seqA >= seq); + printf("Test Case %s: %d %d %d %d %d Device:%s\n", __FUNCTION__, batch, seq, seqA, hnum, hsize, + dev->getName().c_str()); + avector Q(batch * seq * hnum * hsize), K(batch * seqA * hnum * hsize), V(batch * seqA * hnum * hsize); + fill_buffer_randn(Q.data(), Q.size(), -0.5f, 0.5f); + fill_buffer_randn(K.data(), K.size(), -0.5f, 0.5f); + fill_buffer_randn(V.data(), V.size(), -0.5f, 0.5f); + avector S(batch * seq * hnum * seqA), O(batch * seq * hnum * hsize); + mha_sref(Q.data(), K.data(), V.data(), S.data(), O.data(), batch, seq, seqA, hnum, hsize); + sycl_vector dQ(batch * seq * hnum * hsize, q), dK(batch * seqA * hnum * hsize, q), + dV(batch * seqA * hnum * hsize, q); + sycl_vector dS(batch * seq * hnum * seqA, q), dO(batch * seq * hnum * hsize, q); + q->memcpy(dQ.data(), Q.data(), Q.size() * sizeof(Q[0])); + q->memcpy(dK.data(), K.data(), K.size() * sizeof(K[0])); + q->memcpy(dV.data(), V.data(), V.size() * sizeof(V[0])); + q->wait(); + auto Qptr = dQ.data(); + auto Kptr = dK.data(); + auto Vptr = dV.data(); + auto Sptr = dS.data(); + auto Optr = dO.data(); + int nf = hnum * hsize; + sycl::range<1> num_items{batch * seq * hnum}; + int n_past = seqA - seq; + const float attn_scale = 1.0f / sqrtf(static_cast(hsize)); + if (seq > 1) { + MHA::forward(batch, seq, seqA, hnum, hsize, Qptr, Kptr, Vptr, Optr, q).wait(); + } else { + MHA::forward(batch, seq, seqA, hnum, hsize, Qptr, Kptr, Vptr, Optr, q).wait(); + } + // auto ev = q->submit([&](sycl::handler& cgh) { + // cgh.parallel_for(num_items, [=](auto it) { + // int i = it; + // int ih = i % hnum; + // i /= hnum; + // int is = i % seq; + // i /= seq; + // int ib = i % batch; + // float maxs = 0.f; + // float tmps[64]; + // for (int jj = 0; jj < seqA; jj++) { + // float tmp = 0.f; + // if (jj <= is + n_past) { + // for (int kk = 0; kk < hsize; kk++) { + // tmp += Qptr[ib * seq * nf + is * nf + ih * hsize + kk] * + // Kptr[ib * nf * seqA + kk + ih * seqA * hsize + jj * hsize]; + // } + // tmp *= attn_scale; + // } else { + // tmp = -INFINITY; + // } + + // tmps[jj] = tmp; + // maxs = std::max(maxs, tmp); + // } + // float sums = 0.f; + // for (int jj = 0; jj < seqA; jj++) { + // tmps[jj] = std::expf(tmps[jj] - maxs); + // sums += tmps[jj]; + // } + // sums = 1.f / sums; + // for (int jj = 0; jj < seqA; jj++) { + // tmps[jj] *= sums; + // Sptr[ib * seq * hnum * seqA + is * hnum * seqA + ih * seqA + jj] = tmps[jj]; + // } + // for (int kk = 0; kk < hsize; kk++) { + // float tmp = 0.f; + // for (int jj = 0; jj < seqA; jj++) { + // tmp += tmps[jj] * Vptr[ib * seqA * nf + jj + ih * hsize * seqA + kk * seqA]; + // } + // Optr[ib * seq * nf + is * nf + ih * hsize + kk] = tmp; + // } + // }); + //}); + q->wait(); + avector STar(batch * seq * hnum * seqA), OTar(batch * seq * hnum * hsize); + q->memcpy(STar.data(), Sptr, STar.size() * sizeof(STar[0])); + q->memcpy(OTar.data(), Optr, OTar.size() * sizeof(OTar[0])); + q->wait(); + // buffer_error(S.data(), STar.data(), S.size(), 0.001f); + buffer_error(O.data(), OTar.data(), O.size(), 0.001f); + } +}; +#ifdef BTLA_UT_SYCL +#endif +static UT_MHASgemm sUT_MHASgemm; } // namespace sycl_ut } // namespace bestla diff --git a/bestla/bestla/ut/sycl_misc.cpp b/bestla/bestla/ut/sycl_misc.cpp index e81521959..2a41bfaa9 100644 --- a/bestla/bestla/ut/sycl_misc.cpp +++ b/bestla/bestla/ut/sycl_misc.cpp @@ -1,8 +1,13 @@ #include "bestla_ut.h" +#include "bestla_prologue_b.h" #include "sycl_ut.h" -#include "../sycl/sycl_device.h" -#include "../sycl/sycl_utils.h" +#include "sycl/sycl_device.h" +#include "sycl/sycl_utils.h" +#include "sycl/sycl_storage.h" +#include "sycl/sycl_gemm.h" +#include "sycl/sycl_prologue_b.h" namespace bestla { +using namespace ut; using namespace utils; namespace sycl_ut { class UT_SyclDevice { @@ -13,7 +18,7 @@ class UT_SyclDevice { dev->print(); } }; -// static UT_SyclDevice sUT_SyclDevice; +static UT_SyclDevice sUT_SyclDevice; class UT_SyclVector { public: @@ -30,5 +35,142 @@ class UT_SyclVector { } }; // static UT_SyclVector sUT_SyclVector; + +class UT_StorageMemCheck { + public: + UT_StorageMemCheck() { + UT_START(); + ut_s4(4096, 4096, 128, BTLA_DTYPE::S4_CLIP); + } + + void ut_s4(int n, int k, int blocksize, BTLA_DTYPE qtype, bool asym = false) { + printf("Test C type Case: %d %d %d %s\n", n, k, blocksize, asym ? "asym" : "sym"); + int ldb = n; + int kblk_num = utils::updiv(k, blocksize); + using GemmCore = sycl_gemm::xve::DefaultSGemmCore; + using PrologueB = prologue_b::gemm::WeightKBlockNInteger; + + auto packedW = + PrologueB::createStorage(n, k, blocksize, qtype, bestla_dtype, bestla_dtype, asym); + avector buf0(packedW.mSize), buf1(packedW.mSize); + packedW.assign(buf0.data()); + auto dev = UT_Device::get(); + auto q = dev->getQueue(); + sycl_storage::StorageWeightKBlockNInteger sycl_stor(packedW); + sycl_utils::sycl_vector dbuf(sycl_stor.getDeviceSize(), q); + sycl_stor.assign(dbuf.data()); + sycl_stor.fromHost(packedW, q); + storage::gemm::StorageWeightKBlockNInteger tmp = packedW; + tmp.assign(buf1.data()); + sycl_stor.toHost(tmp, q); + buffer_error(buf0.data(), buf1.data(), buf0.size()); + } +}; +static UT_StorageMemCheck sUT_StorageMemCheck; + +class UT_BlockQunatize_S3S4 { + public: + UT_BlockQunatize_S3S4() { + UT_START(); + using GemmCore = sycl_gemm::xve::DefaultSGemmCore; + + ut(4096, 4096, 32, BTLA_DTYPE::S4_CLIP); + } + template + void ut(int n, int k, int blocksize, BTLA_DTYPE QUANT_T, bool isAsym = false) { + auto constexpr RuntimeISA = BTLA_ISA::AVX2; + auto dev = UT_Device::get(); + auto q = dev->getQueue(); + printf("%s DType %s %d: %d %d %d Asym:%d\n", __FUNCTION__, utils::bestla_dtype_str(QUANT_T), int(RuntimeISA), n, k, + blocksize, isAsym); + utils::aligned_vector raw(n * k); + ut::fill_buffer_randn(raw.data(), raw.size(), -0.5f, 0.5f); + using PrologueB = prologue_b::gemm::WeightKBlockNInteger; + auto ptr = PrologueB::createStorage(n, k, blocksize, QUANT_T, BTLA_DTYPE::F32, BTLA_DTYPE::F32, isAsym); + avector buffer(ptr.mSize); + ptr.assign(buffer.data()); + PrologueB::packTransposeWeight(n, k, raw.data(), k, &ptr, UT_Threading::get()); + auto transtor = ptr.toTrans(); + avector buffer1(transtor.mSize); + transtor.assign(buffer1.data()); + PrologueB::convertTransStorage(ptr, transtor, UT_Threading::get()); + sycl_storage::StorageWeightKBlockNInteger sycl_stor(transtor); + sycl_utils::sycl_vector dbuf(sycl_stor.getDeviceSize(), q); + sycl_stor.assign(dbuf.data()); + sycl_stor.fromHost(transtor, q); + avector dequant(n * k, 0); + using ProB = sycl_prologue_b::WeightS4Trans; + sycl_utils::sycl_vector dequantB(n * k, q); + int blks = updiv(k, blocksize); + auto evt = ProB::dequant_s4( + n, k, blocksize, {(uint8_t*)sycl_stor.mQBuf, (float*)sycl_stor.mSBuf, blks}, dequantB.data(), q); + evt.wait(); + q->memcpy(dequant.data(), dequantB.data(), dequantB.size() * 4).wait(); + ut::buffer_error(raw.data(), dequant.data(), dequant.size(), 0.01f); + } +}; +#ifdef BTLA_UT_PROLOGUE_B +// no proper threshold for this UT +// +#endif +static UT_BlockQunatize_S3S4 sUT_BlockQunatize_S3S4; + +class UT_CompFp32 { + public: + UT_CompFp32() { + UT_START(); + ut_s4(); + } + + void ut_s4() { + using GemmCore = sycl_gemm::xve::DefaultSGemmCore; + ut(1, 4096, 4096, 32, BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32, false); + ut(1, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32, false); + } + + template + void ut(int m, int n, int k, int blocksize, BTLA_DTYPE qtype, BTLA_DTYPE stype, bool isAsym) { + auto dev = UT_Device::get(); + auto q = dev->getQueue(); + printf("Test Case %s: %d %d %d-%d type:%s core:%s scaletype:%s Asym:%d\n", __FUNCTION__, m, n, k, blocksize, + bestla_dtype_str(qtype), gemm::CoreAttr::to_str(GemmCore_T::ID), bestla_dtype_str(stype), isAsym); + auto constexpr RuntimeISA = BTLA_ISA::AVX2; + using PrologueB = prologue_b::gemm::WeightKBlockNInteger; + blocksize = blocksize == -1 ? k : blocksize; + auto packedw = PrologueB::createStorage(n, k, blocksize, qtype, stype, bestla_dtype, isAsym); + utils::avector buffer(packedw.mSize); + packedw.assign(buffer.data()); + avector matBf32(k * n), matAf32(m * k), matC(m * n), refC(m * n); + fill_buffer_randn(matBf32.data(), matBf32.size(), -0.5f, 0.5f); + fill_buffer_randn(matAf32.data(), matAf32.size(), -0.5f, 0.5f); + PrologueB::packWeight(n, k, matBf32.data(), n, &packedw, UT_Threading::get()); + gemmref_fp32fp32fp32(m, n, k, matAf32.data(), matBf32.data(), refC.data(), k, n, n); + sycl_utils::sycl_vector dC(n, q), dA(k * m, q); + q->memcpy(dA.data(), matAf32.data(), matAf32.size() * 4).wait(); + using ProBTransT = sycl_prologue_b::WeightS4Trans; + auto transtor = packedw.toTrans(); + avector buffer1(transtor.mSize); + transtor.assign(buffer1.data()); + PrologueB::convertTransStorage(packedw, transtor, UT_Threading::get()); + sycl_storage::StorageWeightKBlockNInteger sycl_stor(transtor); + sycl_utils::sycl_vector dbuf(sycl_stor.getDeviceSize(), q); + sycl_stor.assign(dbuf.data()); + sycl_stor.fromHost(transtor, q); + int blks = updiv(k, blocksize); + auto ev = ProBTransT::gemv(dA.data(), {(uint8_t*)sycl_stor.mQBuf, (float*)sycl_stor.mSBuf, blks}, dC.data(), n, k, + blocksize, q); + ev.wait(); + q->memcpy(matC.data(), dC.data(), matC.size() * 4).wait(); + + auto err = get_ut_err(qtype); + auto dbits = bestla_dtype_bits(qtype); + auto type = bestla_dtype_type(qtype); + auto constexpr dtype_int = bestla_dtype_type(BTLA_DTYPE::TypeInt); + buffer_error(refC.data(), matC.data(), refC.size(), err); + } +}; +#ifdef BTLA_UT_PROLOGUE_B +#endif +static UT_CompFp32 sUT_CompFp32; } // namespace sycl_ut } // namespace bestla diff --git a/neural_speed/application/main_pybind.cpp b/neural_speed/application/main_pybind.cpp index d632e8602..0822775c8 100644 --- a/neural_speed/application/main_pybind.cpp +++ b/neural_speed/application/main_pybind.cpp @@ -454,7 +454,7 @@ void Model::init_model(const std::string& model_path, int max_new_tokens, int n_ token_eos = false; curr_input_ids.clear(); curr_input_ids.resize(params.max_request_num); - ctx = model_init_from_gpt_params(params); + ctx = model_init_from_gpt_params(params, nullptr); n_vocab = model_n_vocab(ctx); n_ctx = model_n_ctx(ctx); last_n_tokens.resize(params.max_request_num); diff --git a/neural_speed/application/main_run.cpp b/neural_speed/application/main_run.cpp index 745c3665a..0a441d18a 100644 --- a/neural_speed/application/main_run.cpp +++ b/neural_speed/application/main_run.cpp @@ -37,6 +37,8 @@ #include "models/model_utils/model_config.h" #include "models/model_utils/model_utils.h" +#include "core/ne_bestla.h" + #if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) #include #include @@ -132,16 +134,21 @@ int main(int argc, char** argv) { // NOLINT if (params.random_prompt) { params.prompt = gpt_random_prompt(rng); } + bestla_set_threads(params.n_threads); model_init_backend(); + ne_sycl_context* dev_ctx = model_init_sycl(false); model_context* ctx; g_ctx = &ctx; - + if (dev_ctx == nullptr) { + params.n_gpu_layers = 0; + } // load the model and apply lora adapter, if any - ctx = model_init_from_gpt_params(params); + ctx = model_init_from_gpt_params(params, dev_ctx); if (ctx == nullptr) { fprintf(stderr, "%s: error: unable to load model\n", __func__); + model_release_sycl(dev_ctx); return 1; } diff --git a/neural_speed/application/pybind_gptj.cpp b/neural_speed/application/pybind_gptj.cpp index f98dc7a0a..36ceb8334 100644 --- a/neural_speed/application/pybind_gptj.cpp +++ b/neural_speed/application/pybind_gptj.cpp @@ -97,7 +97,7 @@ void* init_gptj(int seed, int n_predict, int n_batch, int top_k, float top_p, fl model_init_backend(); model_context* ctx; g_ctx = &ctx; - ctx = model_init_from_gpt_params(params); + ctx = model_init_from_gpt_params(params, nullptr); if (ctx == nullptr) { fprintf(stderr, "%s: error: unable to load model\n", __func__); return nullptr; diff --git a/neural_speed/core/CMakeLists.txt b/neural_speed/core/CMakeLists.txt index 3c5eb694b..487273c4c 100644 --- a/neural_speed/core/CMakeLists.txt +++ b/neural_speed/core/CMakeLists.txt @@ -36,9 +36,10 @@ endif() if(NOT WIN32) target_link_libraries(ne_layers PUBLIC rt) else() - target_link_options(ne_layers PUBLIC /STACK:5242880) + target_link_options(ne_layers PUBLIC /STACK:5242880 /F5242880) endif() + if (NS_BUILD_TESTS) function(add_test_target src) # ARGN: additional source diff --git a/neural_speed/core/layers/bestla_common.hpp b/neural_speed/core/layers/bestla_common.hpp index c830e101e..ee895bff4 100644 --- a/neural_speed/core/layers/bestla_common.hpp +++ b/neural_speed/core/layers/bestla_common.hpp @@ -28,8 +28,11 @@ class ne_threading { static bestla::parallel::IThreading* get() { GetCPUDevice(); static bestla::parallel::StdThreading OptmizedThreading; -#ifdef NS_USE_OMP +#if (BTLA_OPENMP && NS_USE_OMP) static bestla::parallel::OMPThreading DefaultThreading; +#ifdef NS_SYCL + return &DefaultThreading; +#endif if (!_cd->isHybrid()) { return &DefaultThreading; } diff --git a/neural_speed/core/layers/ne_bestla.cpp b/neural_speed/core/layers/ne_bestla.cpp index e9e8f1024..471845f00 100644 --- a/neural_speed/core/layers/ne_bestla.cpp +++ b/neural_speed/core/layers/ne_bestla.cpp @@ -162,3 +162,115 @@ void bestla_add(int batch, int vsize, const float* tensor, const float* vector, pth->parallel_for(threadfunc); } } + +static inline bool ne_is_contiguous(const struct ne_tensor* tensor) { + static_assert(NE_MAX_DIMS == 4, "NE_MAX_DIMS is not 4 - update this function"); + return tensor->nb[0] <= tensor->nb[1] && tensor->nb[1] <= tensor->nb[2] && tensor->nb[2] <= tensor->nb[3]; +} + +static inline int ne_nrows(const struct ne_tensor* tensor) { + static_assert(NE_MAX_DIMS == 4, "NE_MAX_DIMS is not 4 - update this function"); + return tensor->ne[1] * tensor->ne[2] * tensor->ne[3]; +} + +ne_backend bestla_backend_support(struct ne_tensor* src0, struct ne_tensor* src1, enum ne_op op) { + ne_backend bk = NE_BACKEND_CPU; +#ifdef NS_SYCL + bool src_on_device = src0->backend == NE_BACKEND_SYCL; + if (src1) { + src_on_device |= src1->backend == NE_BACKEND_SYCL; + } + switch (op) { + case NE_OP_MUL_MAT: { + struct ne_tensor* wei = src0; + if (src0->type == NE_TYPE_BTLA) { + bk = src_on_device ? NE_BACKEND_SYCL : NE_BACKEND_CPU; + } + } break; + case NE_OP_RMS_NORM: + case NE_OP_SILU: + case NE_OP_ADD: + case NE_OP_MUL: { + if (src0->type == NE_TYPE_F32) { + bk = src_on_device ? NE_BACKEND_SYCL : NE_BACKEND_CPU; + } + } break; + default: + break; + } +#endif + return bk; +} + +bool bestla_support(struct ne_tensor* node, int n_threads, size_t* workspace, size_t* dev_workspace) { + size_t ws_h = 0; + size_t ws_d = 0; + bool support = false; + if (node->backend == NE_BACKEND_SYCL) { + support = true; + } + switch (node->op) { + case NE_OP_MUL_MAT_ID: + case NE_OP_MUL_MAT_BIAS: + case NE_OP_MUL_MAT: { + struct ne_tensor* wei = node->src0; + if (node->op == NE_OP_MUL_MAT_ID) { + wei = node->opt[0]; + } + if (node->src0->type == NE_TYPE_BTLA) { + if (node->src0->backend == NE_BACKEND_CPU) { + ws_h = bestla_f32f32_get_workspace_size(node->src1->ne[1], wei->ne[1], node->src1->ne[0], wei->data); + } + support = true; + } + } break; + case NE_OP_ROPE: + if (node->type == NE_TYPE_BTLA) support = true; + break; + case NE_OP_MUL: + case NE_OP_ADD: { + if (ne_is_contiguous(node->src1) && ne_is_contiguous(node->src0) && + (ne_nrows(node->src1) == 1 || ne_nrows(node->src1) == ne_nrows(node->src0)) && + node->src0->ne[0] == node->src1->ne[0] && node->nb[0] == sizeof(float)) { + support = true; + } + } break; + case NE_OP_MUL_FFN_SILU: + case NE_OP_MUL_FFN_GELU: + case NE_OP_MUL_FFN_GELU_MUL: + case NE_OP_MUL_FFN_ADD_GELU: { + if (node->src0->backend == NE_BACKEND_CPU) { + ws_h = bestla_fusion_FFN_f32f32_get_workspace_size(node->src0->ne[1], node->src0->ne[0], node->src1->ne[1], + node->opt[0]->ne[1], node->src1->data, node->opt[0]->data); + support = true; + } + } break; + case NE_OP_MUL_ID_FFN_GELU: + case NE_OP_MUL_ID_FFN_SILU: { + if (node->src0->backend == NE_BACKEND_CPU) { + ws_h = bestla_fusion_FFN_f32f32_get_workspace_size(node->src0->ne[1], node->src0->ne[0], node->opt[0]->ne[1], + node->opt[9]->ne[1], node->opt[0]->data, node->opt[9]->data); + support = true; + } + } break; + case NE_OP_MUL_QKV: { + ws_h = bestla_fusion_QKV_f32f32_get_workspace_size(node->src0->ne[1], node->src1->ne[1], node->src1->ne[0], + node->src1->data); + support = true; + } break; + case NE_OP_NORM: + case NE_OP_RMS_NORM: { + if (ne_is_contiguous(node->src0)) { + support = true; + } + } break; + default: + break; + } + if (support) { + node->n_tasks = 1; + } + *workspace = ws_h; + *dev_workspace = ws_d; + return support; +} diff --git a/neural_speed/core/layers/ne_bestla_sycl.cpp b/neural_speed/core/layers/ne_bestla_sycl.cpp new file mode 100644 index 000000000..170f5c503 --- /dev/null +++ b/neural_speed/core/layers/ne_bestla_sycl.cpp @@ -0,0 +1,880 @@ +// Copyright (c) 2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "bestla_common.hpp" +#include "bestla_gemm.h" +using namespace bestla; // NOLINT +using namespace ne_bestla; // NOLINT + +#ifdef NS_SYCL +#include "bestla/sycl/sycl_device.h" +#include "bestla/sycl/sycl_storage.h" +#include "bestla/sycl/sycl_gemm.h" +#include "bestla/sycl/sycl_prologue_b.h" +#include "bestla/sycl/sycl_wrapper.h" + +void* bestla_create_device(bool profile) { + auto ptr = new sycl_device::SyclDevice(profile); + ptr->print(); + return ptr; +} + +void* bestla_get_device_queue(void* device) { + if (device) { + auto ptr = (sycl_device::SyclDevice*)device; + auto q = ptr->getQueue(); + return q; + } + return NULL; +} + +void bestla_release_device(void* device) { + if (device) { + auto ptr = (sycl_device::SyclDevice*)device; + delete ptr; + } +} + +size_t bestla_device_gmem_size(void* device) { + if (device) { + auto ptr = (sycl_device::SyclDevice*)device; + return ptr->getGlobalMemSize(); + } +} + +void* bestla_device_malloc(size_t size, void* queue) { + if (queue) { + auto ptr = (sycl::queue*)queue; + auto tmp = sycl::malloc_device(size, *ptr); + return tmp; + } +} + +void bestla_device_free(void* obj, void* queue) { + if (queue && obj) { + auto ptr = (sycl::queue*)queue; + sycl::free(obj, *ptr); + } +} + +void bestla_device_memcpy_sync(void* dstptr, const void* srcptr, size_t size, void* queue) { + if (queue && srcptr && dstptr) { + auto ptr = (sycl::queue*)queue; + ptr->memcpy(dstptr, srcptr, size); + ptr->wait(); + } +} + +void bestla_device_memcpy(void* dstptr, const void* srcptr, size_t size, void* queue) { + if (queue && srcptr && dstptr) { + auto ptr = (sycl::queue*)queue; + ptr->memcpy(dstptr, srcptr, size); + } +} + +void bestla_device_sync(void* queue) { + if (queue) { + auto ptr = (sycl::queue*)queue; + ptr->wait(); + } +} + +size_t bestla_device_storage_size() { return sizeof(sycl_storage::StorageWeightKBlockNInteger); } + +void bestla_device_load_storage(void* hoststor, void* devstor, void* deviceptr, void* device_queue) { + auto ptr = storage::gemm::PackedWeightParser::deserialBuffer(const_cast(hoststor)); + GetCPUDevice(); + if (ptr && devstor && deviceptr) { + auto dstor = (sycl_storage::StorageWeightKBlockNInteger*)devstor; + if (ptr->mPrologueID == BTLA_PROLOGUEB_IDS::WeightKBlockNInteger) { + auto sptr = reinterpret_cast(ptr); + auto transtor = sptr->toTrans(); + utils::avector buffer1(transtor.mSize); + transtor.assign(buffer1.data()); + auto coretype = sptr->mCoreId; + auto NTile = gemm::CoreAttr::get_mask_val(sptr->mCoreId, gemm::CoreAttr::NTILE_MASK, gemm::CoreAttr::NTILE_SHIFT); + auto PackRow = gemm::CoreAttr::get_packrow(sptr->mCoreId); + auto CType = gemm::CoreAttr::get_comp(sptr->mCoreId); + auto btype = static_cast(gemm::CompTypeHelper::get_B(CType)); + if (btype == gemm::CompType::tFP32 && PackRow == 1) { + if (NTile == tAVX512F::NTILE) { + prologue_b::gemm::WeightKBlockNInteger::convertTransStorage(*sptr, transtor, + ne_bestla::ne_threading::get()); + } else if (NTile == tAVX2::NTILE) { + prologue_b::gemm::WeightKBlockNInteger::convertTransStorage(*sptr, transtor, + ne_bestla::ne_threading::get()); + } + } + if (btype == gemm::CompType::tS8 && PackRow == 4) { + if (NTile == tAMX_INT8_SS_KBlock::NTILE) { + prologue_b::gemm::WeightKBlockNInteger::convertTransStorage( + *sptr, transtor, ne_bestla::ne_threading::get()); + } else if (NTile == tAVX512_VNNI_KBlock::NTILE) { + prologue_b::gemm::WeightKBlockNInteger::convertTransStorage( + *sptr, transtor, ne_bestla::ne_threading::get()); + } else if (NTile == tAVX_VNNI_KBlock::NTILE) { + prologue_b::gemm::WeightKBlockNInteger::convertTransStorage(*sptr, transtor, + ne_bestla::ne_threading::get()); + } + } + if (btype == gemm::CompType::tBF16 && PackRow == 2) { + if (NTile == tAMX_BF16::NTILE) { + prologue_b::gemm::WeightKBlockNInteger::convertTransStorage(*sptr, transtor, + ne_bestla::ne_threading::get()); + } + } + *dstor = sycl_storage::StorageWeightKBlockNInteger(transtor); + dstor->assign((int8_t*)deviceptr); + dstor->fromHost(transtor, (sycl::queue*)device_queue); + } + } + if (ptr) { + delete ptr; + } +} + +template +using ProAT = sycl_prologue_a::ActivationBase; +template +using ProBTransT = sycl_prologue_b::WeightS4Trans; +template +using EpiT = sycl_epilogue::OutputBase; +void bestla_device_f32f32_forward(float* activation, void* weiptr, float* output, int _m, int _n, int _k, int lda, + int ldo, void* workspace, void* queue) { + using GemmCore = sycl_gemm::xve::DefaultSGemmCore; + auto dstor = (sycl_storage::StorageWeightKBlockNInteger*)weiptr; + auto q = (sycl::queue*)queue; + if (_m == 1) { + using ProB = ProBTransT; + ProB::gemv(activation, {(uint8_t*)dstor->mQBuf, (float*)dstor->mSBuf, dstor->mCStep}, output, _n, _k, + dstor->mBlockSize, q); + } else { + using KernelTLauncher = sycl_wrapper::LauncherWOQ; + utils::GemmProblem gp(1, _m, _n, _k); + KernelTLauncher::compute( + q, _m, _n, _k, dstor->mBlockSize, + {{activation, lda}, {(uint8_t*)dstor->mQBuf, (float*)dstor->mSBuf, dstor->mCStep}, {output, ldo}}); + } + if (sycl_device::SyclDevice::is_cpu(q)) { + q->wait(); + } +} + +void bestla_device_mul_f32(const struct ne_compute_params* params, const struct ne_tensor* src0, + const struct ne_tensor* src1, struct ne_tensor* dst) { + if (params->type == NE_TASK_INIT || params->type == NE_TASK_FINALIZE) { + return; + } + + auto q = (sycl::queue*)params->dev_queue; + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + const int64_t ne12 = src1->ne[2]; + const int64_t ne13 = src1->ne[3]; + + const size_t nb00 = src0->nb[0]; + const size_t nb01 = src0->nb[1]; + const size_t nb02 = src0->nb[2]; + const size_t nb03 = src0->nb[3]; + + const size_t nb10 = src1->nb[0]; + const size_t nb11 = ne11 == 1 ? 0 : src1->nb[1]; + const size_t nb12 = ne12 == 1 ? 0 : src1->nb[2]; + const size_t nb13 = ne13 == 1 ? 0 : src1->nb[3]; + + const size_t nb0 = dst->nb[0]; + const size_t nb1 = dst->nb[1]; + const size_t nb2 = dst->nb[2]; + const size_t nb3 = dst->nb[3]; + sycl::range<1> num_items{ne00 * ne01 * ne02 * ne03}; + auto src0ptr = (float*)src0->data; + auto src1ptr = (float*)src1->data; + auto dstptr = (float*)dst->data; + auto ev = q->submit([&](sycl::handler& cgh) { + cgh.parallel_for(num_items, [=](auto it) { + int i = it; + int i00 = i % ne00; + i /= ne00; + int i01 = i % ne01; + i /= ne01; + int i02 = i % ne02; + i /= ne02; + int i03 = i % ne03; + + int i13 = i03 % ne13; + int i12 = i02 % ne12; + int i11 = i01 % ne11; + + float* dst_ptr = (float*)((char*)dstptr + i03 * nb3 + i02 * nb2 + i01 * nb1); + float* src0_ptr = (float*)((char*)src0ptr + i03 * nb03 + i02 * nb02 + i01 * nb01); + float* src1_ptr = (float*)((char*)src1ptr + i13 * nb13 + i12 * nb12 + i11 * nb11); + dst_ptr[i00] = src0_ptr[i00] * src1_ptr[i00]; + }); + }); + if (sycl_device::SyclDevice::is_cpu(q)) { + q->wait(); + } +} + +void bestla_device_add_f32(const struct ne_compute_params* params, const struct ne_tensor* src0, + const struct ne_tensor* src1, struct ne_tensor* dst) { + if (params->type == NE_TASK_INIT || params->type == NE_TASK_FINALIZE) { + return; + } + + auto q = (sycl::queue*)params->dev_queue; + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + const int64_t ne12 = src1->ne[2]; + const int64_t ne13 = src1->ne[3]; + + const size_t nb00 = src0->nb[0]; + const size_t nb01 = src0->nb[1]; + const size_t nb02 = src0->nb[2]; + const size_t nb03 = src0->nb[3]; + + const size_t nb10 = src1->nb[0]; + const size_t nb11 = ne11 == 1 ? 0 : src1->nb[1]; + const size_t nb12 = ne12 == 1 ? 0 : src1->nb[2]; + const size_t nb13 = ne13 == 1 ? 0 : src1->nb[3]; + + const size_t nb0 = dst->nb[0]; + const size_t nb1 = dst->nb[1]; + const size_t nb2 = dst->nb[2]; + const size_t nb3 = dst->nb[3]; + sycl::range<1> num_items{ne00 * ne01 * ne02 * ne03}; + auto src0ptr = (float*)src0->data; + auto src1ptr = (float*)src1->data; + auto dstptr = (float*)dst->data; + auto ev = q->submit([&](sycl::handler& cgh) { + cgh.parallel_for(num_items, [=](auto it) { + int i = it; + int i00 = i % ne00; + i /= ne00; + int i01 = i % ne01; + i /= ne01; + int i02 = i % ne02; + i /= ne02; + int i03 = i % ne03; + + int i13 = i03 % ne13; + int i12 = i02 % ne12; + int i11 = i01 % ne11; + + float* dst_ptr = (float*)((char*)dstptr + i03 * nb3 + i02 * nb2 + i01 * nb1); + float* src0_ptr = (float*)((char*)src0ptr + i03 * nb03 + i02 * nb02 + i01 * nb01); + float* src1_ptr = (float*)((char*)src1ptr + i13 * nb13 + i12 * nb12 + i11 * nb11); + dst_ptr[i00] = src0_ptr[i00] + src1_ptr[i00]; + }); + }); + if (sycl_device::SyclDevice::is_cpu(q)) { + q->wait(); + } +} + +void bestla_device_elewise_f32(const struct ne_compute_params* params, const struct ne_tensor* src0, + struct ne_tensor* dst) { + if (params->type == NE_TASK_INIT || params->type == NE_TASK_FINALIZE) { + return; + } + + auto q = (sycl::queue*)params->dev_queue; + auto op = dst->op; + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + + auto srcptr = (float*)src0->data; + auto dstptr = (float*)dst->data; + sycl::range<1> num_items{ne00 * ne01 * ne02 * ne03}; + auto ev = q->submit([&](sycl::handler& cgh) { + cgh.parallel_for(num_items, [=](auto it) { + int i = it; + float srcval = srcptr[i]; + if (op == NE_OP_SILU) { + srcval = ne_silu_f32(srcval); + } + dstptr[i] = srcval; + }); + }); + if (sycl_device::SyclDevice::is_cpu(q)) { + q->wait(); + } +} + +void bestla_device_rms_norm_f32(const struct ne_compute_params* params, const struct ne_tensor* src0, + struct ne_tensor* dst) { + if (params->type == NE_TASK_INIT || params->type == NE_TASK_FINALIZE) { + return; + } + auto q = (sycl::queue*)params->dev_queue; + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + const size_t nb00 = src0->nb[0]; + const size_t nb01 = src0->nb[1]; + const size_t nb02 = src0->nb[2]; + const size_t nb03 = src0->nb[3]; + const size_t nb0 = dst->nb[0]; + const size_t nb1 = dst->nb[1]; + const size_t nb2 = dst->nb[2]; + const size_t nb3 = dst->nb[3]; + int64_t constexpr WgSize = 1024; + int constexpr SgSize = 16; + int64_t ne00_ = bestla::utils::padto_le(ne00, WgSize); + auto src0ptr = (float*)src0->data; + auto dstptr = (float*)dst->data; + auto ev = q->submit([&](sycl::handler& cgh) { + sycl::local_accessor slm(sycl::range(WgSize), cgh); + cgh.parallel_for(sycl::nd_range<1>(ne01 * ne02 * ne03 * WgSize, WgSize), + [=](auto it) [[intel::reqd_sub_group_size(SgSize)]] { + auto sg = it.get_sub_group(); + auto sg_idx = sg.get_group_id()[0]; + auto wg_idx = it.get_group(0); + auto wg_loc_id = it.get_local_id(); + auto lane_id = sg.get_local_id()[0]; + int i = wg_idx; + int i01 = i % ne01; + i /= ne01; + int i02 = i % ne02; + i /= ne02; + int i03 = i % ne03; + + float* dst_ptr = (float*)((char*)dstptr + i03 * nb3 + i02 * nb2 + i01 * nb1); + float* src0_ptr = (float*)((char*)src0ptr + i03 * nb03 + i02 * nb02 + i01 * nb01); + float sum = 0.0; + int64_t i00 = wg_loc_id; + for (; i00 < ne00_; i00 += WgSize) { + sum += (src0_ptr[i00] * src0_ptr[i00]); + } + if (i00 < ne00) { + sum += (src0_ptr[i00] * src0_ptr[i00]); + } + slm[wg_loc_id] = sum; + it.barrier(sycl::access::fence_space::local_space); + if (sg_idx == 0) { + for (size_t i = wg_loc_id; i < WgSize - SgSize; i += SgSize) { + sum += slm[i + SgSize]; + } + float gsum = 0.f; + for (int i = 0; i < SgSize; i += 1) { + gsum += sg.shuffle(sum, i); + } + float mean = gsum / ne00; + const float scale = 1.0f / sqrtf(mean + eps); + slm[0] = scale; + } + it.barrier(sycl::access::fence_space::local_space); + + float scale = slm[0]; + i00 = wg_loc_id; + for (; i00 < ne00_; i00 += WgSize) { + dst_ptr[i00] = src0_ptr[i00] * scale; + } + if (i00 < ne00) { + dst_ptr[i00] = src0_ptr[i00] * scale; + } + }); + }); + if (sycl_device::SyclDevice::is_cpu(q)) { + q->wait(); + } +} + +extern void ggml_rope_yarn_corr_dims(int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, + float dims[2]); + +static float rope_yarn_ramp(const float low, const float high, const int i0) { + const float y = (i0 / 2 - low) / std::max(0.001f, high - low); + return 1.0f - std::min(1.0f, std::max(0.0f, y)); +} + +// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn +// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. +static void rope_yarn(float theta_extrap, float freq_scale, float corr_dims0, float corr_dims1, int64_t i0, + float ext_factor, float mscale, float* cos_theta, float* sin_theta) { + // Get n-d rotational scaling corrected for extrapolation + float theta_interp = freq_scale * theta_extrap; + float theta = theta_interp; + if (ext_factor != 0.0f) { + float ramp_mix = rope_yarn_ramp(corr_dims0, corr_dims1, i0) * ext_factor; + theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + + // Get n-d magnitude scaling corrected for interpolation + mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale); + } + *cos_theta = cosf(theta) * mscale; + *sin_theta = sinf(theta) * mscale; +} + +void bestla_device_rope_f32(const struct ne_compute_params* params, const struct ne_tensor* src0, + const struct ne_tensor* src1, struct ne_tensor* dst) { + if (params->type == NE_TASK_INIT || params->type == NE_TASK_FINALIZE) { + return; + } + auto q = (sycl::queue*)params->dev_queue; + const int bs = src0->ne[3]; + NE_ASSERT(src1->type == NE_TYPE_I32); + + const float freq_base = ((float*)(dst->op_params))[0]; + const float freq_scale = 1 / ((float*)(dst->op_params))[1]; + const int n_orig_ctx = (int)((float*)(dst->op_params))[2]; + const float ext_factor = ((float*)(dst->op_params))[3]; + const float attn_factor = ((float*)(dst->op_params))[4]; + const float beta_fast = ((float*)(dst->op_params))[5]; + const float beta_slow = ((float*)(dst->op_params))[6]; + const float scale_factor = ((float*)(dst->op_params))[7]; +#define ROPE_PARAMS_NUM 5 +#define ROPE_NPAST_IDX 0 +#define ROPE_NDIMS_IDX 1 +#define ROPE_MODE_IDX 2 +#define ROPE_PROMPTSIZE_IDX 3 +#define ROPE_NKEEP_IDX 4 +#define ROPE_PADDING_IDX 5 + const int64_t n_past = ((int32_t*)src1->data)[ROPE_NPAST_IDX]; + const int64_t n_dims = ((int32_t*)src1->data)[ROPE_NDIMS_IDX]; + const int64_t mode = ((int32_t*)src1->data)[ROPE_MODE_IDX]; + const int64_t prompt_size = ((int32_t*)src1->data)[ROPE_PROMPTSIZE_IDX]; + const int64_t n_keep = ((int32_t*)src1->data)[ROPE_NKEEP_IDX]; + assert(n_past >= 0); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + const int64_t ne2 = dst->ne[2]; + const int64_t ne3 = dst->ne[3]; + + const size_t nb00 = src0->nb[0]; + const size_t nb01 = src0->nb[1]; + const size_t nb02 = src0->nb[2]; + const size_t nb03 = src0->nb[3]; + + const size_t nb0 = dst->nb[0]; + const size_t nb1 = dst->nb[1]; + const size_t nb2 = dst->nb[2]; + const size_t nb3 = dst->nb[3]; + + const int nr = ne1 * ne2 * ne3; + + const float theta_scale = powf(freq_base, -2.0f / n_dims); + const float inv_ndims = -1.f / n_dims; + float corr_dims[2]; + ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); + float corr_dims0 = corr_dims[0]; + float corr_dims1 = corr_dims[1]; + int constexpr SgSize = 16; + auto src0ptr = (float*)src0->data; + auto dstptr = (float*)dst->data; + auto ev = q->submit([&](sycl::handler& cgh) { + // sycl::local_accessor slm(sycl::range(WgSize), cgh); + cgh.parallel_for(sycl::nd_range<1>(nr * SgSize, SgSize), [=](auto it) [[intel::reqd_sub_group_size(SgSize)]] { + auto sg = it.get_sub_group(); + auto sg_idx = sg.get_group_id()[0]; + auto wg_idx = it.get_group(0); + auto wg_loc_id = it.get_local_id(); + auto lane_id = sg.get_local_id()[0]; + int i = wg_idx; + int i1 = i % ne1; + i /= ne1; + int i2 = i % ne2; + i /= ne2; + int i3 = i % ne3; + + const int64_t p = n_past + i2; + float theta_base = (float)p; + for (int64_t i0 = 0; i0 < ne0; i0 += 2) { + float cos_theta, sin_theta; + rope_yarn(theta_base, freq_scale, corr_dims0, corr_dims1, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + + theta_base *= theta_scale; + + const float* const src = (float*)((char*)src0ptr + i3 * nb03 + i2 * nb02 + i1 * nb01 + i0 * nb00); + float* dst_data = (float*)((char*)dstptr + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0); + + const float x0 = src[0]; + const float x1 = src[1]; + + dst_data[0] = x0 * cos_theta - x1 * sin_theta; + dst_data[1] = x0 * sin_theta + x1 * cos_theta; + } + }); + }); + if (sycl_device::SyclDevice::is_cpu(q)) { + q->wait(); + } +} + +void bestla_device_dup_f32(const struct ne_compute_params* params, const struct ne_tensor* src0, + struct ne_tensor* dst) { + if (params->type == NE_TASK_INIT || params->type == NE_TASK_FINALIZE) { + return; + } + + auto q = (sycl::queue*)params->dev_queue; + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + const int64_t ne2 = dst->ne[2]; + const int64_t ne3 = dst->ne[3]; + + const size_t nb00 = src0->nb[0]; + const size_t nb01 = src0->nb[1]; + const size_t nb02 = src0->nb[2]; + const size_t nb03 = src0->nb[3]; + + const size_t nb0 = dst->nb[0]; + const size_t nb1 = dst->nb[1]; + const size_t nb2 = dst->nb[2]; + const size_t nb3 = dst->nb[3]; + + auto srcptr = (float*)src0->data; + auto dstptr = (float*)dst->data; + auto dtype = dst->type; + sycl::range<1> num_items{ne0 * ne1 * ne2 * ne3}; + auto ev = q->submit([&](sycl::handler& cgh) { + cgh.parallel_for(num_items, [=](auto it) { + int i = it; + int i0 = i % ne0; + i /= ne0; + int i1 = i % ne1; + i /= ne1; + int i2 = i % ne2; + i /= ne2; + int i3 = i % ne3; + float srcval = *(float*)((char*)srcptr + i0 * nb00 + i1 * nb01 + i2 * nb02 + i3 * nb03); + auto dptr = (char*)dstptr + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3; + if (dtype == NE_TYPE_F32) { + *(float*)dptr = srcval; + } else if (dtype == NE_TYPE_F16) { + *(sycl::half*)dptr = srcval; + } + }); + }); + if (sycl_device::SyclDevice::is_cpu(q)) { + q->wait(); + } +} + +template +class MHA { + public: + template + static sycl::event forward(int batch, int seq, int seq_acc, int hnum, int hsize, int n_ctx, const T* Q, const T* K, + const T* V, T_DST* O, float attn_scale, sycl::queue* q) { + int constexpr SgSize = 16; + assert(hsize % SgSize == 0); + int n_past = seq_acc - seq; + if constexpr (Mask) { + assert(seq > 1); + } + int WgSize = SgSize; + int seq_acc_pad = utils::padto_le(seq_acc, WgSize * 2); + int nf = hnum * hsize; + auto ev = q->submit([&](sycl::handler& cgh) { + sycl::local_accessor slm(sycl::range(std::max(seq_acc, 1024)), cgh); + cgh.parallel_for(sycl::nd_range<1>(WgSize * batch * seq * hnum, WgSize), + [=](auto it) [[intel::reqd_sub_group_size(SgSize)]] { + auto sg = it.get_sub_group(); + auto sg_idx = sg.get_group_id()[0]; + auto wg_idx = it.get_group(0); + auto wg_loc_id = it.get_local_id(); + auto lane_id = sg.get_local_id()[0]; + + int i = wg_idx; + int ih = i % hnum; + i /= hnum; + int is = i % seq; + i /= seq; + int ib = i % batch; + size_t Q_off = ib * seq * nf + is * nf + ih * hsize; + size_t K_off = ib * n_ctx * nf + ih * hsize * n_ctx; + size_t V_off = ib * n_ctx * nf + ih * hsize * n_ctx; + size_t O_off = ib * seq * nf + is * nf + ih * hsize; + typedef sycl::vec TC; + T maxs = -INFINITY; + for (int jj = 0; jj < seq_acc; jj++) { + TC tmp = {0, 0}; + if constexpr (Mask) { + if (jj <= is + n_past) { + for (int ik = wg_loc_id * 2; ik < hsize; ik += WgSize * 2) { + tmp += *(TC*)&Q[Q_off + ik] * *(TC*)&K[K_off + jj * hsize + ik]; + } + tmp *= attn_scale; + } else { + tmp = {-INFINITY, -INFINITY}; + } + } else { + for (int ik = wg_loc_id * 2; ik < hsize; ik += WgSize * 2) { + tmp += *(TC*)&Q[Q_off + ik] * *(TC*)&K[K_off + jj * hsize + ik]; + } + tmp *= attn_scale; + } + T tmp_sum = tmp[0] + tmp[1]; + T sum = 0; + for (int i = 0; i < SgSize; i += 1) { + sum += sg.shuffle(tmp_sum, i); + } + slm[jj] = sum; + maxs = std::max(maxs, sum); + } + float fsums = 0.f; + float fmax = float(maxs); + int jj = wg_loc_id * 2; + for (; jj < seq_acc_pad; jj += WgSize * 2) { + auto s2 = *(TC*)&slm[jj]; + s2[0] = std::expf(s2[0] - fmax); + s2[1] = std::expf(s2[1] - fmax); + fsums += s2[0]; + fsums += s2[1]; + *(TC*)&slm[jj] = s2; + } + if (jj < seq_acc) { + slm[jj] = std::expf(float(slm[jj]) - fmax); + fsums += slm[jj]; + if (jj + 1 < seq_acc) { + slm[jj + 1] = std::expf(float(slm[jj + 1]) - fmax); + fsums += slm[jj + 1]; + } + } + float gsum = 0; + for (int i = 0; i < SgSize; i += 1) { + gsum += sg.shuffle(fsums, i); + } + T scale = 1.f / gsum; + jj = wg_loc_id * 2; + for (; jj < seq_acc_pad; jj += WgSize * 2) { + auto s2 = *(TC*)&slm[jj]; + s2 *= scale; + *(TC*)&slm[jj] = s2; + } + if (jj < seq_acc) { + slm[jj] *= scale; + if (jj + 1 < seq_acc) { + slm[jj + 1] *= scale; + } + } + + for (int kk = 0; kk < hsize; kk++) { + TC tmp = {0, 0}; + jj = wg_loc_id * 2; + for (; jj < seq_acc_pad; jj += WgSize * 2) { + auto s2 = *(TC*)&slm[jj]; + auto v2 = *(TC*)&V[V_off + kk * n_ctx + jj]; + tmp += s2 * v2; + } + if (jj < seq_acc) { + tmp[0] += slm[jj] * V[V_off + kk * n_ctx + jj]; + if (jj + 1 < seq_acc) { + tmp[1] += slm[jj + 1] * V[V_off + kk * n_ctx + jj + 1]; + } + } + T tmp_sum = tmp[0] + tmp[1]; + T sum = 0; + for (int i = 0; i < SgSize; i += 1) { + sum += sg.shuffle(tmp_sum, i); + } + O[O_off + kk] = sum; + } + }); + }); + return ev; + } + + template + static sycl::event forward1(int batch, int seq, int seq_acc, int hnum, int hsize, int n_ctx, const T* Q, const T* K, + const T* V, T_DST* O, float attn_scale, sycl::queue* q) { + int constexpr SgSize = 16; + static_assert(HSize % SgSize == 0); + int constexpr SgUnroll = HSize / SgSize; + assert(hsize % HSize == 0); + assert(hsize % SgSize == 0); + int n_past = seq_acc - seq; + if constexpr (Mask) { + assert(seq > 1); + } + int constexpr WgSize = SgSize; + int seq_acc_pad = utils::padto_le(seq_acc, WgSize); + int nf = hnum * hsize; + auto ev = q->submit([&](sycl::handler& cgh) { + sycl::local_accessor slm(sycl::range(std::max(seq_acc, 1024)), cgh); + cgh.parallel_for(sycl::nd_range<1>(WgSize * batch * seq * hnum, WgSize), + [=](auto it) [[intel::reqd_sub_group_size(SgSize)]] { + auto sg = it.get_sub_group(); + auto sg_idx = sg.get_group_id()[0]; + auto wg_idx = it.get_group(0); + auto wg_loc_id = it.get_local_id(); + auto lane_id = sg.get_local_id()[0]; + + int i = wg_idx; + int ih = i % hnum; + i /= hnum; + int is = i % seq; + i /= seq; + int ib = i % batch; + size_t Q_off = ib * seq * nf + is * nf + ih * hsize; + size_t K_off = ib * n_ctx * nf + ih * hsize * n_ctx; + size_t V_off = ib * n_ctx * nf + ih * hsize * n_ctx; + size_t O_off = ib * seq * nf + is * nf + ih * hsize; + + T maxs = -INFINITY; + for (int jj = 0; jj < seq_acc; jj++) { + T tmp = 0; + if constexpr (Mask) { + if (jj <= is + n_past) { + for (int ik = wg_loc_id * SgUnroll; ik < hsize; ik += SgUnroll * SgSize) { +#pragma unroll + for (int ir = 0; ir < SgUnroll; ir++) { + tmp += Q[Q_off + ik + ir] * K[K_off + jj * hsize + ik + ir]; + } + } + tmp *= attn_scale; + } else { + tmp = -INFINITY; + } + } else { + for (int ik = wg_loc_id * SgUnroll; ik < hsize; ik += SgUnroll * SgSize) { +#pragma unroll + for (int ir = 0; ir < SgUnroll; ir++) { + tmp += Q[Q_off + ik + ir] * K[K_off + jj * hsize + ik + ir]; + } + } + tmp *= attn_scale; + } + T sum = 0; +#pragma unroll + for (int i = 0; i < SgSize; i += 1) { + sum += sg.shuffle(tmp, i); + } + slm[jj] = sum; + maxs = std::max(maxs, sum); + } + float fsums = 0.f; + float fmax = float(maxs); + int jj = wg_loc_id; + for (; jj < seq_acc_pad; jj += SgSize) { + auto s = slm[jj]; + s = std::expf(s - fmax); + fsums += s; + slm[jj] = s; + } + if (jj < seq_acc) { + auto s = std::expf(float(slm[jj]) - fmax); + fsums += s; + slm[jj] = s; + } + float gsum = 0; +#pragma unroll + for (int i = 0; i < SgSize; i += 1) { + gsum += sg.shuffle(fsums, i); + } + T scale = 1.f / gsum; + jj = wg_loc_id; + for (; jj < seq_acc_pad; jj += WgSize) { + slm[jj] *= scale; + } + if (jj < seq_acc) { + slm[jj] *= scale; + } + + T tmp[SgUnroll]; + for (int kk = wg_loc_id * SgUnroll; kk < hsize; kk += SgUnroll * SgSize) { +#pragma unroll + for (int ir = 0; ir < SgUnroll; ir++) { + tmp[ir] = 0; + } + for (int ijj = 0; ijj < seq_acc; ijj += 1) { + auto s = slm[ijj]; +#pragma unroll + for (int ir = 0; ir < SgUnroll; ir++) { + auto v = V[V_off + (kk + ir) * n_ctx + ijj]; + tmp[ir] += s * v; + } + } +#pragma unroll + for (int ir = 0; ir < SgUnroll; ir++) { + O[O_off + kk + ir] = tmp[ir]; + } + } + }); + }); + return ev; + } +}; +void bestla_device_mha_f32(const struct ne_compute_params* params, const struct ne_tensor* _q, + const struct ne_tensor* k, const struct ne_tensor* v, struct ne_tensor* dst) { + if (params->type == NE_TASK_INIT || params->type == NE_TASK_FINALIZE) { + return; + } + auto q = (sycl::queue*)params->dev_queue; + const int64_t neq0 = _q->ne[0]; + const int64_t neq1 = _q->ne[1]; + const int64_t neq2 = _q->ne[2]; + const int64_t neq3 = _q->ne[3]; + + const int64_t nek0 = k->ne[0]; + const int64_t nek1 = k->ne[1]; + const int64_t nek2 = k->ne[2]; + // const int64_t nek3 = k->ne[3]; + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + + const int64_t headsize = neq0; + const int64_t headnum = neq1; + const int64_t heads_kv = nek2; + const int64_t embedsize = headnum * headsize; + const int64_t seq_cur = neq2; + const int64_t seq_all = nek1; + const int64_t batch = neq3; + auto scale = *(float*)dst->padding; + auto n_ctx = *(uint32_t*)&dst->padding[4]; + auto Qptr = (float*)_q->data; + auto Kptr = (float*)k->data; + auto Vptr = (float*)v->data; + auto Optr = (float*)dst->data; + if (seq_cur > 1) { + MHA::forward1(batch, seq_cur, seq_all, headnum, headsize, n_ctx, Qptr, Kptr, Vptr, Optr, + scale, q); + } else { + MHA::forward1(batch, seq_cur, seq_all, headnum, headsize, n_ctx, Qptr, Kptr, Vptr, Optr, + scale, q); + } + if (sycl_device::SyclDevice::is_cpu(q)) { + q->wait(); + } +} +#endif diff --git a/neural_speed/core/ne.h b/neural_speed/core/ne.h index 9e029b6d9..862934473 100644 --- a/neural_speed/core/ne.h +++ b/neural_speed/core/ne.h @@ -49,11 +49,7 @@ #define NE_SIZE_CALC -1 -#if __AVX512F__ #define NE_ALIGNMENT 64 -#else -#define NE_ALIGNMENT 32 -#endif #define NE_ASSERT(x) \ do { \ @@ -98,7 +94,7 @@ struct ne_context; enum ne_backend { NE_BACKEND_CPU = 0, - NE_BACKEND_CUDA = 1, + NE_BACKEND_SYCL = 1, }; // ne object @@ -124,6 +120,19 @@ struct ne_scratch { // ne context // +#define MAX_SYCL_BUFFER_SIZE (4000ull << 20) // 4GB +#define MAX_SYCL_BUFFER_COUNT 64 // 32*4GB=128GB + +struct ne_sycl_context { + void* dev; + void* queue; + int n_buffers; + void* buffers[MAX_SYCL_BUFFER_COUNT]; + size_t offs[MAX_SYCL_BUFFER_COUNT]; + size_t offs_save[MAX_SYCL_BUFFER_COUNT]; + size_t sizes[MAX_SYCL_BUFFER_COUNT]; +}; + struct ne_context { size_t mem_size; void* mem_buffer; @@ -137,6 +146,9 @@ struct ne_context { struct ne_scratch scratch; struct ne_scratch scratch_save; + + struct ne_object* objects_save; + struct ne_sycl_context* dev_ctx; }; struct ne_context_container { @@ -197,6 +209,9 @@ struct ne_cgraph { size_t work_size; struct ne_tensor* work; + size_t dev_work_size; + struct ne_tensor* dev_work; + struct ne_tensor* nodes[NE_MAX_NODES]; struct ne_tensor* grads[NE_MAX_NODES]; struct ne_tensor* leafs[NE_MAX_NODES]; @@ -232,6 +247,11 @@ struct ne_compute_params { // work buffer for all threads size_t wsize; void* wdata; + + size_t dev_wsize; + void* dev_wdata; + + void* dev_queue; }; #ifdef __cplusplus diff --git a/neural_speed/core/ne_bestla.h b/neural_speed/core/ne_bestla.h index 77713d35a..414c7d81d 100644 --- a/neural_speed/core/ne_bestla.h +++ b/neural_speed/core/ne_bestla.h @@ -78,6 +78,38 @@ void bestla_layernormalization(int norm_count, int norm_size, bool isrms, float void bestla_mul(int batch, int vsize, const float* tensor, const float* vector, int vstep, float* out); void bestla_add(int batch, int vsize, const float* tensor, const float* vector, int vstep, float* out); + +enum ne_backend bestla_backend_support(struct ne_tensor* src0, struct ne_tensor* src1, enum ne_op op); +bool bestla_support(struct ne_tensor* node, int n_threads, size_t* workspace, size_t* dev_workspace); + +#ifdef NS_SYCL +void* bestla_create_device(bool profile); +void* bestla_get_device_queue(void* device); +void bestla_release_device(void* device); +size_t bestla_device_gmem_size(void* device); +void* bestla_device_malloc(size_t size, void* queue); +void bestla_device_free(void* ptr, void* queue); +void bestla_device_memcpy(void* dstptr, const void* srcptr, size_t size, void* queue); +void bestla_device_memcpy_sync(void* dstptr, const void* srcptr, size_t size, void* queue); +void bestla_device_sync(void* queue); +size_t bestla_device_storage_size(); +void bestla_device_load_storage(void* hoststor, void* devstor, void* deviceptr, void* queue); +void bestla_device_f32f32_forward(float* activation, void* weiptr, float* output, int _m, int _n, int _k, int lda, + int ldo, void* workspace, void* queue); +void bestla_device_mul_f32(const struct ne_compute_params* params, const struct ne_tensor* src0, + const struct ne_tensor* src1, struct ne_tensor* dst); +void bestla_device_add_f32(const struct ne_compute_params* params, const struct ne_tensor* src0, + const struct ne_tensor* src1, struct ne_tensor* dst); +void bestla_device_elewise_f32(const struct ne_compute_params* params, const struct ne_tensor* src0, + struct ne_tensor* dst); +void bestla_device_rms_norm_f32(const struct ne_compute_params* params, const struct ne_tensor* src0, + struct ne_tensor* dst); +void bestla_device_rope_f32(const struct ne_compute_params* params, const struct ne_tensor* src0, + const struct ne_tensor* src1, struct ne_tensor* dst); +void bestla_device_dup_f32(const struct ne_compute_params* params, const struct ne_tensor* src0, struct ne_tensor* dst); +void bestla_device_mha_f32(const struct ne_compute_params* params, const struct ne_tensor* q, const struct ne_tensor* k, + const struct ne_tensor* v, struct ne_tensor* dst); +#endif #ifdef __cplusplus } #endif diff --git a/neural_speed/core/ne_layers.c b/neural_speed/core/ne_layers.c index 32452fc8c..278ec50c7 100644 --- a/neural_speed/core/ne_layers.c +++ b/neural_speed/core/ne_layers.c @@ -581,7 +581,9 @@ int ne_nrows(const struct ne_tensor* tensor) { size_t ne_nbytes(const struct ne_tensor* tensor) { static_assert(NE_MAX_DIMS == 4, "NE_MAX_DIMS is not 4 - update this function"); - + if (tensor->type == NE_TYPE_BTLA) { + return tensor->size; + } return (ne_nelements(tensor) * NE_TYPE_SIZE[tensor->type]) / NE_BLCK_SIZE[tensor->type]; } @@ -794,27 +796,26 @@ struct ne_context* ne_init(struct ne_init_params params) { const size_t mem_size = (params.mem_size + NE_MEM_ALIGN - 1) & ~(NE_MEM_ALIGN - 1); - *ctx = (struct ne_context){ - /*.mem_size =*/mem_size, - /*.mem_buffer =*/params.mem_buffer ? params.mem_buffer : NE_ALIGNED_MALLOC(mem_size), - /*.mem_buffer_owned =*/params.mem_buffer ? false : true, - /*.no_alloc =*/params.no_alloc, - /*.n_objects =*/0, - /*.objects_begin =*/NULL, - /*.objects_end =*/NULL, - /*.scratch =*/ - { - 0, - 0, - NULL, - }, - /*.scratch_save =*/ - { - 0, - 0, - NULL, - }, - }; + *ctx = + (struct ne_context){/*.mem_size =*/mem_size, + /*.mem_buffer =*/params.mem_buffer ? params.mem_buffer : NE_ALIGNED_MALLOC(mem_size), + /*.mem_buffer_owned =*/params.mem_buffer ? false : true, + /*.no_alloc =*/params.no_alloc, + /*.n_objects =*/0, + /*.objects_begin =*/NULL, + /*.objects_end =*/NULL, + /*.scratch =*/ + { + 0, + 0, + NULL, + }, + /*.scratch_save =*/ + { + 0, + 0, + NULL, + }}; NE_ASSERT(ctx->mem_buffer != NULL); @@ -860,6 +861,18 @@ size_t ne_used_mem(const struct ne_context* ctx) { return ctx->objects_end == NULL ? 0 : ctx->objects_end->offs + ctx->objects_end->size; } +void ne_buffer_save(struct ne_context* ctx) { + if (ctx->dev_ctx) { + memcpy(ctx->dev_ctx->offs_save, ctx->dev_ctx->offs, ctx->dev_ctx->n_buffers * sizeof(ctx->dev_ctx->offs_save[0])); + } +} + +void ne_buffer_load(struct ne_context* ctx) { + if (ctx->dev_ctx) { + memcpy(ctx->dev_ctx->offs, ctx->dev_ctx->offs_save, ctx->dev_ctx->n_buffers * sizeof(ctx->dev_ctx->offs_save[0])); + } +} + size_t ne_set_scratch(struct ne_context* ctx, struct ne_scratch scratch) { const size_t result = ctx->scratch.data ? ctx->scratch.offs : 0; @@ -888,8 +901,169 @@ static void ne_set_op_params(struct ne_tensor* tensor, const void* params, size_ memcpy(tensor->op_params, params, params_size); } +#ifdef NS_SYCL +struct ne_tensor* ne_new_device_tensor_impl(struct ne_context* ctx, enum ne_type type, int n_dims, const int64_t* ne, + void* data, size_t size) { + // always insert objects at the end of the context's memory pool + struct ne_object* obj_cur = ctx->objects_end; + + const size_t cur_offs = obj_cur == NULL ? 0 : obj_cur->offs; + const size_t cur_size = obj_cur == NULL ? 0 : obj_cur->size; + const size_t cur_end = cur_offs + cur_size; + + size_t size_needed = 0; + assert(!ctx->no_alloc); + char* const mem_buffer = (char* const)ctx->mem_buffer; + struct ne_object* const obj_new = (struct ne_object*)(mem_buffer + cur_end); + char* dptr = (char*)data; + int SYCL_ALIGN = 256; + if (type == NE_TYPE_BTLA) { + size_needed = size; + size_needed = ((size_needed + SYCL_ALIGN - 1) / SYCL_ALIGN) * SYCL_ALIGN; + } else { + size_needed += NE_TYPE_SIZE[type] * (ne[0] / NE_BLCK_SIZE[type]); + for (int i = 1; i < n_dims; i++) { + size_needed *= ne[i]; + } + size_needed = ((size_needed + SYCL_ALIGN - 1) / SYCL_ALIGN) * SYCL_ALIGN; + } + int buf_idx = -1; + for (int i = 0; i < ctx->dev_ctx->n_buffers; i++) { + if (ctx->dev_ctx->offs[i] + size_needed <= ctx->dev_ctx->sizes[i]) { + buf_idx = i; + break; + } + } + if (buf_idx == -1) { + NE_PRINT("%s: %d device memory pool is not enough, please increase\n", __func__, __LINE__); + assert(false); + return NULL; + } + if (dptr == NULL) { + dptr = (char* const)ctx->dev_ctx->buffers[buf_idx] + ctx->dev_ctx->offs[buf_idx]; + ctx->dev_ctx->offs[buf_idx] += size_needed; + } + size_t obj_size = sizeof(struct ne_tensor); + if (type == NE_TYPE_BTLA) { + obj_size += bestla_device_storage_size(); + } + if (ctx->scratch.data == NULL) { + if (cur_end + obj_size + NE_OBJECT_SIZE > ctx->mem_size) { + NE_PRINT( + "%s: %d Context's memory pool is not enough(current %zu MB, ctx->mem_size available %zu MB), please increase " + "the scratch_size_ratio.\n", + __func__, __LINE__, (cur_end + size_needed + NE_OBJECT_SIZE) / 1024 / 1024, ctx->mem_size / 1024 / 1024); + assert(false); + return NULL; + } + + *obj_new = (struct ne_object){ + .offs = cur_end + NE_OBJECT_SIZE, + .size = obj_size, + .next = NULL, + }; + } else { + if (cur_end + sizeof(struct ne_tensor) + NE_OBJECT_SIZE > ctx->mem_size) { + NE_PRINT("%s: %d not enough space in the context's memory pool (needed %zu, ctx->mem_size available %zu)\n", + __func__, __LINE__, cur_end + sizeof(struct ne_tensor) + NE_OBJECT_SIZE, ctx->mem_size); + assert(false); + return NULL; + } + *obj_new = (struct ne_object){ + .offs = cur_end + NE_OBJECT_SIZE, + .size = sizeof(struct ne_tensor), + .next = NULL, + }; + if (type == NE_TYPE_BTLA) { + if (ctx->scratch.offs + bestla_device_storage_size() > ctx->scratch.size) { + NE_PRINT("%s: %d not enough space in the context's memory pool (needed %zu, ctx->mem_size available %zu)\n", + __func__, __LINE__, cur_end + sizeof(struct ne_tensor) + NE_OBJECT_SIZE, ctx->mem_size); + assert(false); + return NULL; + } + data = (char* const)ctx->scratch.data + ctx->scratch.offs; + ctx->scratch.offs += bestla_device_storage_size(); + } + } + + if (obj_cur != NULL) { + obj_cur->next = obj_new; + } else { + // this is the first object in this context + ctx->objects_begin = obj_new; + } + + ctx->objects_end = obj_new; + + struct ne_tensor* const result = (struct ne_tensor*)(mem_buffer + obj_new->offs); + + *result = (struct ne_tensor){ + .type = type, + .backend = NE_BACKEND_SYCL, + .n_dims = n_dims, + .ne = {1, 1, 1, 1}, + .nb = {0, 0, 0, 0}, + .op = NE_OP_NONE, + .is_param = false, + .op_params = {0}, + .grad = NULL, + .src0 = NULL, + .src1 = NULL, + .opt = {NULL}, + .n_tasks = 0, + .perf_runs = 0, + .perf_cycles = 0, + .perf_time_us = 0, + .data = NULL, + .size = size_needed, + .name = {0}, + .padding = {0}, + }; + if (type == NE_TYPE_BTLA) { + result->data = (void*)(result + 1); + memcpy(result->padding, &dptr, sizeof(dptr)); + } else { + result->data = dptr; + } + if (result->data == data) { + result->size = size; + } + for (int i = 0; i < n_dims; i++) { + result->ne[i] = ne[i]; + } + + result->nb[0] = NE_TYPE_SIZE[type]; + if (type != NE_TYPE_BTLA) { + result->nb[1] = result->nb[0] * (result->ne[0] / NE_BLCK_SIZE[type]); + } + if (size == NE_SIZE_CALC) { + size_needed = NE_TYPE_SIZE[type] * (ne[0] / NE_BLCK_SIZE[type]); + for (int i = 1; i < n_dims; i++) { + size_needed *= ne[i]; + } + result->size = size_needed; + } + + for (int i = 2; i < NE_MAX_DIMS; i++) { + result->nb[i] = result->nb[i - 1] * result->ne[i - 1]; + } + + ctx->n_objects++; + + return result; +} +#endif + struct ne_tensor* ne_new_tensor_impl(struct ne_context* ctx, enum ne_type type, int n_dims, const int64_t* ne, - void* data, size_t size) { + void* data, size_t size, enum ne_backend bk) { + if (bk == NE_BACKEND_SYCL) { +#ifdef NS_SYCL + return ne_new_device_tensor_impl(ctx, type, n_dims, ne, data, size); +#else + NE_ASSERT(0); + return NULL; +#endif + } // always insert objects at the end of the context's memory pool struct ne_object* obj_cur = ctx->objects_end; @@ -911,7 +1085,7 @@ struct ne_tensor* ne_new_tensor_impl(struct ne_context* ctx, enum ne_type type, } } - char* const mem_buffer = ctx->mem_buffer; + char* const mem_buffer = (char* const)ctx->mem_buffer; struct ne_object* const obj_new = (struct ne_object*)(mem_buffer + cur_end); if (ctx->scratch.data == NULL || data != NULL) { @@ -988,7 +1162,7 @@ struct ne_tensor* ne_new_tensor_impl(struct ne_context* ctx, enum ne_type type, .perf_cycles = 0, .perf_time_us = 0, .data = (data == NULL && !ctx->no_alloc) ? (void*)(result + 1) : data, - .size = size_needed, + .size = data ? size : size_needed, .name = {0}, .padding = {0}, }; @@ -999,6 +1173,14 @@ struct ne_tensor* ne_new_tensor_impl(struct ne_context* ctx, enum ne_type type, result->nb[0] = NE_TYPE_SIZE[type]; if (type != NE_TYPE_BTLA) { result->nb[1] = result->nb[0] * (result->ne[0] / NE_BLCK_SIZE[type]); + if (size == NE_SIZE_CALC) { + size_needed = 0; + size_needed += NE_TYPE_SIZE[type] * (ne[0] / NE_BLCK_SIZE[type]); + for (int i = 1; i < n_dims; i++) { + size_needed *= ne[i]; + } + result->size = size_needed; + } } for (int i = 2; i < NE_MAX_DIMS; i++) { @@ -1010,35 +1192,38 @@ struct ne_tensor* ne_new_tensor_impl(struct ne_context* ctx, enum ne_type type, return result; } -struct ne_tensor* ne_new_tensor(struct ne_context* ctx, enum ne_type type, int n_dims, const int64_t* ne, size_t size) { - return ne_new_tensor_impl(ctx, type, n_dims, ne, NULL, size); +struct ne_tensor* ne_new_tensor(struct ne_context* ctx, enum ne_type type, int n_dims, const int64_t* ne, size_t size, + enum ne_backend bk) { + return ne_new_tensor_impl(ctx, type, n_dims, ne, NULL, size, bk); } -struct ne_tensor* ne_new_tensor_1d(struct ne_context* ctx, enum ne_type type, int64_t ne0, size_t size) { - return ne_new_tensor(ctx, type, 1, &ne0, size); +struct ne_tensor* ne_new_tensor_1d(struct ne_context* ctx, enum ne_type type, int64_t ne0, size_t size, + enum ne_backend bk) { + return ne_new_tensor(ctx, type, 1, &ne0, size, bk); } -struct ne_tensor* ne_new_tensor_2d(struct ne_context* ctx, enum ne_type type, int64_t ne0, int64_t ne1, size_t size) { +struct ne_tensor* ne_new_tensor_2d(struct ne_context* ctx, enum ne_type type, int64_t ne0, int64_t ne1, size_t size, + enum ne_backend bk) { const int64_t ne[2] = {ne0, ne1}; - return ne_new_tensor(ctx, type, 2, ne, size); + return ne_new_tensor(ctx, type, 2, ne, size, bk); } struct ne_tensor* ne_new_tensor_3d(struct ne_context* ctx, enum ne_type type, int64_t ne0, int64_t ne1, int64_t ne2, - size_t size) { + size_t size, enum ne_backend bk) { const int64_t ne[3] = {ne0, ne1, ne2}; - return ne_new_tensor(ctx, type, 3, ne, size); + return ne_new_tensor(ctx, type, 3, ne, size, bk); } struct ne_tensor* ne_new_tensor_4d(struct ne_context* ctx, enum ne_type type, int64_t ne0, int64_t ne1, int64_t ne2, - int64_t ne3, size_t size) { + int64_t ne3, size_t size, enum ne_backend bk) { const int64_t ne[4] = {ne0, ne1, ne2, ne3}; - return ne_new_tensor(ctx, type, 4, ne, size); + return ne_new_tensor(ctx, type, 4, ne, size, bk); } struct ne_tensor* ne_new_i32(struct ne_context* ctx, int32_t value) { ne_scratch_save(ctx); - struct ne_tensor* result = ne_new_tensor_1d(ctx, NE_TYPE_I32, 1, NE_SIZE_CALC); + struct ne_tensor* result = ne_new_tensor_1d(ctx, NE_TYPE_I32, 1, NE_SIZE_CALC, NE_BACKEND_CPU); ne_scratch_load(ctx); @@ -1050,7 +1235,7 @@ struct ne_tensor* ne_new_i32(struct ne_context* ctx, int32_t value) { struct ne_tensor* ne_new_f32(struct ne_context* ctx, float value) { ne_scratch_save(ctx); - struct ne_tensor* result = ne_new_tensor_1d(ctx, NE_TYPE_F32, 1, NE_SIZE_CALC); + struct ne_tensor* result = ne_new_tensor_1d(ctx, NE_TYPE_F32, 1, NE_SIZE_CALC, NE_BACKEND_CPU); ne_scratch_load(ctx); @@ -1060,7 +1245,11 @@ struct ne_tensor* ne_new_f32(struct ne_context* ctx, float value) { } struct ne_tensor* ne_dup_tensor(struct ne_context* ctx, const struct ne_tensor* src) { - return ne_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, NULL, src->size); + return ne_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, NULL, src->size, src->backend); +} + +struct ne_tensor* ne_dup_tensor_bk(struct ne_context* ctx, const struct ne_tensor* src, enum ne_backend bk) { + return ne_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, NULL, src->size, bk); } struct ne_tensor* ne_set_zero(struct ne_tensor* tensor) { @@ -1073,7 +1262,7 @@ struct ne_tensor* ne_set_i32(struct ne_tensor* tensor, int32_t value) { const int nc = tensor->ne[0]; const size_t n1 = tensor->nb[1]; - char* const data = tensor->data; + char* const data = (char* const)tensor->data; switch (tensor->type) { case NE_TYPE_I8: { @@ -1119,7 +1308,7 @@ struct ne_tensor* ne_set_f32(struct ne_tensor* tensor, float value) { const int nc = tensor->ne[0]; const size_t n1 = tensor->nb[1]; - char* const data = tensor->data; + char* const data = (char* const)tensor->data; switch (tensor->type) { case NE_TYPE_I8: { @@ -1295,7 +1484,19 @@ void ne_set_name(struct ne_tensor* tensor, const char* name) { } struct ne_tensor* ne_view_tensor(struct ne_context* ctx, const struct ne_tensor* src) { - struct ne_tensor* result = ne_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, src->data, src->size); + struct ne_tensor* result = + ne_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, src->data, src->size, src->backend); + + result->nb[0] = src->nb[0]; + result->nb[1] = src->nb[1]; + result->nb[2] = src->nb[2]; + result->nb[3] = src->nb[3]; + + return result; +} + +struct ne_tensor* ne_view_tensor_bk(struct ne_context* ctx, const struct ne_tensor* src, enum ne_backend bk) { + struct ne_tensor* result = ne_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, src->data, src->size, bk); result->nb[0] = src->nb[0]; result->nb[1] = src->nb[1]; @@ -1318,7 +1519,7 @@ struct ne_tensor* ne_dump_tensor(struct ne_context* ctx, struct ne_tensor* a) { struct ne_tensor* result = ne_view_tensor(ctx, a); result->op = NE_OP_DUMP_TENSOR; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = NULL; @@ -1335,10 +1536,9 @@ struct ne_tensor* ne_dup_impl(struct ne_context* ctx, struct ne_tensor* a, bool struct ne_tensor* result = inplace ? ne_view_tensor(ctx, a) : ne_dup_tensor(ctx, a); result->op = NE_OP_DUP; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = NULL; - return result; } @@ -1369,14 +1569,14 @@ struct ne_tensor* ne_add_impl(struct ne_context* ctx, struct ne_tensor* a, struc if (!inplace && (a->grad || b->grad)) { is_node = true; } + enum ne_op op = NE_OP_ADD; + enum ne_backend bk = bestla_backend_support(a, b, op); + struct ne_tensor* result = inplace ? ne_view_tensor_bk(ctx, a, bk) : ne_dup_tensor_bk(ctx, a, bk); - struct ne_tensor* result = inplace ? ne_view_tensor(ctx, a) : ne_dup_tensor(ctx, a); - - result->op = NE_OP_ADD; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->op = op; + result->grad = NULL; result->src0 = a; result->src1 = b; - return result; } @@ -1403,10 +1603,9 @@ struct ne_tensor* ne_add1_impl(struct ne_context* ctx, struct ne_tensor* a, stru struct ne_tensor* result = inplace ? ne_view_tensor(ctx, a) : ne_dup_tensor(ctx, a); result->op = NE_OP_ADD1; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = b; - return result; } @@ -1437,7 +1636,7 @@ struct ne_tensor* ne_acc_impl(struct ne_context* ctx, struct ne_tensor* a, struc ne_scratch_save(ctx); - struct ne_tensor* c = ne_new_tensor_1d(ctx, NE_TYPE_I32, 5, NE_SIZE_CALC); + struct ne_tensor* c = ne_new_tensor_1d(ctx, NE_TYPE_I32, 5, NE_SIZE_CALC, a->backend); ((int32_t*)c->data)[0] = nb1; ((int32_t*)c->data)[1] = nb2; @@ -1448,11 +1647,10 @@ struct ne_tensor* ne_acc_impl(struct ne_context* ctx, struct ne_tensor* a, struc ne_scratch_load(ctx); result->op = NE_OP_ACC; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = b; result->opt[0] = c; - return result; } @@ -1507,7 +1705,7 @@ struct ne_tensor* ne_tp_concat(struct ne_context* ctx, struct ne_tensor* a, enum ne_scratch_load(ctx); result->op = NE_OP_TP_CONCAT; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = NULL; result->opt[0] = c; @@ -1527,7 +1725,7 @@ struct ne_tensor* ne_all_reduce(struct ne_context* ctx, struct ne_tensor* a) { struct ne_tensor* result = ne_view_tensor(ctx, a); result->op = NE_OP_ALL_REDUCE; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = NULL; @@ -1583,7 +1781,7 @@ struct ne_tensor* ne_split(struct ne_context* ctx, struct ne_tensor* a, enum par ne_scratch_load(ctx); result->op = NE_OP_SPLIT; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = NULL; result->opt[0] = c; @@ -1606,10 +1804,9 @@ struct ne_tensor* ne_sub_impl(struct ne_context* ctx, struct ne_tensor* a, struc struct ne_tensor* result = inplace ? ne_view_tensor(ctx, a) : ne_dup_tensor(ctx, a); result->op = NE_OP_SUB; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = b; - return result; } @@ -1639,14 +1836,14 @@ struct ne_tensor* ne_mul_impl(struct ne_context* ctx, struct ne_tensor* a, struc if (inplace) { NE_ASSERT(is_node == false); } + enum ne_op op = NE_OP_MUL; + enum ne_backend bk = bestla_backend_support(a, b, op); + struct ne_tensor* result = inplace ? ne_view_tensor_bk(ctx, a, bk) : ne_dup_tensor_bk(ctx, a, bk); - struct ne_tensor* result = inplace ? ne_view_tensor(ctx, a) : ne_dup_tensor(ctx, a); - - result->op = NE_OP_MUL; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->op = op; + result->grad = NULL; result->src0 = a; result->src1 = b; - return result; } @@ -1656,9 +1853,8 @@ struct ne_tensor* ne_tanh(struct ne_context* ctx, struct ne_tensor* a) { struct ne_tensor* result = ne_dup_tensor(ctx, a); result->op = NE_OP_TANH; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; - return result; } @@ -1688,10 +1884,9 @@ struct ne_tensor* ne_div_impl(struct ne_context* ctx, struct ne_tensor* a, struc struct ne_tensor* result = inplace ? ne_view_tensor(ctx, a) : ne_dup_tensor(ctx, a); result->op = NE_OP_DIV; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = b; - return result; } @@ -1715,10 +1910,9 @@ struct ne_tensor* ne_sqr_impl(struct ne_context* ctx, struct ne_tensor* a, bool struct ne_tensor* result = inplace ? ne_view_tensor(ctx, a) : ne_dup_tensor(ctx, a); result->op = NE_OP_SQR; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = NULL; - return result; } @@ -1742,10 +1936,9 @@ struct ne_tensor* ne_sqrt_impl(struct ne_context* ctx, struct ne_tensor* a, bool struct ne_tensor* result = inplace ? ne_view_tensor(ctx, a) : ne_dup_tensor(ctx, a); result->op = NE_OP_SQRT; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = NULL; - return result; } @@ -1769,10 +1962,9 @@ struct ne_tensor* ne_log_impl(struct ne_context* ctx, struct ne_tensor* a, bool struct ne_tensor* result = inplace ? ne_view_tensor(ctx, a) : ne_dup_tensor(ctx, a); result->op = NE_OP_LOG; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = NULL; - return result; } @@ -1793,13 +1985,12 @@ struct ne_tensor* ne_sum(struct ne_context* ctx, struct ne_tensor* a) { is_node = true; } - struct ne_tensor* result = ne_new_tensor_1d(ctx, a->type, 1, NE_SIZE_CALC); + struct ne_tensor* result = ne_new_tensor_1d(ctx, a->type, 1, NE_SIZE_CALC, a->backend); result->op = NE_OP_SUM; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = NULL; - return result; } @@ -1817,13 +2008,12 @@ struct ne_tensor* ne_sum_rows(struct ne_context* ctx, struct ne_tensor* a) { ne[i] = a->ne[i]; } - struct ne_tensor* result = ne_new_tensor(ctx, a->type, a->n_dims, ne, a->size); + struct ne_tensor* result = ne_new_tensor(ctx, a->type, a->n_dims, ne, a->size, a->backend); result->op = NE_OP_SUM_ROWS; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = NULL; - return result; } @@ -1838,13 +2028,12 @@ struct ne_tensor* ne_mean(struct ne_context* ctx, struct ne_tensor* a) { } int64_t ne[NE_MAX_DIMS] = {1, a->ne[1], a->ne[2], a->ne[3]}; - struct ne_tensor* result = ne_new_tensor(ctx, NE_TYPE_F32, a->n_dims, ne, NE_SIZE_CALC); + struct ne_tensor* result = ne_new_tensor(ctx, NE_TYPE_F32, a->n_dims, ne, NE_SIZE_CALC, a->backend); result->op = NE_OP_MEAN; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = NULL; - return result; } @@ -1863,13 +2052,12 @@ struct ne_tensor* ne_repeat(struct ne_context* ctx, struct ne_tensor* a, struct return a; } - struct ne_tensor* result = ne_new_tensor(ctx, a->type, b->n_dims, b->ne, NE_SIZE_CALC); + struct ne_tensor* result = ne_new_tensor(ctx, a->type, b->n_dims, b->ne, NE_SIZE_CALC, a->backend); result->op = NE_OP_REPEAT; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = b; - return result; } @@ -1885,10 +2073,9 @@ struct ne_tensor* ne_abs_impl(struct ne_context* ctx, struct ne_tensor* a, bool struct ne_tensor* result = inplace ? ne_view_tensor(ctx, a) : ne_dup_tensor(ctx, a); result->op = NE_OP_ABS; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = NULL; - return result; } @@ -1912,10 +2099,9 @@ struct ne_tensor* ne_sgn_impl(struct ne_context* ctx, struct ne_tensor* a, bool struct ne_tensor* result = inplace ? ne_view_tensor(ctx, a) : ne_dup_tensor(ctx, a); result->op = NE_OP_SGN; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = NULL; - return result; } @@ -1939,10 +2125,9 @@ struct ne_tensor* ne_neg_impl(struct ne_context* ctx, struct ne_tensor* a, bool struct ne_tensor* result = inplace ? ne_view_tensor(ctx, a) : ne_dup_tensor(ctx, a); result->op = NE_OP_NEG; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = NULL; - return result; } @@ -1966,10 +2151,9 @@ struct ne_tensor* ne_step_impl(struct ne_context* ctx, struct ne_tensor* a, bool struct ne_tensor* result = inplace ? ne_view_tensor(ctx, a) : ne_dup_tensor(ctx, a); result->op = NE_OP_STEP; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = NULL; - return result; } @@ -1993,10 +2177,9 @@ struct ne_tensor* ne_relu_impl(struct ne_context* ctx, struct ne_tensor* a, bool struct ne_tensor* result = inplace ? ne_view_tensor(ctx, a) : ne_dup_tensor(ctx, a); result->op = NE_OP_RELU; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = NULL; - return result; } @@ -2020,10 +2203,9 @@ struct ne_tensor* ne_gelu_impl(struct ne_context* ctx, struct ne_tensor* a, bool struct ne_tensor* result = inplace ? ne_view_tensor(ctx, a) : ne_dup_tensor(ctx, a); result->op = NE_OP_GELU; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = NULL; - return result; } @@ -2043,14 +2225,14 @@ struct ne_tensor* ne_silu_impl(struct ne_context* ctx, struct ne_tensor* a, bool if (!inplace && (a->grad)) { is_node = true; } + enum ne_op op = NE_OP_SILU; + enum ne_backend bk = bestla_backend_support(a, NULL, op); + struct ne_tensor* result = inplace ? ne_view_tensor_bk(ctx, a, bk) : ne_dup_tensor_bk(ctx, a, bk); - struct ne_tensor* result = inplace ? ne_view_tensor(ctx, a) : ne_dup_tensor(ctx, a); - - result->op = NE_OP_SILU; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->op = op; + result->grad = NULL; result->src0 = a; result->src1 = NULL; - return result; } @@ -2075,10 +2257,9 @@ struct ne_tensor* ne_silu_back(struct ne_context* ctx, struct ne_tensor* a, stru struct ne_tensor* result = ne_dup_tensor(ctx, a); result->op = NE_OP_SILU_BACK; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = b; - return result; } @@ -2091,15 +2272,15 @@ struct ne_tensor* ne_norm_impl(struct ne_context* ctx, struct ne_tensor* a, bool NE_ASSERT(false); // TODO: implement backward is_node = true; } - - struct ne_tensor* result = inplace ? ne_view_tensor(ctx, a) : ne_dup_tensor(ctx, a); + enum ne_op op = NE_OP_NORM; + enum ne_backend bk = bestla_backend_support(a, NULL, op); + struct ne_tensor* result = inplace ? ne_view_tensor_bk(ctx, a, bk) : ne_dup_tensor_bk(ctx, a, bk); ne_set_op_params(result, &eps, sizeof(eps)); - result->op = NE_OP_NORM; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->op = op; + result->grad = NULL; result->src0 = a; - return result; } @@ -2118,14 +2299,15 @@ struct ne_tensor* ne_rms_norm_impl(struct ne_context* ctx, struct ne_tensor* a, is_node = true; } - struct ne_tensor* result = inplace ? ne_view_tensor(ctx, a) : ne_dup_tensor(ctx, a); + enum ne_op op = NE_OP_RMS_NORM; + enum ne_backend bk = bestla_backend_support(a, NULL, op); + struct ne_tensor* result = inplace ? ne_view_tensor_bk(ctx, a, bk) : ne_dup_tensor_bk(ctx, a, bk); ne_set_op_params(result, &eps, sizeof(eps)); - result->op = NE_OP_RMS_NORM; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->op = op; + result->grad = NULL; result->src0 = a; - return result; } @@ -2148,10 +2330,9 @@ struct ne_tensor* ne_rms_norm_back(struct ne_context* ctx, struct ne_tensor* a, struct ne_tensor* result = ne_dup_tensor(ctx, a); result->op = NE_OP_RMS_NORM_BACK; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = b; - return result; } @@ -2168,13 +2349,11 @@ struct ne_tensor* ne_mul_mat(struct ne_context* ctx, struct ne_tensor* a, struct } const int64_t ne[4] = {a->ne[1], b->ne[1], b->ne[2], b->ne[3]}; - struct ne_tensor* result = ne_new_tensor(ctx, NE_TYPE_F32, MAX(a->n_dims, b->n_dims), ne, NE_SIZE_CALC); - + struct ne_tensor* result = ne_new_tensor(ctx, NE_TYPE_F32, MAX(a->n_dims, b->n_dims), ne, NE_SIZE_CALC, a->backend); result->op = NE_OP_MUL_MAT; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = b; - return result; } @@ -2192,10 +2371,10 @@ struct ne_tensor* ne_mul_mat_with_bias(struct ne_context* ctx, struct ne_tensor* } const int64_t ne[4] = {w->ne[1], a->ne[1], w->ne[2], a->ne[3]}; - struct ne_tensor* result = ne_new_tensor(ctx, a->type, MIN(w->n_dims, a->n_dims), ne, NE_SIZE_CALC); + struct ne_tensor* result = ne_new_tensor(ctx, a->type, MIN(w->n_dims, a->n_dims), ne, NE_SIZE_CALC, w->backend); result->op = NE_OP_MUL_MAT_BIAS; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = w; result->src1 = a; result->opt[0] = b; @@ -2218,11 +2397,12 @@ struct ne_tensor* ne_mul_mat_id(struct ne_context* ctx, struct ne_tensor* const } const int64_t ne[4] = {as[0]->ne[1], b->ne[1], b->ne[2], b->ne[3]}; - struct ne_tensor* result = ne_new_tensor(ctx, NE_TYPE_F32, MAX(as[0]->n_dims, b->n_dims), ne, NE_SIZE_CALC); + struct ne_tensor* result = + ne_new_tensor(ctx, NE_TYPE_F32, MAX(as[0]->n_dims, b->n_dims), ne, NE_SIZE_CALC, as[0]->backend); int params[] = {id, n_as}; ne_set_op_params(result, ¶ms, sizeof(params)); result->op = NE_OP_MUL_MAT_ID; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = ids; result->src1 = b; @@ -2233,7 +2413,6 @@ struct ne_tensor* ne_mul_mat_id(struct ne_context* ctx, struct ne_tensor* const NE_ASSERT(!ne_is_transposed(a)); result->opt[i] = a; } - return result; } @@ -2258,14 +2437,14 @@ struct ne_tensor* ne_mul_id_ffn_silu(struct ne_context* ctx, struct ne_tensor* c is_node = true; } const int64_t ne[4] = {w2->ne[1], src->ne[1], src->ne[2], src->ne[3]}; - struct ne_tensor* result = ne_new_tensor(ctx, NE_TYPE_F32, src->n_dims, ne, NE_SIZE_CALC); + struct ne_tensor* result = ne_new_tensor(ctx, NE_TYPE_F32, src->n_dims, ne, NE_SIZE_CALC, w2->backend); const int64_t tne[4] = {w1->ne[1], src->ne[1], src->ne[2], src->ne[3]}; - struct ne_tensor* tmp = ne_new_tensor(ctx, NE_TYPE_F32, src->n_dims, tne, NE_SIZE_CALC); - struct ne_tensor* tmp1 = ne_new_tensor(ctx, NE_TYPE_F32, src->n_dims, tne, NE_SIZE_CALC); + struct ne_tensor* tmp = ne_new_tensor(ctx, NE_TYPE_F32, src->n_dims, tne, NE_SIZE_CALC, w1->backend); + struct ne_tensor* tmp1 = ne_new_tensor(ctx, NE_TYPE_F32, src->n_dims, tne, NE_SIZE_CALC, w1->backend); int params[] = {id, n_as}; ne_set_op_params(result, ¶ms, sizeof(params)); result->op = NE_OP_MUL_ID_FFN_SILU; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = src; result->src1 = ids; for (int i = 0; i < n_as; i++) { @@ -2303,14 +2482,14 @@ struct ne_tensor* ne_mul_id_ffn_gelu(struct ne_context* ctx, struct ne_tensor* c is_node = true; } const int64_t ne[4] = {w2->ne[1], src->ne[1], src->ne[2], src->ne[3]}; - struct ne_tensor* result = ne_new_tensor(ctx, NE_TYPE_F32, src->n_dims, ne, NE_SIZE_CALC); + struct ne_tensor* result = ne_new_tensor(ctx, NE_TYPE_F32, src->n_dims, ne, NE_SIZE_CALC, w2->backend); const int64_t tne[4] = {w1->ne[1], src->ne[1], src->ne[2], src->ne[3]}; - struct ne_tensor* tmp = ne_new_tensor(ctx, NE_TYPE_F32, src->n_dims, tne, NE_SIZE_CALC); - struct ne_tensor* tmp1 = ne_new_tensor(ctx, NE_TYPE_F32, src->n_dims, tne, NE_SIZE_CALC); + struct ne_tensor* tmp = ne_new_tensor(ctx, NE_TYPE_F32, src->n_dims, tne, NE_SIZE_CALC, w1->backend); + struct ne_tensor* tmp1 = ne_new_tensor(ctx, NE_TYPE_F32, src->n_dims, tne, NE_SIZE_CALC, w1->backend); int params[] = {id, n_as}; ne_set_op_params(result, ¶ms, sizeof(params)); result->op = NE_OP_MUL_ID_FFN_GELU; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = src; result->src1 = ids; for (int i = 0; i < n_as; i++) { @@ -2329,12 +2508,11 @@ struct ne_tensor* ne_mul_id_ffn_gelu(struct ne_context* ctx, struct ne_tensor* c struct ne_tensor* ne_argsort(struct ne_context* ctx, struct ne_tensor* a) { bool is_node = false; - struct ne_tensor* result = ne_new_tensor(ctx, NE_TYPE_I32, NE_MAX_DIMS, a->ne, NE_SIZE_CALC); + struct ne_tensor* result = ne_new_tensor(ctx, NE_TYPE_I32, NE_MAX_DIMS, a->ne, NE_SIZE_CALC, a->backend); result->op = NE_OP_ARGSORT; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; - return result; } @@ -2347,7 +2525,6 @@ struct ne_tensor* ne_top_k(struct ne_context* ctx, struct ne_tensor* a, int k) { result = ne_view_4d(ctx, result, k, result->ne[1], result->ne[2], result->ne[3], result->nb[1], result->nb[2], result->nb[3], 0); - return result; } // ne_mul_qkv @@ -2368,15 +2545,14 @@ struct ne_tensor* ne_mul_qkv(struct ne_context* ctx, struct ne_tensor* qw, struc } const int64_t ne[4] = {qw->ne[1], src->ne[1], src->ne[2] * 3, src->ne[3]}; - struct ne_tensor* result = ne_new_tensor(ctx, NE_TYPE_F32, MIN(src->n_dims, qw->n_dims), ne, NE_SIZE_CALC); + struct ne_tensor* result = ne_new_tensor(ctx, NE_TYPE_F32, 4, ne, NE_SIZE_CALC, qw->backend); result->op = NE_OP_MUL_QKV; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = src; result->src1 = qw; result->opt[0] = kw; result->opt[1] = vw; - return result; } @@ -2394,13 +2570,13 @@ struct ne_tensor* ne_ffn_silu(struct ne_context* ctx, struct ne_tensor* w1, stru } const int64_t ne[4] = {w2->ne[1], src->ne[1], src->ne[2], src->ne[3]}; - struct ne_tensor* result = ne_new_tensor(ctx, NE_TYPE_F32, src->n_dims, ne, NE_SIZE_CALC); + struct ne_tensor* result = ne_new_tensor(ctx, NE_TYPE_F32, src->n_dims, ne, NE_SIZE_CALC, w1->backend); const int64_t tne[4] = {w1->ne[1], src->ne[1], src->ne[2], src->ne[3]}; - struct ne_tensor* tmp = ne_new_tensor(ctx, NE_TYPE_F32, src->n_dims, tne, NE_SIZE_CALC); - struct ne_tensor* tmp1 = ne_new_tensor(ctx, NE_TYPE_F32, src->n_dims, tne, NE_SIZE_CALC); + struct ne_tensor* tmp = ne_new_tensor(ctx, NE_TYPE_F32, src->n_dims, tne, NE_SIZE_CALC, w1->backend); + struct ne_tensor* tmp1 = ne_new_tensor(ctx, NE_TYPE_F32, src->n_dims, tne, NE_SIZE_CALC, w1->backend); result->op = NE_OP_MUL_FFN_SILU; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = src; result->src1 = w1; result->opt[0] = w2; @@ -2421,12 +2597,12 @@ struct ne_tensor* ne_ffn_add_gelu(struct ne_context* ctx, struct ne_tensor* w1, } const int64_t ne[4] = {w2->ne[1], src->ne[1], src->ne[2], src->ne[3]}; - struct ne_tensor* result = ne_new_tensor(ctx, NE_TYPE_F32, src->n_dims, ne, NE_SIZE_CALC); + struct ne_tensor* result = ne_new_tensor(ctx, NE_TYPE_F32, src->n_dims, ne, NE_SIZE_CALC, w2->backend); const int64_t tne[4] = {w1->ne[1], src->ne[1], src->ne[2], src->ne[3]}; - struct ne_tensor* tmp = ne_new_tensor(ctx, NE_TYPE_F32, src->n_dims, tne, NE_SIZE_CALC); + struct ne_tensor* tmp = ne_new_tensor(ctx, NE_TYPE_F32, src->n_dims, tne, NE_SIZE_CALC, w1->backend); result->op = NE_OP_MUL_FFN_ADD_GELU; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = src; result->src1 = w1; result->opt[0] = w2; @@ -2447,12 +2623,12 @@ struct ne_tensor* ne_ffn_gelu(struct ne_context* ctx, struct ne_tensor* w1, stru } const int64_t ne[4] = {w2->ne[1], src->ne[1], src->ne[2], src->ne[3]}; - struct ne_tensor* result = ne_new_tensor(ctx, NE_TYPE_F32, src->n_dims, ne, NE_SIZE_CALC); + struct ne_tensor* result = ne_new_tensor(ctx, NE_TYPE_F32, src->n_dims, ne, NE_SIZE_CALC, w2->backend); const int64_t tne[4] = {w1->ne[1], src->ne[1], src->ne[2], src->ne[3]}; - struct ne_tensor* tmp = ne_new_tensor(ctx, NE_TYPE_F32, src->n_dims, tne, NE_SIZE_CALC); + struct ne_tensor* tmp = ne_new_tensor(ctx, NE_TYPE_F32, src->n_dims, tne, NE_SIZE_CALC, w1->backend); result->op = NE_OP_MUL_FFN_GELU; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = src; result->src1 = w1; result->opt[0] = w2; @@ -2472,13 +2648,13 @@ struct ne_tensor* ne_ffn_gelu_mul(struct ne_context* ctx, struct ne_tensor* w1, } const int64_t ne[4] = {w2->ne[1], src->ne[1], src->ne[2], src->ne[3]}; - struct ne_tensor* result = ne_new_tensor(ctx, NE_TYPE_F32, src->n_dims, ne, NE_SIZE_CALC); + struct ne_tensor* result = ne_new_tensor(ctx, NE_TYPE_F32, src->n_dims, ne, NE_SIZE_CALC, w2->backend); const int64_t tne[4] = {w1->ne[1], src->ne[1], src->ne[2], src->ne[3]}; - struct ne_tensor* tmp = ne_new_tensor(ctx, NE_TYPE_F32, src->n_dims, tne, NE_SIZE_CALC); - struct ne_tensor* tmp1 = ne_new_tensor(ctx, NE_TYPE_F32, src->n_dims, tne, NE_SIZE_CALC); + struct ne_tensor* tmp = ne_new_tensor(ctx, NE_TYPE_F32, src->n_dims, tne, NE_SIZE_CALC, w1->backend); + struct ne_tensor* tmp1 = ne_new_tensor(ctx, NE_TYPE_F32, src->n_dims, tne, NE_SIZE_CALC, w1->backend); result->op = NE_OP_MUL_FFN_GELU_MUL; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = src; result->src1 = w1; result->opt[0] = w2; @@ -2502,10 +2678,9 @@ struct ne_tensor* ne_scale_impl(struct ne_context* ctx, struct ne_tensor* a, str struct ne_tensor* result = inplace ? ne_view_tensor(ctx, a) : ne_dup_tensor(ctx, a); result->op = NE_OP_SCALE; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = b; - return result; } @@ -2534,7 +2709,7 @@ struct ne_tensor* ne_set_impl(struct ne_context* ctx, struct ne_tensor* a, struc ne_scratch_save(ctx); - struct ne_tensor* c = ne_new_tensor_1d(ctx, NE_TYPE_I32, 5, NE_SIZE_CALC); + struct ne_tensor* c = ne_new_tensor_1d(ctx, NE_TYPE_I32, 5, NE_SIZE_CALC, a->backend); ((int32_t*)c->data)[0] = nb1; ((int32_t*)c->data)[1] = nb2; @@ -2545,11 +2720,10 @@ struct ne_tensor* ne_set_impl(struct ne_context* ctx, struct ne_tensor* a, struc ne_scratch_load(ctx); result->op = NE_OP_SET; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = b; result->opt[0] = c; - return result; } @@ -2596,10 +2770,9 @@ struct ne_tensor* ne_cpy_impl(struct ne_context* ctx, struct ne_tensor* a, struc struct ne_tensor* result = ne_view_tensor(ctx, b); result->op = NE_OP_CPY; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = b; - return result; } @@ -2623,10 +2796,9 @@ struct ne_tensor* ne_cont_impl(struct ne_context* ctx, struct ne_tensor* a, bool struct ne_tensor* result = inplace ? ne_view_tensor(ctx, a) : ne_dup_tensor(ctx, a); result->op = NE_OP_CONT; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = NULL; - return result; } @@ -2656,13 +2828,12 @@ struct ne_tensor* ne_reshape(struct ne_context* ctx, struct ne_tensor* a, struct // NE_ASSERT(false); } - struct ne_tensor* result = ne_new_tensor_impl(ctx, a->type, b->n_dims, b->ne, a->data, NE_SIZE_CALC); + struct ne_tensor* result = ne_new_tensor_impl(ctx, a->type, b->n_dims, b->ne, a->data, NE_SIZE_CALC, a->backend); result->op = NE_OP_RESHAPE; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = NULL; - return result; } @@ -2675,15 +2846,16 @@ struct ne_tensor* ne_reshape_1d(struct ne_context* ctx, struct ne_tensor* a, int if (a->grad) { is_node = true; } - + enum ne_op op = NE_OP_RESHAPE; + enum ne_backend bk = bestla_backend_support(a, NULL, op); const int64_t ne[1] = {ne0}; - struct ne_tensor* result = ne_new_tensor_impl(ctx, a->type, 1, ne, a->data, NE_SIZE_CALC); + struct ne_tensor* result = + ne_new_tensor_impl(ctx, a->type, 1, ne, a->backend == bk ? a->data : NULL, NE_SIZE_CALC, bk); - result->op = NE_OP_RESHAPE; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->op = op; + result->grad = NULL; result->src0 = a; result->src1 = NULL; - return result; } @@ -2696,15 +2868,14 @@ struct ne_tensor* ne_reshape_2d(struct ne_context* ctx, struct ne_tensor* a, int if (a->grad) { is_node = true; } - + enum ne_op op = NE_OP_RESHAPE; const int64_t ne[2] = {ne0, ne1}; - struct ne_tensor* result = ne_new_tensor_impl(ctx, a->type, 2, ne, a->data, NE_SIZE_CALC); + struct ne_tensor* result = ne_new_tensor_impl(ctx, a->type, 2, ne, a->data, NE_SIZE_CALC, a->backend); - result->op = NE_OP_RESHAPE; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->op = op; + result->grad = NULL; result->src0 = a; result->src1 = NULL; - return result; } @@ -2719,13 +2890,15 @@ struct ne_tensor* ne_reshape_3d(struct ne_context* ctx, struct ne_tensor* a, int } const int64_t ne[3] = {ne0, ne1, ne2}; - struct ne_tensor* result = ne_new_tensor_impl(ctx, a->type, 3, ne, a->data, NE_SIZE_CALC); + enum ne_op op = NE_OP_RESHAPE; + enum ne_backend bk = bestla_backend_support(a, NULL, op); + struct ne_tensor* result = + ne_new_tensor_impl(ctx, a->type, 3, ne, a->backend == bk ? a->data : NULL, NE_SIZE_CALC, bk); - result->op = NE_OP_RESHAPE; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->op = op; + result->grad = NULL; result->src0 = a; result->src1 = NULL; - return result; } @@ -2739,15 +2912,27 @@ struct ne_tensor* ne_reshape_4d(struct ne_context* ctx, struct ne_tensor* a, int if (a->grad) { is_node = true; } - + enum ne_op op = NE_OP_RESHAPE; const int64_t ne[4] = {ne0, ne1, ne2, ne3}; - struct ne_tensor* result = ne_new_tensor_impl(ctx, a->type, 4, ne, a->data, NE_SIZE_CALC); - - result->op = NE_OP_RESHAPE; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + struct ne_tensor* result = ne_new_tensor_impl(ctx, a->type, 4, ne, a->data, a->size, a->backend); + result->op = op; + result->grad = NULL; result->src0 = a; result->src1 = NULL; + return result; +} +struct ne_tensor* ne_device_sync(struct ne_context* ctx, struct ne_tensor* a, enum ne_backend bk) { + enum ne_op op = NE_OP_RESHAPE; + struct ne_tensor* result = + ne_new_tensor_impl(ctx, a->type, a->n_dims, a->ne, a->backend == bk ? a->data : NULL, NE_SIZE_CALC, bk); + for (size_t i = 0; i < a->n_dims; i++) { + result->nb[i] = a->nb[i]; + } + result->op = op; + result->grad = NULL; + result->src0 = a; + result->src1 = NULL; return result; } @@ -2760,17 +2945,17 @@ struct ne_tensor* ne_view_1d(struct ne_context* ctx, struct ne_tensor* a, int64_ is_node = true; } - struct ne_tensor* result = ne_new_tensor_impl(ctx, a->type, 1, &ne0, (char*)a->data + offset, NE_SIZE_CALC); + struct ne_tensor* result = + ne_new_tensor_impl(ctx, a->type, 1, &ne0, (char*)a->data + offset, NE_SIZE_CALC, a->backend); result->op = NE_OP_VIEW; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = NULL; if (is_node) { memcpy(result->padding, &offset, sizeof(offset)); } - return result; } @@ -2786,21 +2971,20 @@ struct ne_tensor* ne_view_2d(struct ne_context* ctx, struct ne_tensor* a, int64_ const int64_t ne[NE_MAX_DIMS] = {ne0, ne1, 1, 1}; - struct ne_tensor* result = ne_new_tensor_impl(ctx, a->type, 2, ne, (char*)a->data + offset, NE_SIZE_CALC); + struct ne_tensor* result = ne_new_tensor_impl(ctx, a->type, 2, ne, (char*)a->data + offset, NE_SIZE_CALC, a->backend); result->nb[1] = nb1; result->nb[2] = result->nb[1] * ne1; result->nb[3] = result->nb[2]; result->op = NE_OP_VIEW; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = NULL; if (is_node) { memcpy(result->padding, &offset, sizeof(offset)); } - return result; } @@ -2816,21 +3000,20 @@ struct ne_tensor* ne_view_3d(struct ne_context* ctx, struct ne_tensor* a, int64_ const int64_t ne[NE_MAX_DIMS] = {ne0, ne1, ne2, 1}; - struct ne_tensor* result = ne_new_tensor_impl(ctx, a->type, 3, ne, (char*)a->data + offset, NE_SIZE_CALC); + struct ne_tensor* result = ne_new_tensor_impl(ctx, a->type, 3, ne, (char*)a->data + offset, NE_SIZE_CALC, a->backend); result->nb[1] = nb1; result->nb[2] = nb2; result->nb[3] = result->nb[2] * ne2; result->op = NE_OP_VIEW; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = NULL; if (is_node) { memcpy(result->padding, &offset, sizeof(offset)); } - return result; } @@ -2846,21 +3029,20 @@ struct ne_tensor* ne_view_4d(struct ne_context* ctx, struct ne_tensor* a, int64_ const int64_t ne[NE_MAX_DIMS] = {ne0, ne1, ne2, ne3}; - struct ne_tensor* result = ne_new_tensor_impl(ctx, a->type, 4, ne, (char*)a->data + offset, NE_SIZE_CALC); + struct ne_tensor* result = ne_new_tensor_impl(ctx, a->type, 4, ne, (char*)a->data + offset, NE_SIZE_CALC, a->backend); result->nb[1] = nb1; result->nb[2] = nb2; result->nb[3] = nb3; result->op = NE_OP_VIEW; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = NULL; if (is_node) { memcpy(result->padding, &offset, sizeof(offset)); } - return result; } @@ -2900,6 +3082,8 @@ struct ne_tensor* ne_permute(struct ne_context* ctx, struct ne_tensor* a, int ax nb[axis2] = a->nb[2]; nb[axis3] = a->nb[3]; + result->size = a->size; + result->ne[0] = ne[0]; result->ne[1] = ne[1]; result->ne[2] = ne[2]; @@ -2911,7 +3095,7 @@ struct ne_tensor* ne_permute(struct ne_context* ctx, struct ne_tensor* a, int ax result->nb[3] = nb[3]; result->op = NE_OP_PERMUTE; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = NULL; @@ -2921,7 +3105,6 @@ struct ne_tensor* ne_permute(struct ne_context* ctx, struct ne_tensor* a, int ax result->padding[2] = axis2; result->padding[3] = axis3; } - return result; } @@ -2943,10 +3126,9 @@ struct ne_tensor* ne_transpose(struct ne_context* ctx, struct ne_tensor* a) { result->nb[1] = a->nb[0]; result->op = NE_OP_TRANSPOSE; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = NULL; - return result; } @@ -2968,13 +3150,13 @@ struct ne_tensor* ne_get_rows(struct ne_context* ctx, struct ne_tensor* a, struc } // TODO: implement non F32 return // struct ne_tensor * result = ne_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]); - struct ne_tensor* result = ne_new_tensor_4d(ctx, NE_TYPE_F32, a->ne[0], b->ne[0], b->ne[1], b->ne[2], NE_SIZE_CALC); + struct ne_tensor* result = + ne_new_tensor_4d(ctx, NE_TYPE_F32, a->ne[0], b->ne[0], b->ne[1], b->ne[2], NE_SIZE_CALC, a->backend); result->op = NE_OP_GET_ROWS; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = b; - return result; } @@ -2993,14 +3175,13 @@ struct ne_tensor* ne_get_rows_back(struct ne_context* ctx, struct ne_tensor* a, // TODO: implement non F32 return // struct ne_tensor * result = ne_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]); - struct ne_tensor* result = ne_new_tensor_2d(ctx, NE_TYPE_F32, c->ne[0], c->ne[1], NE_SIZE_CALC); + struct ne_tensor* result = ne_new_tensor_2d(ctx, NE_TYPE_F32, c->ne[0], c->ne[1], NE_SIZE_CALC, a->backend); result->op = NE_OP_GET_ROWS_BACK; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = b; result->opt[0] = c; - return result; } @@ -3015,13 +3196,12 @@ struct ne_tensor* ne_diag(struct ne_context* ctx, struct ne_tensor* a) { } const int64_t ne[4] = {a->ne[0], a->ne[0], a->ne[2], a->ne[3]}; - struct ne_tensor* result = ne_new_tensor(ctx, a->type, MAX(a->n_dims, 2), ne, NE_SIZE_CALC); + struct ne_tensor* result = ne_new_tensor(ctx, a->type, MAX(a->n_dims, 2), ne, NE_SIZE_CALC, a->backend); result->op = NE_OP_DIAG; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = NULL; - return result; } @@ -3041,7 +3221,7 @@ struct ne_tensor* ne_diag_mask_inf_impl(struct ne_context* ctx, struct ne_tensor ne_scratch_save(ctx); const int bs = a->ne[3]; - struct ne_tensor* b = ne_new_tensor_1d(ctx, NE_TYPE_I32, 2 + bs, NE_SIZE_CALC); + struct ne_tensor* b = ne_new_tensor_1d(ctx, NE_TYPE_I32, 2 + bs, NE_SIZE_CALC, a->backend); ((int32_t*)b->data)[0] = n_past; ((int32_t*)b->data)[1] = inplace ? 1 : 0; @@ -3056,10 +3236,9 @@ struct ne_tensor* ne_diag_mask_inf_impl(struct ne_context* ctx, struct ne_tensor ne_scratch_load(ctx); result->op = NE_OP_DIAG_MASK_INF; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = b; - return result; } @@ -3094,7 +3273,7 @@ struct ne_tensor* ne_diag_mask_zero_impl(struct ne_context* ctx, struct ne_tenso ne_scratch_save(ctx); - struct ne_tensor* b = ne_new_tensor_1d(ctx, NE_TYPE_I32, 2, NE_SIZE_CALC); + struct ne_tensor* b = ne_new_tensor_1d(ctx, NE_TYPE_I32, 2, NE_SIZE_CALC, a->backend); ne_set_name(b, "n_past, inplace"); ((int32_t*)b->data)[0] = n_past; @@ -3103,10 +3282,9 @@ struct ne_tensor* ne_diag_mask_zero_impl(struct ne_context* ctx, struct ne_tenso ne_scratch_load(ctx); result->op = NE_OP_DIAG_MASK_ZERO; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = b; - return result; } @@ -3139,7 +3317,7 @@ struct ne_tensor* ne_padding_mask_inf_impl(struct ne_context* ctx, struct ne_ten #define PM_PADDING_IDX 2 const int bs = a->ne[3]; - struct ne_tensor* b = ne_new_tensor_1d(ctx, NE_TYPE_I32, PM_PARAMS_NUM + bs, NE_SIZE_CALC); + struct ne_tensor* b = ne_new_tensor_1d(ctx, NE_TYPE_I32, PM_PARAMS_NUM + bs, NE_SIZE_CALC, a->backend); ((int32_t*)b->data)[PM_NPAST_IDX] = n_past; ((int32_t*)b->data)[PM_INPLACE_IDX] = inplace ? 1 : 0; @@ -3154,10 +3332,9 @@ struct ne_tensor* ne_padding_mask_inf_impl(struct ne_context* ctx, struct ne_ten ne_scratch_load(ctx); result->op = NE_OP_PADDING_MASK_INF; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = b; - return result; } @@ -3181,10 +3358,9 @@ struct ne_tensor* ne_soft_max_impl(struct ne_context* ctx, struct ne_tensor* a, struct ne_tensor* result = inplace ? ne_view_tensor(ctx, a) : ne_dup_tensor(ctx, a); result->op = NE_OP_SOFT_MAX; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = NULL; - return result; } @@ -3224,7 +3400,7 @@ struct ne_tensor* ne_rope_impl(struct ne_context* ctx, struct ne_tensor* a, int #define ROPE_PADDING_IDX 5 const int bs = a->ne[3]; - struct ne_tensor* b = ne_new_tensor_1d(ctx, NE_TYPE_I32, ROPE_PARAMS_NUM + bs, NE_SIZE_CALC); + struct ne_tensor* b = ne_new_tensor_1d(ctx, NE_TYPE_I32, ROPE_PARAMS_NUM + bs, NE_SIZE_CALC, NE_BACKEND_CPU); ((int32_t*)b->data)[ROPE_NPAST_IDX] = n_past; ((int32_t*)b->data)[ROPE_NDIMS_IDX] = n_dims; @@ -3249,12 +3425,11 @@ struct ne_tensor* ne_rope_impl(struct ne_context* ctx, struct ne_tensor* a, int ne_set_op_params(result, ¶ms, sizeof(params)); result->op = NE_OP_ROPE; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = b; result->opt[0] = cossin; result->opt[1] = factor; - return result; } @@ -3307,7 +3482,7 @@ struct ne_tensor* ne_rope_back(struct ne_context* ctx, struct ne_tensor* a, int ne_scratch_save(ctx); - struct ne_tensor* b = ne_new_tensor_1d(ctx, NE_TYPE_I32, 3, NE_SIZE_CALC); + struct ne_tensor* b = ne_new_tensor_1d(ctx, NE_TYPE_I32, 3, NE_SIZE_CALC, a->backend); ne_set_name(b, "n_past, n_dims, mode"); ((int32_t*)b->data)[0] = n_past; @@ -3317,10 +3492,9 @@ struct ne_tensor* ne_rope_back(struct ne_context* ctx, struct ne_tensor* a, int ne_scratch_load(ctx); result->op = NE_OP_ROPE_BACK; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = b; - return result; } @@ -3360,7 +3534,7 @@ struct ne_tensor* ne_alibi(struct ne_context* ctx, struct ne_tensor* a, int n_pa ne_scratch_save(ctx); - struct ne_tensor* b = ne_new_tensor_1d(ctx, NE_TYPE_I32, 3, NE_SIZE_CALC); + struct ne_tensor* b = ne_new_tensor_1d(ctx, NE_TYPE_I32, 3, NE_SIZE_CALC, a->backend); ((int32_t*)b->data)[0] = n_past; ((int32_t*)b->data)[1] = n_head; @@ -3370,10 +3544,9 @@ struct ne_tensor* ne_alibi(struct ne_context* ctx, struct ne_tensor* a, int n_pa ne_scratch_load(ctx); result->op = NE_OP_ALIBI; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = b; - return result; } @@ -3392,7 +3565,7 @@ struct ne_tensor* ne_clamp(struct ne_context* ctx, struct ne_tensor* a, float mi ne_scratch_save(ctx); - struct ne_tensor* b = ne_new_tensor_1d(ctx, NE_TYPE_I32, 3, NE_SIZE_CALC); + struct ne_tensor* b = ne_new_tensor_1d(ctx, NE_TYPE_I32, 3, NE_SIZE_CALC, a->backend); ((float*)b->data)[0] = min; ((float*)b->data)[1] = max; @@ -3400,10 +3573,9 @@ struct ne_tensor* ne_clamp(struct ne_context* ctx, struct ne_tensor* a, float mi ne_scratch_load(ctx); result->op = NE_OP_CLAMP; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = b; - return result; } @@ -3426,13 +3598,12 @@ struct ne_tensor* ne_conv_1d_1s(struct ne_context* ctx, struct ne_tensor* a, str 1, 1, }; - struct ne_tensor* result = ne_new_tensor(ctx, NE_TYPE_F32, 2, ne, NE_SIZE_CALC); + struct ne_tensor* result = ne_new_tensor(ctx, NE_TYPE_F32, 2, ne, NE_SIZE_CALC, a->backend); result->op = NE_OP_CONV_1D_1S; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = b; - return result; } @@ -3455,13 +3626,12 @@ struct ne_tensor* ne_conv_1d_2s(struct ne_context* ctx, struct ne_tensor* a, str 1, 1, }; - struct ne_tensor* result = ne_new_tensor(ctx, NE_TYPE_F32, 2, ne, NE_SIZE_CALC); + struct ne_tensor* result = ne_new_tensor(ctx, NE_TYPE_F32, 2, ne, NE_SIZE_CALC, a->backend); result->op = NE_OP_CONV_1D_2S; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = b; - return result; } @@ -3487,16 +3657,15 @@ NE_API struct ne_tensor* ne_conv_1d(struct ne_context* ctx, struct ne_tensor* a, 1, 1, }; - struct ne_tensor* result = ne_new_tensor(ctx, NE_TYPE_F32, 2, ne, NE_SIZE_CALC); + struct ne_tensor* result = ne_new_tensor(ctx, NE_TYPE_F32, 2, ne, NE_SIZE_CALC, a->backend); int32_t params[] = {s0, p0, d0}; ne_set_op_params(result, params, sizeof(params)); result->op = NE_OP_CONV_1D; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = b; - return result; } @@ -3510,6 +3679,33 @@ struct ne_tensor* ne_conv_1d_ph(struct ne_context* ctx, struct ne_tensor* a, str struct ne_tensor* ne_flash_attn(struct ne_context* ctx, struct ne_tensor* q, struct ne_tensor* k, struct ne_tensor* v, float scale, ne_attn_flags_t flags) { + if (q->backend == NE_BACKEND_SYCL) { + int headsize = q->ne[0]; + int headnum = q->ne[1]; + int seq_cur = q->ne[2]; + int batch = q->ne[3]; + int heads_kv = k->ne[2]; + int seq_all = k->ne[1]; + // int seq_past = seq_all - seq_cur; + NE_ASSERT(("headnum must be a multiple of heads_kv", headnum % heads_kv == 0)); + NE_ASSERT(headsize == k->ne[0]); + NE_ASSERT(headsize == v->ne[1]); + NE_ASSERT(seq_all == v->ne[0]); + NE_ASSERT(("n_heads must be the same for K/V!", k->ne[2] == v->ne[2])); + NE_ASSERT(batch == k->ne[3]); + NE_ASSERT(batch == v->ne[3]); + bool is_node = true; + struct ne_tensor* result = + ne_new_tensor_4d(ctx, NE_TYPE_F32, headsize, headnum, seq_cur, batch, NE_SIZE_CALC, q->backend); + result->op = NE_OP_FLASH_ATTN; + result->grad = NULL; + result->src0 = q; + result->src1 = k; + result->opt[0] = v; + *(float*)result->padding = scale; + *(uint32_t*)&result->padding[4] = flags; + return result; + } NE_ASSERT(ne_can_mul_mat(k, q)); int batch = q->ne[3]; int headnum = q->ne[2]; @@ -3526,10 +3722,11 @@ struct ne_tensor* ne_flash_attn(struct ne_context* ctx, struct ne_tensor* q, str NE_ASSERT(batch == k->ne[3]); NE_ASSERT(batch == v->ne[3]); bool is_node = true; - struct ne_tensor* result = ne_new_tensor_4d(ctx, NE_TYPE_F32, headsize, headnum, seq_cur, batch, NE_SIZE_CALC); - attn_shape_t atte_shape = {batch, headnum, headsize, seq_cur, seq_all}; + struct ne_tensor* result = + ne_new_tensor_4d(ctx, NE_TYPE_F32, headsize, headnum, seq_cur, batch, NE_SIZE_CALC, q->backend); + attn_shape_t atte_shape = {batch, headnum, heads_kv, headsize, seq_cur, seq_all}; size_t tmpsize = bestla_fusion_attn_workspace_size(&atte_shape); - struct ne_tensor* tmp_t = ne_new_tensor_1d(ctx, NE_TYPE_I8, tmpsize, NE_SIZE_CALC); + struct ne_tensor* tmp_t = ne_new_tensor_1d(ctx, NE_TYPE_I8, tmpsize, NE_SIZE_CALC, q->backend); result->op = NE_OP_FLASH_ATTN; result->grad = NULL; result->src0 = q; @@ -3550,7 +3747,7 @@ struct ne_tensor* ne_flash_attn_kv_update(struct ne_context* ctx, struct ne_tens ne_scratch_save(ctx); - struct ne_tensor* params = ne_new_tensor_1d(ctx, NE_TYPE_I32, 3, NE_SIZE_CALC); + struct ne_tensor* params = ne_new_tensor_1d(ctx, NE_TYPE_I32, 3, NE_SIZE_CALC, cache->backend); ((int32_t*)params->data)[0] = n_past; ((int32_t*)params->data)[1] = (int)is_v; @@ -3589,16 +3786,15 @@ struct ne_tensor* ne_flash_ff(struct ne_context* ctx, struct ne_tensor* a, struc } // struct ne_tensor * result = ne_dup_tensor(ctx, a); - struct ne_tensor* result = ne_new_tensor(ctx, NE_TYPE_F32, 4, a->ne, NE_SIZE_CALC); + struct ne_tensor* result = ne_new_tensor(ctx, NE_TYPE_F32, 4, a->ne, NE_SIZE_CALC, a->backend); result->op = NE_OP_FLASH_FF; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = b0; result->opt[0] = b1; result->opt[1] = c0; result->opt[2] = c1; - return result; } @@ -3612,15 +3808,15 @@ struct ne_tensor* ne_map_unary_impl_f32(struct ne_context* ctx, struct ne_tensor is_node = true; } - struct ne_tensor* addr_tensor = ne_new_tensor_1d(ctx, NE_TYPE_I32, sizeof(void*) / sizeof(int32_t), NE_SIZE_CALC); + struct ne_tensor* addr_tensor = + ne_new_tensor_1d(ctx, NE_TYPE_I32, sizeof(void*) / sizeof(int32_t), NE_SIZE_CALC, a->backend); *((void (**)(void))addr_tensor->data) = (void (*)(void))fun; struct ne_tensor* result = inplace ? ne_view_tensor(ctx, a) : ne_dup_tensor(ctx, a); result->op = NE_OP_MAP_UNARY; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->opt[0] = addr_tensor; - return result; } @@ -3644,16 +3840,16 @@ struct ne_tensor* ne_map_binary_impl_f32(struct ne_context* ctx, struct ne_tenso is_node = true; } - struct ne_tensor* addr_tensor = ne_new_tensor_1d(ctx, NE_TYPE_I32, sizeof(void*) / sizeof(int32_t), NE_SIZE_CALC); + struct ne_tensor* addr_tensor = + ne_new_tensor_1d(ctx, NE_TYPE_I32, sizeof(void*) / sizeof(int32_t), NE_SIZE_CALC, a->backend); *((void (**)(void))addr_tensor->data) = (void (*)(void))fun; struct ne_tensor* result = inplace ? ne_view_tensor(ctx, a) : ne_dup_tensor(ctx, a); result->op = NE_OP_MAP_BINARY; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = b; result->opt[0] = addr_tensor; - return result; } @@ -4052,6 +4248,14 @@ static void ne_compute_forward_dup_f16(const struct ne_compute_params* params, c static void ne_compute_forward_dup_f32(const struct ne_compute_params* params, const struct ne_tensor* src0, struct ne_tensor* dst) { NE_ASSERT(ne_nelements(dst) == ne_nelements(src0)); + if (dst->backend == NE_BACKEND_SYCL) { +#ifdef NS_SYCL + bestla_device_dup_f32(params, src0, dst); +#else + NE_ASSERT(false); +#endif + return; + } if (params->type == NE_TASK_INIT || params->type == NE_TASK_FINALIZE) { return; @@ -4332,7 +4536,7 @@ static void ne_compute_forward_dup_f32(const struct ne_compute_params* params, c static void ne_compute_forward_debug(const struct ne_compute_params* params, const struct ne_tensor* src0, struct ne_tensor* dst) { if (params->type == NE_TASK_INIT || params->type == NE_TASK_FINALIZE) return; - const ne_debug_callback_t cb = *((void**)(dst->padding)); + const ne_debug_callback_t cb = *((ne_debug_callback_t*)(dst->padding)); cb(src0); } @@ -4360,8 +4564,19 @@ static void ne_compute_forward_dup(const struct ne_compute_params* params, const static void ne_compute_forward_add_f32(const struct ne_compute_params* params, const struct ne_tensor* src0, const struct ne_tensor* src1, struct ne_tensor* dst) { NE_ASSERT(ne_can_repeat_rows(src1, src0) && ne_are_same_shape(src0, dst)); + if (dst->backend == NE_BACKEND_SYCL) { +#ifdef NS_SYCL + bestla_device_add_f32(params, src0, src1, dst); +#else + NE_ASSERT(0); +#endif + return; + } + if (params->type == NE_TASK_INIT) { + return; + } - if (params->type == NE_TASK_INIT || params->type == NE_TASK_FINALIZE) { + if (params->type == NE_TASK_FINALIZE) { return; } @@ -4394,6 +4609,9 @@ static void ne_compute_forward_add_f32(const struct ne_compute_params* params, c const size_t nb2 = dst->nb[2]; const size_t nb3 = dst->nb[3]; + float* src0ptr = (float*)src0->data; + float* src1ptr = (float*)src1->data; + float* dstptr = (float*)dst->data; NE_ASSERT(nb0 == sizeof(float)); NE_ASSERT(nb00 == sizeof(float)); NE_ASSERT(ne00 == ne10); @@ -4401,7 +4619,7 @@ static void ne_compute_forward_add_f32(const struct ne_compute_params* params, c if ((ne_nrows(src1) == 1 || ne_nrows(src1) == ne_nrows(src0)) && ne10 == ne00) { if (nb10 == sizeof(float)) { int step1 = ne11 == 1 ? 0 : ne10; - bestla_add(nr, ne00, (const float*)src0->data, (const float*)src1->data, step1, (float*)dst->data); + bestla_add(nr, ne00, (const float*)src0ptr, (const float*)src1ptr, step1, dstptr); return; } } @@ -4417,9 +4635,9 @@ static void ne_compute_forward_add_f32(const struct ne_compute_params* params, c const int64_t i12 = i02 % ne12; const int64_t i11 = i01 % ne11; - float* dst_ptr = (float*)((char*)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1); - float* src0_ptr = (float*)((char*)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01); - float* src1_ptr = (float*)((char*)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11); + float* dst_ptr = (float*)(dstptr + i03 * nb3 + i02 * nb2 + i01 * nb1); + float* src0_ptr = (float*)((char*)src0ptr + i03 * nb03 + i02 * nb02 + i01 * nb01); + float* src1_ptr = (float*)((char*)src1ptr + i13 * nb13 + i12 * nb12 + i11 * nb11); ne_vec_add_f32(ne00, dst_ptr, src0_ptr, src1_ptr); } @@ -4436,11 +4654,11 @@ static void ne_compute_forward_add_f32(const struct ne_compute_params* params, c const int64_t i12 = i02 % ne12; const int64_t i11 = i01 % ne11; - float* dst_ptr = (float*)((char*)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1); - float* src0_ptr = (float*)((char*)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01); + float* dst_ptr = (float*)(dstptr + i03 * nb3 + i02 * nb2 + i01 * nb1); + float* src0_ptr = (float*)((char*)src0ptr + i03 * nb03 + i02 * nb02 + i01 * nb01); for (int64_t i0 = 0; i0 < ne00; i0++) { - float* src1_ptr = (float*)((char*)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11 + i0 * nb10); + float* src1_ptr = (float*)((char*)src1ptr + i13 * nb13 + i12 * nb12 + i11 * nb11 + i0 * nb10); dst_ptr[i0] = src0_ptr[i0] + *src1_ptr; } @@ -5411,7 +5629,14 @@ static void ne_compute_forward_sub(const struct ne_compute_params* params, const static void ne_compute_forward_mul_f32(const struct ne_compute_params* params, const struct ne_tensor* src0, const struct ne_tensor* src1, struct ne_tensor* dst) { NE_ASSERT(ne_can_repeat_rows(src1, src0) && ne_are_same_shape(src0, dst)); - + if (dst->backend == NE_BACKEND_SYCL) { +#ifdef NS_SYCL + bestla_device_mul_f32(params, src0, src1, dst); +#else + NE_ASSERT(0); +#endif + return; + } if (params->type == NE_TASK_INIT || params->type == NE_TASK_FINALIZE) { return; } @@ -5444,6 +5669,7 @@ static void ne_compute_forward_mul_f32(const struct ne_compute_params* params, c const size_t nb1 = dst->nb[1]; const size_t nb2 = dst->nb[2]; const size_t nb3 = dst->nb[3]; + if (ne_is_contiguous(src0) && ne_is_contiguous(src1)) { if ((ne_nrows(src1) == 1 || ne_nrows(src1) == ne_nrows(src0)) && ne10 == ne00) { if (nb10 == sizeof(float)) { @@ -6175,7 +6401,12 @@ static void ne_compute_forward_silu_f32(const struct ne_compute_params* params, NE_ASSERT(ne_is_contiguous_except_dim_1(src0)); NE_ASSERT(ne_is_contiguous_except_dim_1(dst)); NE_ASSERT(ne_are_same_shape(src0, dst)); - + if (dst->backend == NE_BACKEND_SYCL) { +#ifdef NS_SYCL + bestla_device_elewise_f32(params, src0, dst); +#endif + return; + } if (params->type == NE_TASK_INIT || params->type == NE_TASK_FINALIZE) { return; } @@ -6357,7 +6588,14 @@ static void ne_compute_forward_norm(const struct ne_compute_params* params, cons static void ne_compute_forward_rms_norm_f32(const struct ne_compute_params* params, const struct ne_tensor* src0, struct ne_tensor* dst) { NE_ASSERT(ne_are_same_shape(src0, dst)); - + if (src0->backend == NE_BACKEND_SYCL) { +#ifdef NS_SYCL + bestla_device_rms_norm_f32(params, src0, dst); +#else + NE_ASSERT(0); +#endif + return; + } if (params->type == NE_TASK_INIT || params->type == NE_TASK_FINALIZE) { return; } @@ -6764,7 +7002,7 @@ static void ne_compute_forward_mul_mat_f16_f32(const struct ne_compute_params* p // compute by src0 rows if (params->type == NE_TASK_INIT) { - ne_fp16_t* const wdata = params->wdata; + ne_fp16_t* const wdata = (ne_fp16_t* const)params->wdata; size_t id = 0; for (int64_t i13 = 0; i13 < ne13; ++i13) { @@ -6906,7 +7144,7 @@ static void ne_compute_forward_mul_mat_q_f32(const struct ne_compute_params* par // compute by src0 rows if (params->type == NE_TASK_INIT) { - char* wdata = params->wdata; + char* wdata = (char*)params->wdata; const size_t row_size = ne10 * NE_TYPE_SIZE[vec_dot_type] / NE_BLCK_SIZE[vec_dot_type]; for (int64_t i13 = 0; i13 < ne13; ++i13) { @@ -7044,16 +7282,37 @@ static void ne_compute_forward_mul_mat_q_f32_bestla(const struct ne_compute_para // nb01 >= nb00 - src0 is not transposed // compute by src0 rows + int8_t* devwptr = (int8_t*)params->dev_wdata; + float* actptr = src1->backend == NE_BACKEND_CPU ? (float*)devwptr : (float*)src1->data; + if (src1->backend == NE_BACKEND_CPU) { + devwptr += src1->size; + } if (params->type == NE_TASK_INIT) { +#ifdef NS_SYCL + if (params->ith == 0) { + if (dst->backend == NE_BACKEND_SYCL && src1->backend == NE_BACKEND_CPU) { + bestla_device_memcpy_sync(actptr, src1->data, src1->size, params->dev_queue); + } + } +#endif return; } if (params->type == NE_TASK_FINALIZE) { return; } - bestla_f32f32_forward((float*)src1->data, src0->data, (float*)dst->data, ne1, ne0, ne10, nb11 / ne_element_size(src1), - nb1 / ne_element_size(dst), params->wdata); + if (dst->backend == NE_BACKEND_SYCL) { +#ifdef NS_SYCL + bestla_device_f32f32_forward(actptr, src0->data, (float*)dst->data, ne1, ne0, ne10, nb11 / ne_element_size(src1), + nb1 / ne_element_size(dst), devwptr, params->dev_queue); +#else + NE_ASSERT(0); +#endif + } else { + bestla_f32f32_forward((float*)src1->data, src0->data, (float*)dst->data, ne1, ne0, ne10, + nb11 / ne_element_size(src1), nb1 / ne_element_size(dst), params->wdata); + } } static void ne_compute_forward_mul_mat(const struct ne_compute_params* params, const struct ne_tensor* src0, @@ -7160,7 +7419,7 @@ static void ne_compute_forward_mul_mat_id_q_f32(const struct ne_compute_params* if (ith != 0) { return; } - char* wdata = params->wdata; + char* wdata = (char*)params->wdata; const size_t row_size = ne10 * NE_TYPE_SIZE[vec_dot_type] / NE_BLCK_SIZE[vec_dot_type]; for (int64_t i13 = 0; i13 < ne13; ++i13) { for (int64_t i12 = 0; i12 < ne12; ++i12) { @@ -7437,7 +7696,7 @@ static void ne_compute_forward_mul_mat_id_f16_f32(const struct ne_compute_params // compute by src0 rows if (params->type == NE_TASK_INIT) { - ne_fp16_t* const wdata = params->wdata; + ne_fp16_t* const wdata = (ne_fp16_t* const)params->wdata; size_t id = 0; for (int64_t i13 = 0; i13 < ne13; ++i13) { @@ -8079,6 +8338,18 @@ static void ne_compute_forward_cont(const struct ne_compute_params* params, cons static void ne_compute_forward_reshape(const struct ne_compute_params* params, const struct ne_tensor* src0, struct ne_tensor* dst) { + if (src0->backend != dst->backend) { +#ifdef NS_SYCL + if (params->type == NE_TASK_INIT) { + if (params->ith == 0) { + bestla_device_sync(params->dev_queue); + bestla_device_memcpy_sync(dst->data, src0->data, src0->size, params->dev_queue); + } + } +#else + NE_ASSERT(0); +#endif + } // NOP UNUSED(params); UNUSED(src0); @@ -8201,7 +8472,7 @@ static void ne_compute_forward_get_rows_f16(const struct ne_compute_params* para for (int64_t i10 = 0; i10 < ne10; ++i10) { const int64_t i01 = *(int32_t*)((char*)src1->data + i10 * nb10 + i11 * nb11 + i12 * nb12); - ne_fp16_to_fp32_row((const void*)((char*)src0->data + i01 * nb01 + i11 * nb02 + i12 * nb03), + ne_fp16_to_fp32_row((const ne_fp16_t*)((char*)src0->data + i01 * nb01 + i11 * nb02 + i12 * nb03), (float*)((char*)dst->data + i10 * nb1 + i11 * nb2 + i12 * nb3), nc); } } @@ -8971,6 +9242,15 @@ void ggml_rope_yarn_corr_dims(int n_dims, int n_orig_ctx, float freq_base, float static void ne_compute_forward_rope_f32(const struct ne_compute_params* params, const struct ne_tensor* src0, const struct ne_tensor* src1, struct ne_tensor* dst) { + if (dst->backend == NE_BACKEND_SYCL) { + assert(src1->backend == NE_BACKEND_CPU); +#ifdef NS_SYCL + bestla_device_rope_f32(params, src0, src1, dst); +#else + NE_ASSERT(0); +#endif + return; + } if (params->type == NE_TASK_INIT || params->type == NE_TASK_FINALIZE) { return; } @@ -9214,9 +9494,9 @@ static void ne_compute_forward_rope_f16(const struct ne_compute_params* params, if (is_shift) { float theta = n_past * freq_scale; - ne_fp16_t* cossin = (dst->opt[0] != NULL) ? dst->opt[0]->data : NULL; + ne_fp16_t* cossin = (dst->opt[0] != NULL) ? (ne_fp16_t*)dst->opt[0]->data : NULL; if (cossin == NULL) { - cossin = malloc(ne0 * sizeof(ne_fp16_t)); + cossin = (ne_fp16_t*)malloc(ne0 * sizeof(ne_fp16_t)); for (int i0 = 0; i0 < ne0; i0 += 2) { cossin[i0 + 0] = NE_FP32_TO_FP16(cosf(theta)); cossin[i0 + 1] = NE_FP32_TO_FP16(sinf(theta)); @@ -9357,18 +9637,18 @@ static void ne_compute_forward_rope_bestla(const struct ne_compute_params* param const float freq_base = ((float*)(dst->op_params))[0]; const float freq_scale = 1 / ((float*)(dst->op_params))[1]; if (is_shift) { - ne_fp16_t* cossin = (dst->opt[0] != NULL) ? dst->opt[0]->data : NULL; + ne_fp16_t* cossin = (dst->opt[0] != NULL) ? (ne_fp16_t*)dst->opt[0]->data : NULL; if (cossin == NULL) { float theta = n_past * freq_scale; const float theta_scale = powf(freq_base, -2.0f / n_dims); - cossin = malloc(head_size * sizeof(ne_fp16_t)); + cossin = (ne_fp16_t*)malloc(head_size * sizeof(ne_fp16_t)); for (int i0 = 0; i0 < head_size; i0 += 2) { cossin[i0 + 0] = NE_FP32_TO_FP16(cosf(theta)); cossin[i0 + 1] = NE_FP32_TO_FP16(sinf(theta)); theta *= theta_scale; } } - bestla_reordered_attn_fp32_shift_rope_k(dst->data, cossin, batch_size, head_num, head_size, seq_len, n_keep); + bestla_reordered_attn_fp32_shift_rope_k((char*)dst->data, cossin, batch_size, head_num, head_size, seq_len, n_keep); if (dst->opt[0] == NULL) free(cossin); return; } @@ -9628,6 +9908,14 @@ static void ne_compute_forward_rope_back(const struct ne_compute_params* params, static void ne_compute_forward_flash_attn_f32(const struct ne_compute_params* params, const struct ne_tensor* q, const struct ne_tensor* k, const struct ne_tensor* v, const bool masked, struct ne_tensor* dst) { + if (dst->backend == NE_BACKEND_SYCL) { +#ifdef NS_SYCL + bestla_device_mha_f32(params, q, k, v, dst); +#else + NE_ASSERT(0); +#endif + return; + } int64_t t0 = ne_perf_time_us(); UNUSED(t0); @@ -9894,7 +10182,7 @@ static void ne_compute_forward_flash_attn_f32_f16_f16(const struct ne_compute_pa .K_sc = 1.f, .V_sc = 1.f, .dst_sc = 1.f, - .tmp = tmp->data, + .tmp = (char*)tmp->data, .QK_scale = scale, .attn_flags = flags, .batch_size = batch, @@ -9956,7 +10244,7 @@ static void ne_compute_forward_flash_attn_reordered(const struct ne_compute_para .K_sc = 1.f, .V_sc = 1.f, .dst_sc = 1.f, - .tmp = tmp->data, + .tmp = (char*)tmp->data, .QK_scale = scale, .attn_flags = flags, .batch_size = batch, @@ -10227,6 +10515,8 @@ static void ne_compute_forward_flash_attn(const struct ne_compute_params* params ne_compute_forward_flash_attn_f32_f16_f16(params, q, k, v, tmp, dst); } else if (k->type == NE_TYPE_BTLA && v->type == NE_TYPE_BTLA) { ne_compute_forward_flash_attn_reordered(params, q, k, v, tmp, dst); + } else if (k->type == NE_TYPE_F32) { + ne_compute_forward_flash_attn_f32(params, q, k, v, true, dst); } else { NE_ASSERT(false); } @@ -10241,14 +10531,14 @@ static void ne_compute_forward_flash_attn_kv_update(const struct ne_compute_para struct ne_tensor* dst) { if (params->type != NE_TASK_COMPUTE) return; NE_ASSERT(ne_nelements(dst->opt[0]) == 3); // 3 params - const int* p_data = dst->opt[0]->data; + const int* p_data = (const int*)dst->opt[0]->data; const int n_past = p_data[0]; const bool is_v = (bool)p_data[1]; const bool no_zeroing = (bool)p_data[2]; NE_ASSERT(cur->type == NE_TYPE_F32); bestla_fusion_attn_fp32_update_kv_args_t args = { - .src = cur->data, - .cache = cache->data, + .src = (float*)cur->data, + .cache = (char*)cache->data, .batch_size = cur->ne[3], .heads_kv = cur->ne[1], .head_size = cur->ne[0], @@ -10889,7 +11179,8 @@ static void ne_compute_backward(struct ne_context* ctx, struct ne_tensor* tensor int64_t ne[4] = {nc0, ncr, nr0, nrr}; struct ne_tensor* F00 = tensor->grad; - struct ne_tensor* F01 = ne_reshape(ctx, F00, ne_new_tensor(ctx, tensor->grad->type, 4, ne, NE_SIZE_CALC)); + struct ne_tensor* F01 = + ne_reshape(ctx, F00, ne_new_tensor(ctx, tensor->grad->type, 4, ne, NE_SIZE_CALC, tensor->grad->backend)); struct ne_tensor* F02 = ne_permute(ctx, F01, 0, 2, 1, 3); struct ne_tensor* F03 = ne_cont(ctx, F02); struct ne_tensor* F04 = ne_reshape_2d(ctx, F03, nc0 * nr0, ncr * nrr); @@ -11329,17 +11620,19 @@ void ne_build_forward_expand(struct ne_cgraph* cgraph, struct ne_tensor* tensor) struct ne_cgraph ne_build_forward(struct ne_tensor* tensor) { struct ne_cgraph result = { - /*.n_nodes =*/0, - /*.n_leafs =*/0, - /*.n_threads =*/NE_DEFAULT_N_THREADS, - /*.work_size =*/0, - /*.work =*/NULL, - /*.nodes =*/{NULL}, - /*.grads =*/{NULL}, - /*.leafs =*/{NULL}, - /*.perf_runs =*/0, - /*.perf_cycles =*/0, - /*.perf_time_us =*/0, + /*.n_nodes =*/0, + /*.n_leafs =*/0, + /*.n_threads =*/NE_DEFAULT_N_THREADS, + /*.work_size =*/0, + /*.work =*/NULL, + /*.dev_work_size =*/0, + /*.dev_work =*/NULL, + /*.nodes =*/{NULL}, + /*.grads =*/{NULL}, + /*.leafs =*/{NULL}, + /*.perf_runs =*/0, + /*.perf_cycles =*/0, + /*.perf_time_us =*/0, }; ne_build_forward_impl(&result, tensor, false); @@ -11392,6 +11685,233 @@ struct ne_cgraph ne_build_backward(struct ne_context* ctx, struct ne_cgraph* gf, // I tried using spin locks, but not sure how to use them correctly - the things I tried were slower than busy loops // +bool ne_support(struct ne_tensor* node, int n_threads, size_t* workspace, size_t* dev_workspace) { + size_t ws_h = 0; + size_t ws_d = 0; + bool support = false; + switch (node->op) { + case NE_OP_CPY: { + node->n_tasks = n_threads; // node->ne[0] == 1 ? n_threads : 1; + if (ne_is_quantized(node->type)) { + ws_h = NE_TYPE_SIZE[NE_TYPE_F32] * node->ne[0] * n_threads; + } + support = true; + } break; + case NE_OP_DUP: { + node->n_tasks = n_threads; + if (ne_is_quantized(node->type)) { + ws_h = NE_TYPE_SIZE[NE_TYPE_F32] * node->ne[0] * n_threads; + } + support = true; + } break; + case NE_OP_ADD: + case NE_OP_ADD1: { + if (node->src0->ne[1] > 4) { + node->n_tasks = n_threads; + } else { + node->n_tasks = 1; + } + + if (ne_is_quantized(node->src0->type)) { + ws_h = NE_TYPE_SIZE[NE_TYPE_F32] * node->src0->ne[0] * n_threads; + } + support = true; + } break; + case NE_OP_ACC: { + node->n_tasks = n_threads; + + if (ne_is_quantized(node->src0->type)) { + ws_h = NE_TYPE_SIZE[NE_TYPE_F32] * node->src1->ne[0] * n_threads; + } + support = true; + } break; + case NE_OP_SUB: + case NE_OP_SUM: + case NE_OP_DIV: + case NE_OP_SUM_ROWS: + case NE_OP_TANH: { + node->n_tasks = 1; + support = true; + } break; + case NE_OP_MUL: + case NE_OP_SQR: + case NE_OP_SQRT: + case NE_OP_LOG: + case NE_OP_MEAN: + case NE_OP_ABS: + case NE_OP_ARGSORT: + case NE_OP_SGN: + case NE_OP_NEG: + case NE_OP_STEP: + case NE_OP_RELU: { + if (node->src0->ne[1] > 4) { + node->n_tasks = n_threads; + } else { + node->n_tasks = 1; + } + support = true; + } break; + case NE_OP_NORM: + case NE_OP_RMS_NORM: { + if (node->src0->ne[1] > 4) { + node->n_tasks = n_threads; + } else { + node->n_tasks = 1; + } + support = true; + } break; + case NE_OP_GELU: + case NE_OP_SILU: + case NE_OP_SILU_BACK: + case NE_OP_RMS_NORM_BACK: { + node->n_tasks = n_threads; + support = true; + } break; + case NE_OP_MUL_MAT_ID: + case NE_OP_CONV_1D: + case NE_OP_MUL_MAT: { + node->n_tasks = n_threads; + + // TODO: use different scheduling for different matrix sizes + // const int nr0 = ne_nrows(node->src0); + // const int nr1 = ne_nrows(node->src1); + + // node->n_tasks = MIN(n_threads, MAX(1, nr0/128)); + // printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks = %d\n", nr0, nr1, nr0*nr1, node->n_tasks); + + struct ne_tensor* wei = node->src0; + if (node->op == NE_OP_MUL_MAT_ID) { + wei = node->opt[0]; + } + if (wei->type == NE_TYPE_F16 && node->src1->type == NE_TYPE_F32) { + ws_h = NE_TYPE_SIZE[NE_TYPE_F16] * ne_nelements(node->src1); + } else if (ne_is_quantized(wei->type) && node->src1->type == NE_TYPE_F32) { + const enum ne_type type_q = quantize_fns[wei->type].vec_dot_type; + ws_h = NE_TYPE_SIZE[type_q] * ne_nelements(node->src1) / NE_BLCK_SIZE[type_q]; + } else if (wei->type == NE_TYPE_F32 && node->src1->type == NE_TYPE_F32) { + ws_h = 0; + } else { + NE_ASSERT(false); + } + support = true; + } break; + case NE_OP_SCALE: { + node->n_tasks = 1; + support = true; + } break; + case NE_OP_SET: + case NE_OP_CONT: + case NE_OP_RESHAPE: + case NE_OP_VIEW: + case NE_OP_PERMUTE: + case NE_OP_TRANSPOSE: + case NE_OP_GET_ROWS: + case NE_OP_GET_ROWS_BACK: + case NE_OP_REPEAT: + case NE_OP_DIAG: + case NE_OP_DIAG_MASK_ZERO: { + node->n_tasks = 1; + support = true; + } break; + case NE_OP_DIAG_MASK_INF: + case NE_OP_PADDING_MASK_INF: + case NE_OP_ROPE: + node->n_tasks = n_threads; + support = true; + break; + case NE_OP_SOFT_MAX: { + size_t rows = ne_nrows(node->src0); + node->n_tasks = rows > 1 ? n_threads : 1; + support = true; + } break; + case NE_OP_ROPE_BACK: { + node->n_tasks = n_threads; + support = true; + } break; + case NE_OP_ALIBI: { + node->n_tasks = 1; // TODO + support = true; + } break; + case NE_OP_CLAMP: { + node->n_tasks = 1; // TODO + support = true; + } break; + case NE_OP_CONV_1D_1S: + case NE_OP_CONV_1D_2S: { + node->n_tasks = n_threads; + + NE_ASSERT(node->src0->ne[3] == 1); + NE_ASSERT(node->src1->ne[2] == 1); + NE_ASSERT(node->src1->ne[3] == 1); + + const int nk = node->src0->ne[0]; + + if (node->src0->type == NE_TYPE_F16 && node->src1->type == NE_TYPE_F32) { + ws_h = sizeof(ne_fp16_t) * (nk * ne_up32(node->src0->ne[1]) * node->src0->ne[2] + + (2 * (nk / 2) + node->src1->ne[0]) * node->src1->ne[1]); + } else if (node->src0->type == NE_TYPE_F32 && node->src1->type == NE_TYPE_F32) { + ws_h = sizeof(float) * (nk * ne_up32(node->src0->ne[1]) * node->src0->ne[2] + + (2 * (nk / 2) + node->src1->ne[0]) * node->src1->ne[1]); + } else { + NE_ASSERT(false); + } + support = true; + } break; + case NE_OP_FLASH_ATTN_KV_UPDATE: + case NE_OP_FLASH_ATTN: { + node->n_tasks = 1; + support = true; + } break; + case NE_OP_FLASH_FF: { + node->n_tasks = n_threads; + if (node->src1->type == NE_TYPE_F32) { + ws_h = sizeof(float) * node->src1->ne[1] * node->n_tasks; // TODO: this can become (n_tasks-1) + ws_h += sizeof(float) * node->src1->ne[1] * node->n_tasks; // this is overestimated by x2 + } + + if (node->src1->type == NE_TYPE_F16) { + ws_h = sizeof(float) * node->src1->ne[1] * node->n_tasks; // TODO: this can become (n_tasks-1) + ws_h += sizeof(float) * node->src1->ne[1] * node->n_tasks; // this is overestimated by x2 + } + support = true; + } break; + case NE_OP_MAP_UNARY: + case NE_OP_MAP_BINARY: { + node->n_tasks = 1; + support = true; + } break; + case NE_OP_NONE: { + node->n_tasks = 1; + support = true; + } break; + // split and all_reduce do not use thread pool + case NE_OP_SPLIT: + case NE_OP_ALL_REDUCE: + case NE_OP_TP_CONCAT: + case NE_OP_DEBUG: + case NE_OP_DUMP_TENSOR: { + node->n_tasks = 1; + support = true; + } break; + case NE_OP_COUNT: { + NE_ASSERT(false); + } break; + } + assert(node->backend == NE_BACKEND_CPU); + if (node->src0->backend == NE_BACKEND_SYCL) { + ws_h += node->src0->size; + } + if (node->src1 && node->src1->backend == NE_BACKEND_SYCL) { + ws_h += node->src1->size; + } + if (node->backend == NE_BACKEND_SYCL) { + ws_h += node->size; + } + *workspace = ws_h; + *dev_workspace = ws_d; + return support; +} + void ne_graph_compute(struct ne_context* ctx, struct ne_cgraph* cgraph) { int n_threads = cgraph->n_threads; @@ -11399,287 +11919,29 @@ void ne_graph_compute(struct ne_context* ctx, struct ne_cgraph* cgraph) { // initialize tasks + work buffer { size_t work_size = 0; + size_t dev_work_size = 0; // thread scheduling for the different operations for (int i = 0; i < cgraph->n_nodes; i++) { struct ne_tensor* node = cgraph->nodes[i]; - - switch (node->op) { - case NE_OP_CPY: { - node->n_tasks = n_threads; // node->ne[0] == 1 ? n_threads : 1; - size_t cur = 0; - if (ne_is_quantized(node->type)) { - cur = NE_TYPE_SIZE[NE_TYPE_F32] * node->ne[0] * n_threads; - } - work_size = MAX(work_size, cur); - } break; - case NE_OP_DUP: { - node->n_tasks = n_threads; - - size_t cur = 0; - if (ne_is_quantized(node->type)) { - cur = NE_TYPE_SIZE[NE_TYPE_F32] * node->ne[0] * n_threads; - } - - work_size = MAX(work_size, cur); - } break; - case NE_OP_ADD: { - if (ne_is_contiguous(node->src1) && ne_is_contiguous(node->src0) && - (ne_nrows(node->src1) == 1 || ne_nrows(node->src1) == ne_nrows(node->src0)) && - node->src0->ne[0] == node->src1->ne[0] && node->nb[0] == sizeof(float)) { - node->n_tasks = 1; - break; - } + size_t cur_work_size = 0; + size_t cur_dev_work_size = 0; + if (!bestla_support(node, n_threads, &cur_work_size, &cur_dev_work_size)) { + if (!ne_support(node, n_threads, &cur_work_size, &cur_dev_work_size)) { + NE_ASSERT(0); } - case NE_OP_ADD1: { - if (node->src0->ne[1] > 4) { - node->n_tasks = n_threads; + } + if (node->src0 && node->src1) { + if (node->src0->backend != node->src1->backend) { + if (node->src1->backend == NE_BACKEND_CPU) { + cur_dev_work_size += node->src1->size; } else { - node->n_tasks = 1; - } - - size_t cur = 0; - - if (ne_is_quantized(node->src0->type)) { - cur = NE_TYPE_SIZE[NE_TYPE_F32] * node->src0->ne[0] * n_threads; - } - - work_size = MAX(work_size, cur); - } break; - case NE_OP_ACC: { - node->n_tasks = n_threads; - - size_t cur = 0; - - if (ne_is_quantized(node->src0->type)) { - cur = NE_TYPE_SIZE[NE_TYPE_F32] * node->src1->ne[0] * n_threads; - } - - work_size = MAX(work_size, cur); - } break; - case NE_OP_SUB: - case NE_OP_SUM: - case NE_OP_DIV: - case NE_OP_SUM_ROWS: - case NE_OP_TANH: { - node->n_tasks = 1; - } break; - case NE_OP_MUL: { - if (ne_is_contiguous(node->src1) && ne_is_contiguous(node->src0) && - (ne_nrows(node->src1) == 1 || ne_nrows(node->src1) == ne_nrows(node->src0)) && - node->src0->ne[0] == node->src1->ne[0] && node->nb[0] == sizeof(float)) { - node->n_tasks = 1; - break; + cur_work_size += node->src1->size; } } - case NE_OP_SQR: - case NE_OP_SQRT: - case NE_OP_LOG: - case NE_OP_MEAN: - case NE_OP_ABS: - case NE_OP_ARGSORT: - case NE_OP_SGN: - case NE_OP_NEG: - case NE_OP_STEP: - case NE_OP_RELU: { - if (node->src0->ne[1] > 4) { - node->n_tasks = n_threads; - } else { - node->n_tasks = 1; - } - } break; - case NE_OP_NORM: - case NE_OP_RMS_NORM: { - if (ne_is_contiguous(node->src0)) { - node->n_tasks = 1; - } else { - if (node->src0->ne[1] > 4) { - node->n_tasks = n_threads; - } else { - node->n_tasks = 1; - } - } - } break; - case NE_OP_GELU: - case NE_OP_SILU: - case NE_OP_SILU_BACK: - case NE_OP_RMS_NORM_BACK: { - node->n_tasks = n_threads; - } break; - case NE_OP_MUL_MAT_BIAS: - case NE_OP_MUL_MAT_ID: - case NE_OP_CONV_1D: - case NE_OP_MUL_MAT: { - node->n_tasks = n_threads; - - // TODO: use different scheduling for different matrix sizes - // const int nr0 = ne_nrows(node->src0); - // const int nr1 = ne_nrows(node->src1); - - // node->n_tasks = MIN(n_threads, MAX(1, nr0/128)); - // printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks = %d\n", nr0, nr1, nr0*nr1, node->n_tasks); - - size_t cur = 0; - struct ne_tensor* wei = node->src0; - if (node->op == NE_OP_MUL_MAT_ID) { - wei = node->opt[0]; - } - if (wei->type == NE_TYPE_BTLA) { - cur = bestla_f32f32_get_workspace_size(node->src1->ne[1], wei->ne[1], node->src1->ne[0], wei->data); - node->n_tasks = 1; - } else if (wei->type == NE_TYPE_F16 && node->src1->type == NE_TYPE_F32) { - cur = NE_TYPE_SIZE[NE_TYPE_F16] * ne_nelements(node->src1); - } else if (wei->type == NE_TYPE_F32 && node->src1->type == NE_TYPE_F32) { - cur = 0; - } else if (ne_is_quantized(wei->type) && node->src1->type == NE_TYPE_F32) { - { - const enum ne_type type_q = quantize_fns[wei->type].vec_dot_type; - cur = NE_TYPE_SIZE[type_q] * ne_nelements(node->src1) / NE_BLCK_SIZE[type_q]; - } - } else { - NE_ASSERT(false); - } - - work_size = MAX(work_size, cur); - } break; - case NE_OP_MUL_FFN_SILU: - case NE_OP_MUL_FFN_GELU: - case NE_OP_MUL_FFN_GELU_MUL: - case NE_OP_MUL_FFN_ADD_GELU: { - size_t cur = 0; - cur = bestla_fusion_FFN_f32f32_get_workspace_size(node->src0->ne[1], node->src0->ne[0], node->src1->ne[1], - node->opt[0]->ne[1], node->src1->data, node->opt[0]->data); - work_size = MAX(work_size, cur); - node->n_tasks = 1; - } break; - case NE_OP_MUL_ID_FFN_GELU: - case NE_OP_MUL_ID_FFN_SILU: { - size_t cur = 0; - cur = - bestla_fusion_FFN_f32f32_get_workspace_size(node->src0->ne[1], node->src0->ne[0], node->opt[0]->ne[1], - node->opt[9]->ne[1], node->opt[0]->data, node->opt[9]->data); - work_size = MAX(work_size, cur); - node->n_tasks = 1; - } break; - case NE_OP_MUL_QKV: { - size_t cur = 0; - cur = bestla_fusion_QKV_f32f32_get_workspace_size(node->src0->ne[1], node->src1->ne[1], node->src1->ne[0], - node->src1->data); - work_size = MAX(work_size, cur); - node->n_tasks = 1; - } break; - case NE_OP_SCALE: { - node->n_tasks = 1; - } break; - case NE_OP_SET: - case NE_OP_CONT: - case NE_OP_RESHAPE: - case NE_OP_VIEW: - case NE_OP_PERMUTE: - case NE_OP_TRANSPOSE: - case NE_OP_GET_ROWS: - case NE_OP_GET_ROWS_BACK: - case NE_OP_REPEAT: - case NE_OP_DIAG: - case NE_OP_DIAG_MASK_ZERO: { - node->n_tasks = 1; - } break; - case NE_OP_DIAG_MASK_INF: - case NE_OP_PADDING_MASK_INF: - case NE_OP_ROPE: - // only first token use parallel - if (node->type == NE_TYPE_BTLA) - node->n_tasks = 1; - else - node->n_tasks = n_threads; - break; - case NE_OP_SOFT_MAX: { - size_t rows = ne_nrows(node->src0); - node->n_tasks = rows > 1 ? n_threads : 1; - } break; - case NE_OP_ROPE_BACK: { - node->n_tasks = n_threads; - } break; - case NE_OP_ALIBI: { - node->n_tasks = 1; // TODO - } break; - case NE_OP_CLAMP: { - node->n_tasks = 1; // TODO - } break; - case NE_OP_CONV_1D_1S: - case NE_OP_CONV_1D_2S: { - node->n_tasks = n_threads; - - NE_ASSERT(node->src0->ne[3] == 1); - NE_ASSERT(node->src1->ne[2] == 1); - NE_ASSERT(node->src1->ne[3] == 1); - - size_t cur = 0; - const int nk = node->src0->ne[0]; - - if (node->src0->type == NE_TYPE_F16 && node->src1->type == NE_TYPE_F32) { - cur = sizeof(ne_fp16_t) * (nk * ne_up32(node->src0->ne[1]) * node->src0->ne[2] + - (2 * (nk / 2) + node->src1->ne[0]) * node->src1->ne[1]); - } else if (node->src0->type == NE_TYPE_F32 && node->src1->type == NE_TYPE_F32) { - cur = sizeof(float) * (nk * ne_up32(node->src0->ne[1]) * node->src0->ne[2] + - (2 * (nk / 2) + node->src1->ne[0]) * node->src1->ne[1]); - } else { - NE_ASSERT(false); - } - - work_size = MAX(work_size, cur); - } break; - case NE_OP_FLASH_ATTN_KV_UPDATE: - case NE_OP_FLASH_ATTN: { - node->n_tasks = 1; - work_size = 0LL; - } break; - case NE_OP_FLASH_FF: { - node->n_tasks = n_threads; - - size_t cur = 0; - - if (node->src1->type == NE_TYPE_F32) { - cur = sizeof(float) * node->src1->ne[1] * node->n_tasks; // TODO: this can become (n_tasks-1) - cur += sizeof(float) * node->src1->ne[1] * node->n_tasks; // this is overestimated by x2 - } - - if (node->src1->type == NE_TYPE_F16) { - cur = sizeof(float) * node->src1->ne[1] * node->n_tasks; // TODO: this can become (n_tasks-1) - cur += sizeof(float) * node->src1->ne[1] * node->n_tasks; // this is overestimated by x2 - } - - work_size = MAX(work_size, cur); - } break; - case NE_OP_MAP_UNARY: - case NE_OP_MAP_BINARY: { - node->n_tasks = 1; - } break; - case NE_OP_NONE: { - node->n_tasks = 1; - } break; - case NE_OP_COUNT: { - NE_ASSERT(false); - } break; - // split and all_reduce do not use thread pool - case NE_OP_SPLIT: - case NE_OP_ALL_REDUCE: - case NE_OP_TP_CONCAT: - case NE_OP_DEBUG: - case NE_OP_DUMP_TENSOR: { - node->n_tasks = 1; - } break; - // case NE_OP_TP_CONCAT: { - // node->n_tasks = n_threads; - - // size_t cur = 0; - // if (ne_is_quantized(node->type)) { - // cur = NE_TYPE_SIZE[NE_TYPE_F32] * node->ne[0] * n_threads; - // } - - // work_size = MAX(work_size, cur); - // } break; } + work_size = MAX(work_size, cur_work_size); + dev_work_size = MAX(dev_work_size, cur_dev_work_size); } if (cgraph->work != NULL && work_size > cgraph->work_size) { @@ -11690,7 +11952,18 @@ void ne_graph_compute(struct ne_context* ctx, struct ne_cgraph* cgraph) { cgraph->work_size = work_size + CACHE_LINE_SIZE * (n_threads - 1); NE_PRINT_DEBUG("%s: allocating work buffer for graph (%zu bytes)\n", __func__, cgraph->work_size); - cgraph->work = ne_new_tensor_1d(ctx, NE_TYPE_I8, cgraph->work_size, NE_SIZE_CALC); + cgraph->work = ne_new_tensor_1d(ctx, NE_TYPE_I8, cgraph->work_size, NE_SIZE_CALC, NE_BACKEND_CPU); + } + + if (cgraph->dev_work != NULL && dev_work_size > cgraph->dev_work_size) { + NE_ASSERT(false); // TODO: better handling + } + + if (dev_work_size > 0 && cgraph->dev_work == NULL) { + cgraph->dev_work_size = dev_work_size; + + NE_PRINT_DEBUG("%s: allocating work buffer for graph (%zu bytes)\n", __func__, cgraph->work_size); + cgraph->dev_work = ne_new_tensor_1d(ctx, NE_TYPE_I8, cgraph->dev_work_size, NE_SIZE_CALC, NE_BACKEND_SYCL); } } @@ -11713,13 +11986,14 @@ void ne_graph_compute(struct ne_context* ctx, struct ne_cgraph* cgraph) { bestla_timer(true); #endif // INIT - struct ne_compute_params params = { - /*.type =*/NE_TASK_INIT, - /*.ith =*/0, - /*.nth =*/node->n_tasks, - /*.wsize =*/cgraph->work ? ne_nbytes(cgraph->work) : 0, - /*.wdata =*/cgraph->work ? cgraph->work->data : NULL, - }; + struct ne_compute_params params = {/*.type =*/NE_TASK_INIT, + /*.ith =*/0, + /*.nth =*/node->n_tasks, + /*.wsize =*/cgraph->work ? ne_nbytes(cgraph->work) : 0, + /*.wdata =*/cgraph->work ? cgraph->work->data : NULL, + /*.dev_wsize =*/cgraph->dev_work ? cgraph->dev_work_size : 0, + /*.dev_wdata =*/cgraph->dev_work ? cgraph->dev_work->data : NULL, + /*.dev_queue =*/ctx->dev_ctx ? ctx->dev_ctx->queue : NULL}; bestla_parallel_for(ne_compute_forward, ¶ms, node); #if NE_DEBUG @@ -12053,16 +12327,18 @@ static enum ne_opt_result ne_opt_adam(struct ne_context* ctx, struct ne_opt_para const float beta2 = params.adam.beta2; const float eps = params.adam.eps; - float* x = ne_new_tensor_1d(ctx, NE_TYPE_F32, nx, NE_SIZE_CALC)->data; // view of the parameters - float* g1 = ne_new_tensor_1d(ctx, NE_TYPE_F32, nx, NE_SIZE_CALC)->data; // gradient - float* g2 = ne_new_tensor_1d(ctx, NE_TYPE_F32, nx, NE_SIZE_CALC)->data; // gradient squared - float* m = ne_new_tensor_1d(ctx, NE_TYPE_F32, nx, NE_SIZE_CALC)->data; // first moment - float* v = ne_new_tensor_1d(ctx, NE_TYPE_F32, nx, NE_SIZE_CALC)->data; // second moment - float* mh = ne_new_tensor_1d(ctx, NE_TYPE_F32, nx, NE_SIZE_CALC)->data; // first moment hat - float* vh = ne_new_tensor_1d(ctx, NE_TYPE_F32, nx, NE_SIZE_CALC)->data; // second moment hat + float* x = + (float*)ne_new_tensor_1d(ctx, NE_TYPE_F32, nx, NE_SIZE_CALC, NE_BACKEND_CPU)->data; // view of the parameters + float* g1 = (float*)ne_new_tensor_1d(ctx, NE_TYPE_F32, nx, NE_SIZE_CALC, NE_BACKEND_CPU)->data; // gradient + float* g2 = (float*)ne_new_tensor_1d(ctx, NE_TYPE_F32, nx, NE_SIZE_CALC, NE_BACKEND_CPU)->data; // gradient squared + float* m = (float*)ne_new_tensor_1d(ctx, NE_TYPE_F32, nx, NE_SIZE_CALC, NE_BACKEND_CPU)->data; // first moment + float* v = (float*)ne_new_tensor_1d(ctx, NE_TYPE_F32, nx, NE_SIZE_CALC, NE_BACKEND_CPU)->data; // second moment + float* mh = (float*)ne_new_tensor_1d(ctx, NE_TYPE_F32, nx, NE_SIZE_CALC, NE_BACKEND_CPU)->data; // first moment hat + float* vh = (float*)ne_new_tensor_1d(ctx, NE_TYPE_F32, nx, NE_SIZE_CALC, NE_BACKEND_CPU)->data; // second moment hat - float* pf = params.past > 0 ? ne_new_tensor_1d(ctx, NE_TYPE_F32, params.past, NE_SIZE_CALC)->data - : NULL; // past function values + float* pf = params.past > 0 + ? (float*)ne_new_tensor_1d(ctx, NE_TYPE_F32, params.past, NE_SIZE_CALC, NE_BACKEND_CPU)->data + : NULL; // past function values // initialize ne_vec_set_f32(nx, m, 0.0f); @@ -12262,7 +12538,7 @@ static enum ne_opt_result linesearch_backtracking(struct ne_context* ctx, const } else { // Armijo condition is satisfied if (params->lbfgs.linesearch == NE_LINESEARCH_BACKTRACKING_ARMIJO) { - return count; + return (enum ne_opt_result)count; } ne_vec_dot_f32(nx, &dg, g, d); @@ -12273,16 +12549,16 @@ static enum ne_opt_result linesearch_backtracking(struct ne_context* ctx, const } else { if (params->lbfgs.linesearch == NE_LINESEARCH_BACKTRACKING_WOLFE) { // regular Wolfe conditions - return count; + return (enum ne_opt_result)count; } if (dg > -params->lbfgs.wolfe * dginit) { width = dec; } else { // strong Wolfe condition (NE_LINESEARCH_BACKTRACKING_STRONG_WOLFE) - return count; + return (enum ne_opt_result)count; } - return count; + return (enum ne_opt_result)count; } } @@ -12332,14 +12608,16 @@ static enum ne_opt_result ne_opt_lbfgs(struct ne_context* ctx, struct ne_opt_par } } - float* x = ne_new_tensor_1d(ctx, NE_TYPE_F32, nx, NE_SIZE_CALC)->data; // current parameters - float* xp = ne_new_tensor_1d(ctx, NE_TYPE_F32, nx, NE_SIZE_CALC)->data; // previous parameters - float* g = ne_new_tensor_1d(ctx, NE_TYPE_F32, nx, NE_SIZE_CALC)->data; // current gradient - float* gp = ne_new_tensor_1d(ctx, NE_TYPE_F32, nx, NE_SIZE_CALC)->data; // previous gradient - float* d = ne_new_tensor_1d(ctx, NE_TYPE_F32, nx, NE_SIZE_CALC)->data; // search direction + float* x = (float*)ne_new_tensor_1d(ctx, NE_TYPE_F32, nx, NE_SIZE_CALC, NE_BACKEND_CPU)->data; // current parameters + float* xp = + (float*)ne_new_tensor_1d(ctx, NE_TYPE_F32, nx, NE_SIZE_CALC, NE_BACKEND_CPU)->data; // previous parameters + float* g = (float*)ne_new_tensor_1d(ctx, NE_TYPE_F32, nx, NE_SIZE_CALC, NE_BACKEND_CPU)->data; // current gradient + float* gp = (float*)ne_new_tensor_1d(ctx, NE_TYPE_F32, nx, NE_SIZE_CALC, NE_BACKEND_CPU)->data; // previous gradient + float* d = (float*)ne_new_tensor_1d(ctx, NE_TYPE_F32, nx, NE_SIZE_CALC, NE_BACKEND_CPU)->data; // search direction - float* pf = params.past > 0 ? ne_new_tensor_1d(ctx, NE_TYPE_F32, params.past, NE_SIZE_CALC)->data - : NULL; // past function values + float* pf = params.past > 0 + ? (float*)ne_new_tensor_1d(ctx, NE_TYPE_F32, params.past, NE_SIZE_CALC, NE_BACKEND_CPU)->data + : NULL; // past function values float fx = 0.0f; // cost function value float xnorm = 0.0f; // ||x|| @@ -12350,13 +12628,14 @@ static enum ne_opt_result ne_opt_lbfgs(struct ne_context* ctx, struct ne_opt_par ne_opt_get_params(np, ps, x); // the L-BFGS memory - struct ne_lbfgs_iteration_data* lm = alloca(sizeof(struct ne_lbfgs_iteration_data) * m); + struct ne_lbfgs_iteration_data* lm = + (struct ne_lbfgs_iteration_data*)alloca(sizeof(struct ne_lbfgs_iteration_data) * m); for (int i = 0; i < m; ++i) { lm[i].alpha = 0.0f; lm[i].ys = 0.0f; - lm[i].s = ne_new_tensor_1d(ctx, NE_TYPE_F32, nx, NE_SIZE_CALC)->data; - lm[i].y = ne_new_tensor_1d(ctx, NE_TYPE_F32, nx, NE_SIZE_CALC)->data; + lm[i].s = (float*)ne_new_tensor_1d(ctx, NE_TYPE_F32, nx, NE_SIZE_CALC, NE_BACKEND_CPU)->data; + lm[i].y = (float*)ne_new_tensor_1d(ctx, NE_TYPE_F32, nx, NE_SIZE_CALC, NE_BACKEND_CPU)->data; } // evaluate the function value and its gradient @@ -12420,7 +12699,7 @@ static enum ne_opt_result ne_opt_lbfgs(struct ne_context* ctx, struct ne_opt_par ne_vec_cpy_f32(nx, x, xp); ne_vec_cpy_f32(nx, g, gp); - return ls; + return (enum ne_opt_result)ls; } ne_vec_norm_f32(nx, &xnorm, x); diff --git a/neural_speed/core/ne_layers.h b/neural_speed/core/ne_layers.h index 3767ed5cf..13c8f1e7b 100644 --- a/neural_speed/core/ne_layers.h +++ b/neural_speed/core/ne_layers.h @@ -112,27 +112,32 @@ NE_API void ne_free(struct ne_context* ctx); NE_API size_t ne_used_mem(const struct ne_context* ctx); +NE_API void ne_buffer_save(struct ne_context* ctx); + +NE_API void ne_buffer_load(struct ne_context* ctx); + NE_API size_t ne_set_scratch(struct ne_context* ctx, struct ne_scratch scratch); NE_API struct ne_tensor* ne_new_tensor(struct ne_context* ctx, enum ne_type type, int n_dims, const int64_t* ne, - size_t size); + size_t size, enum ne_backend bk); -NE_API struct ne_tensor* ne_new_tensor_1d(struct ne_context* ctx, enum ne_type type, int64_t ne0, size_t size); +NE_API struct ne_tensor* ne_new_tensor_1d(struct ne_context* ctx, enum ne_type type, int64_t ne0, size_t size, + enum ne_backend bk); NE_API struct ne_tensor* ne_new_tensor_2d(struct ne_context* ctx, enum ne_type type, int64_t ne0, int64_t ne1, - size_t size); + size_t size, enum ne_backend bk); NE_API struct ne_tensor* ne_new_tensor_3d(struct ne_context* ctx, enum ne_type type, int64_t ne0, int64_t ne1, - int64_t ne2, size_t size); + int64_t ne2, size_t size, enum ne_backend bk); NE_API struct ne_tensor* ne_new_tensor_4d(struct ne_context* ctx, enum ne_type type, int64_t ne0, int64_t ne1, - int64_t ne2, int64_t ne3, size_t size); + int64_t ne2, int64_t ne3, size_t size, enum ne_backend bk); -#define d_ne_new_tensor(...) ne_new_tensor(__VA_ARGS__, NE_SIZE_CALC) -#define d_ne_new_tensor_1d(...) ne_new_tensor_1d(__VA_ARGS__, NE_SIZE_CALC) -#define d_ne_new_tensor_2d(...) ne_new_tensor_2d(__VA_ARGS__, NE_SIZE_CALC) -#define d_ne_new_tensor_3d(...) ne_new_tensor_3d(__VA_ARGS__, NE_SIZE_CALC) -#define d_ne_new_tensor_4d(...) ne_new_tensor_4d(__VA_ARGS__, NE_SIZE_CALC) +#define d_ne_new_tensor(...) ne_new_tensor(__VA_ARGS__, NE_SIZE_CALC, NE_BACKEND_CPU) +#define d_ne_new_tensor_1d(...) ne_new_tensor_1d(__VA_ARGS__, NE_SIZE_CALC, NE_BACKEND_CPU) +#define d_ne_new_tensor_2d(...) ne_new_tensor_2d(__VA_ARGS__, NE_SIZE_CALC, NE_BACKEND_CPU) +#define d_ne_new_tensor_3d(...) ne_new_tensor_3d(__VA_ARGS__, NE_SIZE_CALC, NE_BACKEND_CPU) +#define d_ne_new_tensor_4d(...) ne_new_tensor_4d(__VA_ARGS__, NE_SIZE_CALC, NE_BACKEND_CPU) NE_API struct ne_tensor* ne_new_i32(struct ne_context* ctx, int32_t value); NE_API struct ne_tensor* ne_new_f32(struct ne_context* ctx, float value); @@ -140,6 +145,9 @@ NE_API struct ne_tensor* ne_new_f32(struct ne_context* ctx, float value); NE_API struct ne_tensor* ne_dup_tensor(struct ne_context* ctx, const struct ne_tensor* src); NE_API struct ne_tensor* ne_view_tensor(struct ne_context* ctx, const struct ne_tensor* src); +NE_API struct ne_tensor* ne_dup_tensor_bk(struct ne_context* ctx, const struct ne_tensor* src, enum ne_backend bk); +NE_API struct ne_tensor* ne_view_tensor_bk(struct ne_context* ctx, const struct ne_tensor* src, enum ne_backend bk); + NE_API struct ne_tensor* ne_set_zero(struct ne_tensor* tensor); NE_API struct ne_tensor* ne_set_i32(struct ne_tensor* tensor, int32_t value); NE_API struct ne_tensor* ne_set_f32(struct ne_tensor* tensor, float value); @@ -338,6 +346,9 @@ NE_API struct ne_tensor* ne_reshape_3d(struct ne_context* ctx, struct ne_tensor* NE_API struct ne_tensor* ne_reshape_4d(struct ne_context* ctx, struct ne_tensor* a, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3); +// If a is a device tensor, sync it to a host tensor. If a is a host tensor, it equals reshape(a). +NE_API struct ne_tensor* ne_device_sync(struct ne_context* ctx, struct ne_tensor* a, enum ne_backend bk); + // offset in bytes NE_API struct ne_tensor* ne_view_1d(struct ne_context* ctx, struct ne_tensor* a, int64_t ne0, size_t offset); diff --git a/neural_speed/models/bloom/bloom.cpp b/neural_speed/models/bloom/bloom.cpp index d24c2c23a..d83e1319f 100644 --- a/neural_speed/models/bloom/bloom.cpp +++ b/neural_speed/models/bloom/bloom.cpp @@ -170,8 +170,7 @@ static bool bloom_model_eval_internal(model_context* ctx, const model_input* inp // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3) struct ne_tensor* Q = ne_permute( - ctx0, ne_cpy(ctx0, Qcur, ne_new_tensor_3d(ctx0, NE_TYPE_F32, n_embd / n_head, n_head, N, NE_SIZE_CALC)), 0, 2, - 1, 3); + ctx0, ne_cpy(ctx0, Qcur, d_ne_new_tensor_3d(ctx0, NE_TYPE_F32, n_embd / n_head, n_head, N)), 0, 2, 1, 3); // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3) struct ne_tensor* K = ne_permute(ctx0, @@ -207,7 +206,7 @@ static bool bloom_model_eval_internal(model_context* ctx, const model_input* inp il * n_ctx * ne_element_size(kv_self.v) * n_embd), n_embd / n_head, n_head, n_past + N), 1, 2, 0, 3), - ne_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd / n_head, n_head, NE_SIZE_CALC)); + d_ne_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd / n_head, n_head)); // KQV = transpose(V) * KQ_soft_max struct ne_tensor* KQV = ne_mul_mat(ctx0, V_trans, KQ_soft_max); @@ -215,7 +214,7 @@ static bool bloom_model_eval_internal(model_context* ctx, const model_input* inp struct ne_tensor* KQV_merged = ne_permute(ctx0, KQV, 0, 2, 1, 3); // cur = KQV_merged.contiguous().view(n_embd, N) - cur = ne_cpy(ctx0, KQV_merged, ne_new_tensor_2d(ctx0, NE_TYPE_F32, n_embd, N, NE_SIZE_CALC)); + cur = ne_cpy(ctx0, KQV_merged, d_ne_new_tensor_2d(ctx0, NE_TYPE_F32, n_embd, N)); } else { const auto seq_kv = n_past + N; const auto k_size = kv_cache_info.k_bytes; diff --git a/neural_speed/models/chatglm/chatglm.cpp b/neural_speed/models/chatglm/chatglm.cpp index 51145974a..d75bf4804 100644 --- a/neural_speed/models/chatglm/chatglm.cpp +++ b/neural_speed/models/chatglm/chatglm.cpp @@ -203,8 +203,7 @@ static bool chatglm_model_eval_internal(model_context* ctx, const model_input* i if (n_past == 0) { // build attention mask for context input - ne_tensor* inf = - ne_new_tensor_4d(ctx0, attn_scores->type, 1, qlen - 1, num_attention_heads, batch_size, NE_SIZE_CALC); + ne_tensor* inf = d_ne_new_tensor_4d(ctx0, attn_scores->type, 1, qlen - 1, num_attention_heads, batch_size); ne_set_f32(inf, -INFINITY); ne_tensor* masked_attn_scores = diff --git a/neural_speed/models/falcon/falcon.cpp b/neural_speed/models/falcon/falcon.cpp index de9304fb1..aa981121d 100644 --- a/neural_speed/models/falcon/falcon.cpp +++ b/neural_speed/models/falcon/falcon.cpp @@ -222,7 +222,7 @@ static bool falcon_model_eval_internal(model_context* ctx, const model_input* in struct ne_tensor* KQV_merged = ne_permute(ctx0, KQV, 0, 2, 1, 3); // cur = KQV_merged.contiguous().view(n_embd, N) - cur = ne_cpy(ctx0, KQV_merged, ne_new_tensor_2d(ctx0, NE_TYPE_F32, n_embd, N, NE_SIZE_CALC)); + cur = ne_cpy(ctx0, KQV_merged, d_ne_new_tensor_2d(ctx0, NE_TYPE_F32, n_embd, N)); } else { // Using MHA (GQA/MQA) managed kv-cache const auto seq_kv = n_past + N; const auto k_size = kv_cache_info.k_bytes; @@ -272,7 +272,7 @@ static bool falcon_model_eval_internal(model_context* ctx, const model_input* in lctx.use_buf(ctx0, 1); struct ne_tensor* inpFF = layernorm_output; - struct ne_tensor* attn_out = ne_cpy(ctx0, cur, ne_new_tensor_2d(ctx0, NE_TYPE_F32, n_embd, N, NE_SIZE_CALC)); + struct ne_tensor* attn_out = ne_cpy(ctx0, cur, d_ne_new_tensor_2d(ctx0, NE_TYPE_F32, n_embd, N)); // FFN (pre_layer_norm output) { diff --git a/neural_speed/models/gemma/gemma.cpp b/neural_speed/models/gemma/gemma.cpp index 1bd069e1d..78c920eb6 100644 --- a/neural_speed/models/gemma/gemma.cpp +++ b/neural_speed/models/gemma/gemma.cpp @@ -255,8 +255,7 @@ static bool gemma_model_eval_internal(model_context* ctx, const model_input* inp struct ne_tensor* KQV_merged = ne_permute(ctx0, KQV, 0, 2, 1, 3); // cur = KQV_merged.contiguous().view(n_gqa_embd, N) - cur = ne_cpy(ctx0, KQV_merged, - ne_new_tensor_2d(ctx0, NE_TYPE_F32, head_dim * n_head, N * batch_size, NE_SIZE_CALC)); + cur = ne_cpy(ctx0, KQV_merged, d_ne_new_tensor_2d(ctx0, NE_TYPE_F32, head_dim * n_head, N * batch_size)); } else { const auto seq_kv = n_past + N; const auto k_size = kv_cache_info.k_bytes; diff --git a/neural_speed/models/gptj/gptj.cpp b/neural_speed/models/gptj/gptj.cpp index ef0ee67a0..327cdc8a9 100644 --- a/neural_speed/models/gptj/gptj.cpp +++ b/neural_speed/models/gptj/gptj.cpp @@ -344,8 +344,7 @@ static bool gptj_model_eval_internal(model_context* ctx, const model_input* inpu } // for-loop self-attention - struct ne_tensor* KQV_merged_contiguous = - ne_new_tensor_2d(ctx0, NE_TYPE_F32, head_size * n_head, seq_len_sum, NE_SIZE_CALC); + struct ne_tensor* KQV_merged_contiguous = d_ne_new_tensor_2d(ctx0, NE_TYPE_F32, head_size * n_head, seq_len_sum); size_t off_sl = 0; for (int gi = 0; gi < infer_groups.size(); ++gi) { const int attn_bs = infer_groups[gi].size(); @@ -453,7 +452,7 @@ static bool gptj_model_eval_internal(model_context* ctx, const model_input* inpu } else if (attn_n_total == 0 && run_mha_bf16_first) { // non-reordered kv-cache bf16 mha (first token only) auto vnele = ne_nelements(Vcur); - struct ne_tensor* Vtmp = ne_new_tensor_1d(ctx0, NE_TYPE_F16, vnele, NE_SIZE_CALC); + struct ne_tensor* Vtmp = d_ne_new_tensor_1d(ctx0, NE_TYPE_F16, vnele); Vtmp = ne_cpy(ctx0, ne_view_1d(ctx0, Vcur, vnele, 0), Vtmp); Vtmp = ne_view_4d(ctx0, Vtmp, head_size, n_head, attn_sl, attn_bs, ne_element_size(Vtmp) * head_size, ne_element_size(Vtmp) * head_size * n_head, diff --git a/neural_speed/models/gptneox/gptneox.cpp b/neural_speed/models/gptneox/gptneox.cpp index ce2e992a6..d9a03825c 100644 --- a/neural_speed/models/gptneox/gptneox.cpp +++ b/neural_speed/models/gptneox/gptneox.cpp @@ -262,7 +262,7 @@ static bool gptneox_model_eval_internal(model_context* ctx, const model_input* i struct ne_tensor* KQV_merged = ne_permute(ctx0, KQV, 0, 2, 1, 3); // cur = KQV_merged.contiguous().view(n_embd, N) - cur = ne_cpy(ctx0, KQV_merged, ne_new_tensor_2d(ctx0, NE_TYPE_F32, n_embd, N * batch_size, NE_SIZE_CALC)); + cur = ne_cpy(ctx0, KQV_merged, d_ne_new_tensor_2d(ctx0, NE_TYPE_F32, n_embd, N * batch_size)); } else { const auto seq_kv = n_past + N; const auto k_size = kv_cache_info.k_bytes; diff --git a/neural_speed/models/grok/grok.cpp b/neural_speed/models/grok/grok.cpp index 89c02dee9..1c752bab1 100644 --- a/neural_speed/models/grok/grok.cpp +++ b/neural_speed/models/grok/grok.cpp @@ -227,7 +227,7 @@ static bool grok_model_eval_internal(model_context* ctx, const model_input* inpu struct ne_tensor* KQV_merged = ne_permute(ctx0, KQV, 0, 2, 1, 3); // cur = KQV_merged.contiguous().view(n_embd, N) - cur = ne_cpy(ctx0, KQV_merged, ne_new_tensor_2d(ctx0, NE_TYPE_F32, n_embd, N * batch_size, NE_SIZE_CALC)); + cur = ne_cpy(ctx0, KQV_merged, d_ne_new_tensor_2d(ctx0, NE_TYPE_F32, n_embd, N * batch_size)); } else { const auto seq_kv = n_past + N; const auto k_size = kv_cache_info.k_bytes; diff --git a/neural_speed/models/llama/llama.cpp b/neural_speed/models/llama/llama.cpp index 749f3c2c6..486c3760e 100644 --- a/neural_speed/models/llama/llama.cpp +++ b/neural_speed/models/llama/llama.cpp @@ -42,6 +42,7 @@ #include "models/model_utils/util.h" #include "models/models.h" +#define SYCL_NDEBUG 1 // evaluate the transformer // // - lctx: model context @@ -139,13 +140,14 @@ static bool llama_model_eval_internal(model_context* ctx, const model_input* inp }; struct ne_context* ctx0 = ne_init(params); + ctx0->dev_ctx = ctx->dev_ctx; // for big prompts, if BLAS is enabled, it is better to use only one thread // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance ne_cgraph gf = {}; gf.n_threads = N >= 32 && ne_cpu_has_blas() ? 1 : n_threads; - const bool run_mha_reordered = kv_self.k->type == NE_TYPE_BTLA; + const bool run_mha_reordered = kv_self.k ? kv_self.k->type == NE_TYPE_BTLA : false; kv_cache_info_t kv_cache_info = {0, 0}; if (run_mha_reordered) { NE_ASSERT(kv_self.v->type == NE_TYPE_BTLA); // kv type should be the same @@ -168,8 +170,9 @@ static bool llama_model_eval_internal(model_context* ctx, const model_input* inp bestla_reordered_attn_fp32_batch_kv_info(&kv_shape, &kv_cache_info); } - struct ne_tensor* embd = ne_new_tensor_1d(ctx0, NE_TYPE_I32, seq_len_sum, NE_SIZE_CALC); + struct ne_tensor* embd = ne_new_tensor_1d(ctx0, NE_TYPE_I32, seq_len_sum, NE_SIZE_CALC, NE_BACKEND_CPU); ne_set_name(embd, "embd"); + int cpy_off = 0; for (int i = 0; i < batch_size; ++i) { memcpy(static_cast(embd->data) + cpy_off, inputs[i].tokens, n_tokens[i] * ne_element_size(embd)); @@ -184,14 +187,21 @@ static bool llama_model_eval_internal(model_context* ctx, const model_input* inp #endif struct ne_tensor* inpL = ne_get_rows(ctx0, model.others[0], embd); + int gpu_layer_start = n_layer - model.kv_self.n_gpu_layer; for (int il = 0; il < n_layer; ++il) { + bool cpu_layer = il < gpu_layer_start; + if (!cpu_layer) { + inpL = ne_device_sync(ctx0, inpL, NE_BACKEND_SYCL); + } else { + inpL = ne_device_sync(ctx0, inpL, NE_BACKEND_CPU); + } struct ne_tensor* inpSA = inpL; struct ne_tensor* cur; lctx.use_buf(ctx0, 0); - - // norm + ne_buffer_save(ctx0); + // norm { cur = ne_rms_norm(ctx0, inpL, hparams.norm_eps); @@ -213,277 +223,377 @@ static bool llama_model_eval_internal(model_context* ctx, const model_input* inp infer_bs); Vcur = ne_view_1d(ctx0, QKVcur, qkv_size, 2 * qkv_bytes); } else { +#if SYCL_NDEBUG Qcur = ne_reshape_4d(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[0], cur), head_size, n_head, infer_seq_len, infer_bs); Kcur = ne_reshape_4d(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[1], cur), head_size, n_head_kv, infer_seq_len, infer_bs); - Vcur = ne_mul_mat(ctx0, model.layers[il].attn[2], cur); - } - if (concat_multi_seqs) { - size_t off_sl = 0; - // per_request rope - for (int gi = 0; gi < infer_groups.size(); ++gi) { - const int qk_bs = infer_groups[gi].size(); - const int qk_sl = n_tokens[infer_groups[gi].front()]; - const int qk_n_past = n_pasts[infer_groups[gi].front()]; - struct ne_tensor* Qcur_req = - ne_view_4d(ctx0, Qcur, head_size, n_head, qk_sl, qk_bs, ne_element_size(Qcur) * head_size, - ne_element_size(Qcur) * head_size * n_head, ne_element_size(Qcur) * head_size * n_head * qk_sl, - off_sl * n_head * ne_element_size(Qcur)); - ne_build_forward_expand( - &gf, ne_rope_inplace(ctx0, Qcur_req, qk_n_past, n_rot, 0, 0, hparams.freq_base, hparams.freq_scale)); - struct ne_tensor* Kcur_req = ne_view_4d( - ctx0, Kcur, head_size, n_head_kv, qk_sl, qk_bs, ne_element_size(Kcur) * head_size, - ne_element_size(Kcur) * head_size * n_head_kv, ne_element_size(Kcur) * head_size * n_head_kv * qk_sl, - off_sl * n_head_kv * ne_element_size(Kcur)); - ne_build_forward_expand( - &gf, ne_rope_inplace(ctx0, Kcur_req, qk_n_past, n_rot, 0, 0, hparams.freq_base, hparams.freq_scale)); - off_sl += head_size * qk_bs * qk_sl; - } - } else { - Qcur = ne_rope_inplace(ctx0, Qcur, std::max(n_cached - N, n_past), n_rot, 0, 0, hparams.freq_base, - hparams.freq_scale); - Kcur = ne_rope_inplace( // n_ctx exceeds but it will be shift-roped back with cached K - ctx0, Kcur, (is_ring_full ? n_ctx : n_past), n_rot, 0, 0, hparams.freq_base, hparams.freq_scale); - // Vcur = ne_transpose(ctx0, ne_reshape_2d(ctx0, Vcur, head_size * n_head_kv, N)); + Vcur = ne_reshape_4d(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[2], cur), head_size, n_head_kv, infer_seq_len, + infer_bs); +#else + cur = ne_mul_mat(ctx0, model.layers[il].attn[0], cur); + cur = ne_mul_mat(ctx0, model.layers[il].attn[1], cur); + cur = ne_mul_mat(ctx0, model.layers[il].attn[2], cur); + cur = ne_mul_mat(ctx0, model.layers[il].attn[3], cur); +#endif } - ne_set_name(Qcur, "Qcur"); - ne_set_name(Kcur, "Kcur"); - ne_set_name(Vcur, "Vcur"); - // self-attention - const float attn_scale = 1.0f / sqrtf(static_cast(head_size)); - struct ne_tensor* KQV_merged_contiguous = - ne_new_tensor_2d(ctx0, NE_TYPE_F32, head_size * n_head, seq_len_sum, NE_SIZE_CALC); - if (!run_mha_reordered) { +#if SYCL_NDEBUG +#ifdef NS_SYCL + if (!cpu_layer && infer_groups.size() == 1 && batch_size == 1 && !is_ring_full) { + Qcur = ne_rope_inplace(ctx0, Qcur, n_past, n_rot, 0, 0, hparams.freq_base, hparams.freq_scale); + Kcur = ne_rope_inplace(ctx0, Kcur, n_past, n_rot, 0, 0, hparams.freq_base, hparams.freq_scale); + const int attn_sl = n_tokens[infer_groups[0].front()]; + const int attn_block_id = block_ids[infer_groups[0].front()]; + const int attn_n_past = n_pasts[infer_groups[0].front()]; + const int attn_n_total = n_totals[infer_groups[0].front()]; + ne_set_name(Qcur, "Qcur"); + ne_set_name(Kcur, "Kcur"); + ne_set_name(Vcur, "Vcur"); + const float attn_scale = 1.0f / sqrtf(static_cast(head_size)); + Kcur = ne_permute(ctx0, Kcur, 0, 2, 1, 3); // [heads, N, head_size] + Vcur = ne_permute(ctx0, Vcur, 1, 2, 0, 3); // [heads, head_size, N] + + struct ne_tensor* k_cache = ne_view_1d(ctx0, kv_self.k_d, n_ctx * n_embd_gqa * kv_n_ctx_block, + il * n_ctx * ne_element_size(kv_self.k_d) * n_embd_gqa * kv_n_ctx_block); + struct ne_tensor* v_cache = ne_view_1d(ctx0, kv_self.v_d, n_ctx * n_embd_gqa * kv_n_ctx_block, + il * n_ctx * ne_element_size(kv_self.v_d) * n_embd_gqa * kv_n_ctx_block); // store key and value to memory - // important: - // 1. storing RoPE-ed version of K in the KV cache! - // 2. for loop self-attention in multi seqs infer (num_request > 1) { - struct ne_tensor* const k_cache = - ne_view_1d(ctx0, kv_self.k, n_ctx * n_embd_gqa * kv_n_ctx_block, - il * n_ctx * ne_element_size(kv_self.k) * n_embd_gqa * kv_n_ctx_block); - struct ne_tensor* const v_cache = - ne_view_1d(ctx0, kv_self.v, n_ctx * n_embd_gqa * kv_n_ctx_block, - il * n_ctx * ne_element_size(kv_self.v) * n_embd_gqa * kv_n_ctx_block); - // cache = [tokens, beams, requests, layers], - // tokens = [head_dim, head_num, n_ctx] (may different orders) - size_t off_N_i = 0; - for (int i = 0; i < batch_size; ++i) { - const int block_idx = block_ids[i]; - const int N_i = n_tokens[i]; - const int n_past_i = n_pasts[i]; - // batch K - struct ne_tensor* Kcur_bs_i = - ne_permute(ctx0, - ne_view_4d(ctx0, Kcur, head_size, n_head_kv, N_i, 1, ne_element_size(Kcur) * head_size, - ne_element_size(Kcur) * n_embd_gqa, ne_element_size(Kcur) * n_embd_gqa * N_i, - ne_element_size(Kcur) * off_N_i), - 0, 2, 1, 3); - struct ne_tensor* k_bs_i = - ne_view_4d(ctx0, k_cache, head_size, N_i, n_head_kv, 1, ne_element_size(k_cache) * head_size, - ne_element_size(k_cache) * head_size * n_ctx, ne_element_size(k_cache) * n_embd_gqa * n_ctx, - block_idx * n_ctx * n_embd_gqa * ne_element_size(k_cache) + - head_size * n_past_i * ne_element_size(k_cache)); - // batch V - struct ne_tensor* Vcur_bs_i = - ne_permute(ctx0, - ne_reshape_4d(ctx0, - ne_view_2d(ctx0, Vcur, n_embd_gqa, N_i, ne_element_size(Vcur) * n_embd_gqa, - ne_element_size(Vcur) * off_N_i), - head_size, n_head_kv, N_i, 1), - 1, 2, 0, 3); - struct ne_tensor* v_bs_i = ne_view_4d( - ctx0, v_cache, N_i, head_size, n_head_kv, 1, n_ctx * ne_element_size(v_cache), - n_ctx * ne_element_size(v_cache) * head_size, n_ctx * ne_element_size(v_cache) * n_embd_gqa, - block_idx * n_ctx * n_embd_gqa * ne_element_size(v_cache) + n_past_i * ne_element_size(v_cache)); - // concat - ne_build_forward_expand(&gf, ne_cpy(ctx0, Kcur_bs_i, k_bs_i)); - ne_build_forward_expand(&gf, ne_cpy(ctx0, Vcur_bs_i, v_bs_i)); - off_N_i += head_size * n_head_kv * N_i; - } + struct ne_tensor* k_cache_view = ne_view_4d( + ctx0, k_cache, head_size, attn_sl, n_head, infer_bs, ne_element_size(k_cache) * head_size, + ne_element_size(k_cache) * head_size * n_ctx, ne_element_size(k_cache) * head_size * n_ctx * n_head, + attn_n_past * head_size * ne_element_size(k_cache)); // [kv_heads, N, head_size] + + struct ne_tensor* v_cache_view = ne_view_4d( + ctx0, v_cache, attn_sl, head_size, n_head, infer_bs, ne_element_size(v_cache) * n_ctx, + ne_element_size(v_cache) * head_size * n_ctx, ne_element_size(v_cache) * head_size * n_ctx * n_head, + attn_n_past * ne_element_size(v_cache)); // [kv_heads, head_size, N] + ne_build_forward_expand(&gf, ne_cpy(ctx0, Kcur, k_cache_view)); + ne_build_forward_expand(&gf, ne_cpy(ctx0, Vcur, v_cache_view)); } - // for-loop attention - size_t off_sl = 0; - for (int gi = 0; gi < infer_groups.size(); ++gi) { - const int attn_bs = infer_groups[gi].size(); - const int attn_sl = n_tokens[infer_groups[gi].front()]; - const int attn_block_id = block_ids[infer_groups[gi].front()]; - const int attn_n_past = n_pasts[infer_groups[gi].front()]; - const int attn_n_total = n_totals[infer_groups[gi].front()]; - struct ne_tensor* Q = - ne_permute(ctx0, - ne_view_4d(ctx0, Qcur, head_size, n_head, attn_sl, attn_bs, ne_element_size(Qcur) * head_size, - ne_element_size(Qcur) * head_size * n_head, - ne_element_size(Qcur) * head_size * n_head * attn_sl, off_sl * ne_element_size(Qcur)), - 0, 2, 1, 3); - std::string suffix = std::to_string(gi); - ne_set_name(Q, std::string("Q_" + suffix).c_str()); - const int n_cached_gi = shift_roped_k ? n_cached : attn_n_past + attn_sl; - std::vector attn_block_ids(infer_groups[gi].size()); - for (int j = 0; j < infer_groups[gi].size(); ++j) { - attn_block_ids[j] = block_ids[infer_groups[gi][j]]; +#if 1 + Kcur = ne_view_4d(ctx0, k_cache, head_size, attn_n_past + attn_sl, n_head, infer_bs, + ne_element_size(k_cache) * head_size, ne_element_size(k_cache) * head_size * n_ctx, + ne_element_size(k_cache) * head_size * n_head * n_ctx, + 0); // [kv_heads, klen, head_size] + Vcur = ne_view_4d(ctx0, v_cache, attn_n_past + attn_sl, head_size, n_head, infer_bs, + ne_element_size(v_cache) * n_ctx, ne_element_size(v_cache) * head_size * n_ctx, + ne_element_size(v_cache) * head_size * n_ctx * n_head, + 0); // [kv_heads, head_size, klen] + auto KQV = ne_flash_attn(ctx0, Qcur, Kcur, Vcur, attn_scale, ne_attn_flags_t(n_ctx)); +#else + Qcur = ne_permute(ctx0, Qcur, 0, 2, 1, 3); // [heads, N, head_size] + + Qcur = ne_device_sync(ctx0, Qcur, NE_BACKEND_CPU); + k_cache = ne_device_sync(ctx0, k_cache, NE_BACKEND_CPU); + v_cache = ne_device_sync(ctx0, v_cache, NE_BACKEND_CPU); + Kcur = ne_view_4d(ctx0, k_cache, head_size, n_past + infer_seq_len, n_head, infer_bs, + ne_element_size(k_cache) * head_size, ne_element_size(k_cache) * head_size * n_ctx, + ne_element_size(k_cache) * head_size * n_head * n_ctx, + 0); // [kv_heads, klen, head_size] + Vcur = ne_view_4d(ctx0, v_cache, n_past + infer_seq_len, head_size, n_head, infer_bs, + ne_element_size(v_cache) * n_ctx, ne_element_size(v_cache) * head_size * n_ctx, + ne_element_size(v_cache) * head_size * n_ctx * n_head, + 0); // [kv_heads, head_size, klen] + // attention + struct ne_tensor* KQ = ne_mul_mat(ctx0, Kcur, Qcur); // [heads, N, klen] + // KQ_scaled = KQ / sqrt(n_embd/n_head) + struct ne_tensor* KQ_scale = ne_new_f32(ctx0, attn_scale); + // KQ_scaled shape [n_cached, N, n_head, 1] + struct ne_tensor* KQ_scaled = ne_scale_inplace(ctx0, KQ, KQ_scale); + + // KQ_masked = mask_past(KQ_scaled) + if (N > 1) { + KQ_scaled = ne_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); + } + + // KQ = soft_max(KQ_masked) + struct ne_tensor* KQ_soft_max = ne_soft_max_inplace(ctx0, KQ_scaled); + + struct ne_tensor* KQV = ne_mul_mat(ctx0, Vcur, KQ_soft_max); + + // KQV_merged = KQV.permute(0, 2, 1, 3) + KQV = ne_cont(ctx0, ne_permute(ctx0, KQV, 0, 2, 1, 3)); + +#endif + KQV = ne_reshape_2d(ctx0, KQV, head_size * n_head, infer_seq_len); + // projection (no bias) + cur = ne_mul_mat(ctx0, model.layers[il].attn[3], KQV); + } else +#endif + { + Qcur = ne_device_sync(ctx0, Qcur, NE_BACKEND_CPU); + Kcur = ne_device_sync(ctx0, Kcur, NE_BACKEND_CPU); + Vcur = ne_device_sync(ctx0, Vcur, NE_BACKEND_CPU); + if (concat_multi_seqs) { + size_t off_sl = 0; + // per_request rope + for (int gi = 0; gi < infer_groups.size(); ++gi) { + const int qk_bs = infer_groups[gi].size(); + const int qk_sl = n_tokens[infer_groups[gi].front()]; + const int qk_n_past = n_pasts[infer_groups[gi].front()]; + struct ne_tensor* Qcur_req = + ne_view_4d(ctx0, Qcur, head_size, n_head, qk_sl, qk_bs, ne_element_size(Qcur) * head_size, + ne_element_size(Qcur) * head_size * n_head, ne_element_size(Qcur) * head_size * n_head * qk_sl, + off_sl * n_head * ne_element_size(Qcur)); + ne_build_forward_expand( + &gf, ne_rope_inplace(ctx0, Qcur_req, qk_n_past, n_rot, 0, 0, hparams.freq_base, hparams.freq_scale)); + struct ne_tensor* Kcur_req = ne_view_4d( + ctx0, Kcur, head_size, n_head_kv, qk_sl, qk_bs, ne_element_size(Kcur) * head_size, + ne_element_size(Kcur) * head_size * n_head_kv, ne_element_size(Kcur) * head_size * n_head_kv * qk_sl, + off_sl * n_head_kv * ne_element_size(Kcur)); + ne_build_forward_expand( + &gf, ne_rope_inplace(ctx0, Kcur_req, qk_n_past, n_rot, 0, 0, hparams.freq_base, hparams.freq_scale)); + off_sl += head_size * qk_bs * qk_sl; } - struct ne_tensor* K = - model_kv_cache_seq_concat(&gf, &lctx, ctx0, head_size, n_cached_gi, n_head_kv, attn_bs, attn_block_ids, il); - if (is_ring_full) { - K = ne_permute(ctx0, K, 0, 2, 1, 3); - struct ne_tensor* cossin_cache = nullptr; - // Currently we only cache cossin for N == 1 in model-wide; It may be worthwhile to cache cossin for other N - // in a single eval execution - if (N == 1) cossin_cache = kv_self.cossin; - K = ne_rope_shift_inplace(ctx0, K, -N, n_rot, 0, 0, n_keep, cossin_cache, hparams.freq_base, - hparams.freq_scale); - K = ne_permute(ctx0, K, 0, 2, 1, 3); + } else { + Qcur = ne_rope_inplace(ctx0, Qcur, std::max(n_cached - N, n_past), n_rot, 0, 0, hparams.freq_base, + hparams.freq_scale); + Kcur = ne_rope_inplace( // n_ctx exceeds but it will be shift-roped back with cached K + ctx0, Kcur, (is_ring_full ? n_ctx : n_past), n_rot, 0, 0, hparams.freq_base, hparams.freq_scale); + // Vcur = ne_transpose(ctx0, ne_reshape_2d(ctx0, Vcur, head_size * n_head_kv, N)); + } + ne_set_name(Qcur, "Qcur"); + ne_set_name(Kcur, "Kcur"); + ne_set_name(Vcur, "Vcur"); + // self-attention + const float attn_scale = 1.0f / sqrtf(static_cast(head_size)); + struct ne_tensor* KQV_merged_contiguous = + ne_new_tensor_2d(ctx0, NE_TYPE_F32, head_size * n_head, seq_len_sum, NE_SIZE_CALC, NE_BACKEND_CPU); + if (!run_mha_reordered) { + // store key and value to memory + // important: + // 1. storing RoPE-ed version of K in the KV cache! + // 2. for loop self-attention in multi seqs infer (num_request > 1) + { + struct ne_tensor* const k_cache = + ne_view_1d(ctx0, kv_self.k, n_ctx * n_embd_gqa * kv_n_ctx_block, + il * n_ctx * ne_element_size(kv_self.k) * n_embd_gqa * kv_n_ctx_block); + struct ne_tensor* const v_cache = + ne_view_1d(ctx0, kv_self.v, n_ctx * n_embd_gqa * kv_n_ctx_block, + il * n_ctx * ne_element_size(kv_self.v) * n_embd_gqa * kv_n_ctx_block); + // cache = [tokens, beams, requests, layers], + // tokens = [head_dim, head_num, n_ctx] (may different orders) + size_t off_N_i = 0; + for (int i = 0; i < batch_size; ++i) { + const int block_idx = block_ids[i]; + const int N_i = n_tokens[i]; + const int n_past_i = n_pasts[i]; + // batch K + struct ne_tensor* Kcur_bs_i = + ne_permute(ctx0, + ne_view_4d(ctx0, Kcur, head_size, n_head_kv, N_i, 1, ne_element_size(Kcur) * head_size, + ne_element_size(Kcur) * n_embd_gqa, ne_element_size(Kcur) * n_embd_gqa * N_i, + ne_element_size(Kcur) * off_N_i), + 0, 2, 1, 3); + struct ne_tensor* k_bs_i = + ne_view_4d(ctx0, k_cache, head_size, N_i, n_head_kv, 1, ne_element_size(k_cache) * head_size, + ne_element_size(k_cache) * head_size * n_ctx, ne_element_size(k_cache) * n_embd_gqa * n_ctx, + block_idx * n_ctx * n_embd_gqa * ne_element_size(k_cache) + + head_size * n_past_i * ne_element_size(k_cache)); + // batch V + struct ne_tensor* Vcur_bs_i = + ne_permute(ctx0, + ne_reshape_4d(ctx0, + ne_view_2d(ctx0, Vcur, n_embd_gqa, N_i, ne_element_size(Vcur) * n_embd_gqa, + ne_element_size(Vcur) * off_N_i), + head_size, n_head_kv, N_i, 1), + 1, 2, 0, 3); + struct ne_tensor* v_bs_i = ne_view_4d( + ctx0, v_cache, N_i, head_size, n_head_kv, 1, n_ctx * ne_element_size(v_cache), + n_ctx * ne_element_size(v_cache) * head_size, n_ctx * ne_element_size(v_cache) * n_embd_gqa, + block_idx * n_ctx * n_embd_gqa * ne_element_size(v_cache) + n_past_i * ne_element_size(v_cache)); + // concat + ne_build_forward_expand(&gf, ne_cpy(ctx0, Kcur_bs_i, k_bs_i)); + ne_build_forward_expand(&gf, ne_cpy(ctx0, Vcur_bs_i, v_bs_i)); + off_N_i += head_size * n_head_kv * N_i; + } } - // split cached V into n_head heads - struct ne_tensor* V = model_kv_cache_seq_concat(&gf, &lctx, ctx0, n_cached_gi, head_size, n_head_kv, attn_bs, - attn_block_ids, il, false); - ne_set_name(K, std::string("K_" + suffix).c_str()); - ne_set_name(V, std::string("V_" + suffix).c_str()); + // for-loop attention + size_t off_sl = 0; + for (int gi = 0; gi < infer_groups.size(); ++gi) { + const int attn_bs = infer_groups[gi].size(); + const int attn_sl = n_tokens[infer_groups[gi].front()]; + const int attn_block_id = block_ids[infer_groups[gi].front()]; + const int attn_n_past = n_pasts[infer_groups[gi].front()]; + const int attn_n_total = n_totals[infer_groups[gi].front()]; + struct ne_tensor* Q = ne_permute( + ctx0, + ne_view_4d(ctx0, Qcur, head_size, n_head, attn_sl, attn_bs, ne_element_size(Qcur) * head_size, + ne_element_size(Qcur) * head_size * n_head, + ne_element_size(Qcur) * head_size * n_head * attn_sl, off_sl * ne_element_size(Qcur)), + 0, 2, 1, 3); + std::string suffix = std::to_string(gi); + ne_set_name(Q, std::string("Q_" + suffix).c_str()); + const int n_cached_gi = shift_roped_k ? n_cached : attn_n_past + attn_sl; + std::vector attn_block_ids(infer_groups[gi].size()); + for (int j = 0; j < infer_groups[gi].size(); ++j) { + attn_block_ids[j] = block_ids[infer_groups[gi][j]]; + } + struct ne_tensor* K = model_kv_cache_seq_concat(&gf, &lctx, ctx0, head_size, n_cached_gi, n_head_kv, attn_bs, + attn_block_ids, il); + if (is_ring_full) { + K = ne_permute(ctx0, K, 0, 2, 1, 3); + struct ne_tensor* cossin_cache = nullptr; + // Currently we only cache cossin for N == 1 in model-wide; It may be worthwhile to cache cossin for other N + // in a single eval execution + if (N == 1) cossin_cache = kv_self.cossin; + K = ne_rope_shift_inplace(ctx0, K, -N, n_rot, 0, 0, n_keep, cossin_cache, hparams.freq_base, + hparams.freq_scale); + K = ne_permute(ctx0, K, 0, 2, 1, 3); + } + + // split cached V into n_head heads + struct ne_tensor* V = model_kv_cache_seq_concat(&gf, &lctx, ctx0, n_cached_gi, head_size, n_head_kv, attn_bs, + attn_block_ids, il, false); + ne_set_name(K, std::string("K_" + suffix).c_str()); + ne_set_name(V, std::string("V_" + suffix).c_str()); - // K * Q - struct ne_tensor* KQ = ne_mul_mat(ctx0, K, Q); - ne_set_name(KQ, std::string("KQ_" + suffix).c_str()); + // K * Q + struct ne_tensor* KQ = ne_mul_mat(ctx0, K, Q); + ne_set_name(KQ, std::string("KQ_" + suffix).c_str()); - // KQ_scaled = KQ / sqrt(n_embd/n_head) - struct ne_tensor* KQ_scale = ne_new_f32(ctx0, attn_scale); - ne_set_name(KQ_scale, std::string("1/sqrt(n_embd/n_head)_" + suffix).c_str()); + // KQ_scaled = KQ / sqrt(n_embd/n_head) + struct ne_tensor* KQ_scale = ne_new_f32(ctx0, attn_scale); + ne_set_name(KQ_scale, std::string("1/sqrt(n_embd/n_head)_" + suffix).c_str()); - // KQ_scaled shape [n_cached, N, n_head, 1] - struct ne_tensor* KQ_scaled = ne_scale_inplace(ctx0, KQ, KQ_scale); - ne_set_name(KQ_scaled, std::string("KQ_scaled_" + suffix).c_str()); + // KQ_scaled shape [n_cached, N, n_head, 1] + struct ne_tensor* KQ_scaled = ne_scale_inplace(ctx0, KQ, KQ_scale); + ne_set_name(KQ_scaled, std::string("KQ_scaled_" + suffix).c_str()); - // KQ_masked = mask_past(KQ_scaled) - if (N > 1 || !shift_roped_k || attn_n_total == 0) { // TODO(Yi): shift roped-k with N > 1 next-token - KQ_scaled = ne_diag_mask_inf_inplace(ctx0, KQ_scaled, attn_n_past); - ne_set_name(KQ_scaled, std::string("KQ_masked_" + suffix).c_str()); - } + // KQ_masked = mask_past(KQ_scaled) + if (N > 1 || !shift_roped_k || attn_n_total == 0) { // TODO(Yi): shift roped-k with N > 1 next-token + KQ_scaled = ne_diag_mask_inf_inplace(ctx0, KQ_scaled, attn_n_past); + ne_set_name(KQ_scaled, std::string("KQ_masked_" + suffix).c_str()); + } - // KQ = soft_max(KQ_masked) - struct ne_tensor* KQ_soft_max = ne_soft_max_inplace(ctx0, KQ_scaled); - ne_set_name(KQ_soft_max, std::string("KQ_soft_max_" + suffix).c_str()); + // KQ = soft_max(KQ_masked) + struct ne_tensor* KQ_soft_max = ne_soft_max_inplace(ctx0, KQ_scaled); + ne_set_name(KQ_soft_max, std::string("KQ_soft_max_" + suffix).c_str()); - struct ne_tensor* KQV = ne_mul_mat(ctx0, V, KQ_soft_max); - ne_set_name(KQV, std::string("KQV_" + suffix).c_str()); + struct ne_tensor* KQV = ne_mul_mat(ctx0, V, KQ_soft_max); + ne_set_name(KQV, std::string("KQV_" + suffix).c_str()); - // KQV_merged = KQV.permute(0, 2, 1, 3) - struct ne_tensor* KQV_merged_gi = ne_permute(ctx0, KQV, 0, 2, 1, 3); - ne_set_name(KQV_merged_gi, std::string("KQV_merged_" + suffix).c_str()); + // KQV_merged = KQV.permute(0, 2, 1, 3) + struct ne_tensor* KQV_merged_gi = ne_permute(ctx0, KQV, 0, 2, 1, 3); + ne_set_name(KQV_merged_gi, std::string("KQV_merged_" + suffix).c_str()); - ne_build_forward_expand(&gf, - ne_cpy(ctx0, KQV_merged_gi, - ne_view_2d(ctx0, KQV_merged_contiguous, head_size * n_head, attn_sl * attn_bs, - head_size * n_head * ne_element_size(KQV_merged_contiguous), - ne_element_size(KQV_merged_contiguous) * off_sl))); - off_sl += head_size * n_head * attn_sl * attn_bs; - } - ne_set_name(KQV_merged_contiguous, "KQV_merged_contiguous"); - // projection (no bias) - cur = ne_mul_mat(ctx0, model.layers[il].attn[3], KQV_merged_contiguous); - } else { - const auto k_size = kv_cache_info.k_bytes; - const auto v_size = kv_cache_info.v_bytes; - // store key and value to memory - { - size_t off_sl = 0; - for (int gi = 0; gi < infer_groups.size(); ++gi) { - const int update_bs = infer_groups[gi].size(); - const int update_sl = n_tokens[infer_groups[gi].front()]; - const int update_block_id = block_ids[infer_groups[gi].front()]; - const int update_n_past = n_pasts[infer_groups[gi].front()]; - const auto k_cache_g = ne_view_4d(ctx0, kv_self.k, // tensor - head_size, n_ctx, n_head_kv, update_bs, // ne - 0, 0, k_size, // nb (bestla managed) - il * kv_n_ctx_block * k_size + update_block_id * k_size); // offset - const auto k_cur_g = - ne_view_4d(ctx0, Kcur, head_size, n_head_kv, update_sl, update_bs, ne_element_size(Kcur) * head_size, - ne_element_size(Kcur) * n_embd_gqa, ne_element_size(Kcur) * n_embd_gqa * update_sl, - ne_element_size(Kcur) * off_sl); - ne_build_forward_expand(&gf, ne_flash_attn_update_k(ctx0, k_cache_g, k_cur_g, update_n_past, is_ring_full)); - struct ne_tensor* v_cache_g = - ne_view_4d(ctx0, kv_self.v, // tensor - head_size, n_ctx, n_head_kv, update_bs, // ne - 0, 0, v_size, // nb (bestla managed) - il * kv_n_ctx_block * v_size + update_block_id * v_size); // offset); - // bestla always view V as (D, n_head, seq, bs) - const auto v_cur_g = - ne_view_4d(ctx0, Vcur, head_size, n_head_kv, update_sl, update_bs, ne_element_size(Vcur) * head_size, - ne_element_size(Vcur) * n_embd_gqa, ne_element_size(Vcur) * n_embd_gqa * update_sl, - ne_element_size(Vcur) * off_sl); - ne_build_forward_expand(&gf, ne_flash_attn_update_v(ctx0, v_cache_g, v_cur_g, update_n_past, is_ring_full)); - off_sl += n_embd_gqa * update_sl * update_bs; + ne_build_forward_expand(&gf, + ne_cpy(ctx0, KQV_merged_gi, + ne_view_2d(ctx0, KQV_merged_contiguous, head_size * n_head, attn_sl * attn_bs, + head_size * n_head * ne_element_size(KQV_merged_contiguous), + ne_element_size(KQV_merged_contiguous) * off_sl))); + off_sl += head_size * n_head * attn_sl * attn_bs; + } + ne_set_name(KQV_merged_contiguous, "KQV_merged_contiguous"); + // projection (no bias) + cur = ne_mul_mat(ctx0, model.layers[il].attn[3], KQV_merged_contiguous); + } else { + const auto k_size = kv_cache_info.k_bytes; + const auto v_size = kv_cache_info.v_bytes; + // store key and value to memory + { + size_t off_sl = 0; + for (int gi = 0; gi < infer_groups.size(); ++gi) { + const int update_bs = infer_groups[gi].size(); + const int update_sl = n_tokens[infer_groups[gi].front()]; + const int update_block_id = block_ids[infer_groups[gi].front()]; + const int update_n_past = n_pasts[infer_groups[gi].front()]; + const auto k_cache_g = ne_view_4d(ctx0, kv_self.k, // tensor + head_size, n_ctx, n_head_kv, update_bs, // ne + 0, 0, k_size, // nb (bestla managed) + il * kv_n_ctx_block * k_size + update_block_id * k_size); // offset + const auto k_cur_g = + ne_view_4d(ctx0, Kcur, head_size, n_head_kv, update_sl, update_bs, ne_element_size(Kcur) * head_size, + ne_element_size(Kcur) * n_embd_gqa, ne_element_size(Kcur) * n_embd_gqa * update_sl, + ne_element_size(Kcur) * off_sl); + ne_build_forward_expand(&gf, ne_flash_attn_update_k(ctx0, k_cache_g, k_cur_g, update_n_past, is_ring_full)); + struct ne_tensor* v_cache_g = + ne_view_4d(ctx0, kv_self.v, // tensor + head_size, n_ctx, n_head_kv, update_bs, // ne + 0, 0, v_size, // nb (bestla managed) + il * kv_n_ctx_block * v_size + update_block_id * v_size); // offset); + // bestla always view V as (D, n_head, seq, bs) + const auto v_cur_g = + ne_view_4d(ctx0, Vcur, head_size, n_head_kv, update_sl, update_bs, ne_element_size(Vcur) * head_size, + ne_element_size(Vcur) * n_embd_gqa, ne_element_size(Vcur) * n_embd_gqa * update_sl, + ne_element_size(Vcur) * off_sl); + ne_build_forward_expand(&gf, ne_flash_attn_update_v(ctx0, v_cache_g, v_cur_g, update_n_past, is_ring_full)); + off_sl += n_embd_gqa * update_sl * update_bs; + } } - } - // for-loop attention - size_t off_sl = 0; - for (int gi = 0; gi < infer_groups.size(); ++gi) { - const int attn_bs = infer_groups[gi].size(); - const int attn_sl = n_tokens[infer_groups[gi].front()]; - const int attn_block_id = block_ids[infer_groups[gi].front()]; - const int attn_n_past = n_pasts[infer_groups[gi].front()]; - const int attn_n_total = n_totals[infer_groups[gi].front()]; - struct ne_tensor* Q = - ne_permute(ctx0, - ne_view_4d(ctx0, Qcur, head_size, n_head, attn_sl, attn_bs, ne_element_size(Qcur) * head_size, - ne_element_size(Qcur) * head_size * n_head, - ne_element_size(Qcur) * head_size * n_head * attn_sl, off_sl * ne_element_size(Qcur)), - 0, 2, 1, 3); - std::string suffix = std::to_string(gi); - ne_set_name(Q, std::string("Q_" + suffix).c_str()); - const int n_cached_gi = shift_roped_k ? n_cached : attn_n_past + attn_sl; - struct ne_tensor* K = - ne_view_4d(ctx0, kv_self.k, // tensor - head_size, n_cached_gi, n_head_kv, attn_bs, // ne - kv_cache_info.stride_k_sl, kv_cache_info.stride_k_head_num, k_size, // nb (bestla managed) - il * kv_n_ctx_block * k_size + attn_block_id * k_size); // offset - *reinterpret_cast(&K->nb[0]) = kv_cache_info.k_layout; // use nb0 for layout - if (is_ring_full) { - struct ne_tensor* cossin_cache = nullptr; - // Currently we only cache cossin for N == 1 in model-wide; It may be worthwhile to cache cossin for other N - // in a single eval execution - if (N == 1) cossin_cache = kv_self.cossin; - K = ne_rope_shift_inplace(ctx0, K, -N, n_rot, 0, 0, n_keep, cossin_cache, hparams.freq_base, - hparams.freq_scale); + // for-loop attention + size_t off_sl = 0; + for (int gi = 0; gi < infer_groups.size(); ++gi) { + const int attn_bs = infer_groups[gi].size(); + const int attn_sl = n_tokens[infer_groups[gi].front()]; + const int attn_block_id = block_ids[infer_groups[gi].front()]; + const int attn_n_past = n_pasts[infer_groups[gi].front()]; + const int attn_n_total = n_totals[infer_groups[gi].front()]; + struct ne_tensor* Q = ne_permute( + ctx0, + ne_view_4d(ctx0, Qcur, head_size, n_head, attn_sl, attn_bs, ne_element_size(Qcur) * head_size, + ne_element_size(Qcur) * head_size * n_head, + ne_element_size(Qcur) * head_size * n_head * attn_sl, off_sl * ne_element_size(Qcur)), + 0, 2, 1, 3); + std::string suffix = std::to_string(gi); + ne_set_name(Q, std::string("Q_" + suffix).c_str()); + const int n_cached_gi = shift_roped_k ? n_cached : attn_n_past + attn_sl; + struct ne_tensor* K = + ne_view_4d(ctx0, kv_self.k, // tensor + head_size, n_cached_gi, n_head_kv, attn_bs, // ne + kv_cache_info.stride_k_sl, kv_cache_info.stride_k_head_num, k_size, // nb (bestla managed) + il * kv_n_ctx_block * k_size + attn_block_id * k_size); // offset + *reinterpret_cast(&K->nb[0]) = kv_cache_info.k_layout; // use nb0 for layout + if (is_ring_full) { + struct ne_tensor* cossin_cache = nullptr; + // Currently we only cache cossin for N == 1 in model-wide; It may be worthwhile to cache cossin for other N + // in a single eval execution + if (N == 1) cossin_cache = kv_self.cossin; + K = ne_rope_shift_inplace(ctx0, K, -N, n_rot, 0, 0, n_keep, cossin_cache, hparams.freq_base, + hparams.freq_scale); + } + struct ne_tensor* V = + ne_view_4d(ctx0, kv_self.v, // tensor + n_cached_gi, head_size, n_head_kv, attn_bs, // ne + kv_cache_info.stride_v_head_size, kv_cache_info.stride_v_head_num, + v_size, // nb (bestla managed) + il * kv_n_ctx_block * v_size + attn_block_id * v_size); // use nb0 for layout + *reinterpret_cast(&V->nb[0]) = kv_cache_info.v_layout; + ne_set_name(K, std::string("K_" + suffix).c_str()); + ne_set_name(V, std::string("V_" + suffix).c_str()); + + ne_attn_flags_t attn_flags = NE_ATTN_FLAG_NONE; + if (hparams.mha_prefer_f32) attn_flags |= NE_ATTN_FLAG_PREFER_FP32; + if (n_total == 0 || !shift_roped_k) + attn_flags |= NE_ATTN_FLAG_IS_CAUSAL; // no causal mask on next-token cases + struct ne_tensor* KQV_Out = ne_flash_attn(ctx0, Q, K, V, attn_scale, attn_flags); + struct ne_tensor* KQV_merged_gi = ne_view_2d(ctx0, KQV_Out, head_size * n_head, attn_sl * attn_bs, + head_size * n_head * ne_element_size(KQV_Out), 0); + ne_set_name(KQV_merged_gi, std::string("KQV_merged_" + suffix).c_str()); + ne_build_forward_expand(&gf, + ne_cpy(ctx0, KQV_merged_gi, + ne_view_2d(ctx0, KQV_merged_contiguous, head_size * n_head, attn_sl * attn_bs, + head_size * n_head * ne_element_size(KQV_merged_contiguous), + ne_element_size(KQV_merged_contiguous) * off_sl))); + off_sl += head_size * n_head * attn_sl * attn_bs; } - struct ne_tensor* V = ne_view_4d(ctx0, kv_self.v, // tensor - n_cached_gi, head_size, n_head_kv, attn_bs, // ne - kv_cache_info.stride_v_head_size, kv_cache_info.stride_v_head_num, - v_size, // nb (bestla managed) - il * kv_n_ctx_block * v_size + attn_block_id * v_size); // use nb0 for layout - *reinterpret_cast(&V->nb[0]) = kv_cache_info.v_layout; - ne_set_name(K, std::string("K_" + suffix).c_str()); - ne_set_name(V, std::string("V_" + suffix).c_str()); - - ne_attn_flags_t attn_flags = NE_ATTN_FLAG_NONE; - if (hparams.mha_prefer_f32) attn_flags |= NE_ATTN_FLAG_PREFER_FP32; - if (n_total == 0 || !shift_roped_k) attn_flags |= NE_ATTN_FLAG_IS_CAUSAL; // no causal mask on next-token cases - struct ne_tensor* KQV_Out = ne_flash_attn(ctx0, Q, K, V, attn_scale, attn_flags); - struct ne_tensor* KQV_merged_gi = ne_view_2d(ctx0, KQV_Out, head_size * n_head, attn_sl * attn_bs, - head_size * n_head * ne_element_size(KQV_Out), 0); - ne_set_name(KQV_merged_gi, std::string("KQV_merged_" + suffix).c_str()); - ne_build_forward_expand(&gf, - ne_cpy(ctx0, KQV_merged_gi, - ne_view_2d(ctx0, KQV_merged_contiguous, head_size * n_head, attn_sl * attn_bs, - head_size * n_head * ne_element_size(KQV_merged_contiguous), - ne_element_size(KQV_merged_contiguous) * off_sl))); - off_sl += head_size * n_head * attn_sl * attn_bs; + ne_set_name(KQV_merged_contiguous, "KQV_merged_contiguous"); + // projection (no bias) + cur = ne_mul_mat(ctx0, model.layers[il].attn[3], KQV_merged_contiguous); } - ne_set_name(KQV_merged_contiguous, "KQV_merged_contiguous"); - // projection (no bias) - cur = ne_mul_mat(ctx0, model.layers[il].attn[3], KQV_merged_contiguous); } + #ifdef NS_TP_MODEL if (enable_tp) { cur = ne_all_reduce(ctx0, cur); } #endif - - lctx.use_buf(ctx0, 1); - +#endif + // lctx.use_buf(ctx0, 1); struct ne_tensor* inpFF = ne_add(ctx0, cur, inpSA); // feed-forward network @@ -509,7 +619,8 @@ static bool llama_model_eval_internal(model_context* ctx, const model_input* inp } } else { // for-loop MOE (deal with sequence one by one) - struct ne_tensor* moe_out = ne_new_tensor_2d(ctx0, NE_TYPE_F32, head_size * n_head, seq_len_sum, NE_SIZE_CALC); + struct ne_tensor* moe_out = + ne_new_tensor_2d(ctx0, NE_TYPE_F32, head_size * n_head, seq_len_sum, NE_SIZE_CALC, NE_BACKEND_CPU); size_t off_sl = 0; for (int bi = 0; bi < batch_size; ++bi) { const int moe_sl = n_tokens[bi]; @@ -585,13 +696,12 @@ static bool llama_model_eval_internal(model_context* ctx, const model_input* inp } cur = ne_add(ctx0, cur, inpFF); + ne_buffer_load(ctx0); // input for next layer inpL = cur; } - lctx.use_buf(ctx0, 0); - // used at the end to optionally extract the embeddings struct ne_tensor* embeddings = nullptr; // norm @@ -606,7 +716,11 @@ static bool llama_model_eval_internal(model_context* ctx, const model_input* inp // lm_head inpL = ne_mul_mat(ctx0, model.others[2], inpL); + inpL = ne_device_sync(ctx0, inpL, NE_BACKEND_CPU); + if (!lctx.embedding.empty()) { + embeddings = ne_device_sync(ctx0, embeddings, NE_BACKEND_CPU); + } lctx.use_buf(ctx0, -1); // logits -> probs @@ -622,6 +736,7 @@ static bool llama_model_eval_internal(model_context* ctx, const model_input* inp // update kv token count lctx.model.kv_self.n = n_cached; + float* logptr = (float*)inpL->data; // extract logits { @@ -629,7 +744,7 @@ static bool llama_model_eval_internal(model_context* ctx, const model_input* inp if (lctx.logits_all) { logits_out.resize(n_vocab * seq_len_sum); - memcpy(logits_out.data(), reinterpret_cast(ne_get_data(inpL)), sizeof(float) * n_vocab * seq_len_sum); + memcpy(logits_out.data(), logptr, sizeof(float) * n_vocab * seq_len_sum); } else { // return result for just the last token logits_out.resize(n_vocab * batch_size); @@ -637,12 +752,12 @@ static bool llama_model_eval_internal(model_context* ctx, const model_input* inp reinterpret_cast(bestla_get_thread_handle()); threading->parallel_for_collapse(0, batch_size, 1, [&](int i) { size_t bs_off = std::accumulate(n_tokens.begin(), n_tokens.begin() + i, 0) * n_vocab; - memcpy(logits_out.data() + (i * n_vocab), - reinterpret_cast(ne_get_data(inpL)) + bs_off + (n_vocab * (n_tokens[i] - 1)), + memcpy(logits_out.data() + (i * n_vocab), logptr + bs_off + (n_vocab * (n_tokens[i] - 1)), sizeof(float) * n_vocab); }); } } + // extract embeddings if (!lctx.embedding.empty()) { auto& embedding_out = lctx.embedding; diff --git a/neural_speed/models/llama/llama.h b/neural_speed/models/llama/llama.h index 0e7c31255..a76e9022b 100644 --- a/neural_speed/models/llama/llama.h +++ b/neural_speed/models/llama/llama.h @@ -36,11 +36,19 @@ static const model_scratch llama_mem_req(int n_layers, float scratch_size_ratio static_cast(scratch_size_ratio * 4096) * MB, }; case 32: +#ifdef NS_SYCL + return { + static_cast(scratch_size_ratio * 0) * MB, + static_cast(scratch_size_ratio * 0) * MB, + static_cast(scratch_size_ratio * 2048) * MB, + }; +#else return { static_cast(scratch_size_ratio * 4096) * MB, static_cast(scratch_size_ratio * 2048) * MB, static_cast(scratch_size_ratio * 4096) * MB, }; +#endif case 40: return { static_cast(scratch_size_ratio * 4096) * MB, diff --git a/neural_speed/models/llama/llama_utils.cpp b/neural_speed/models/llama/llama_utils.cpp index f78f47e83..1abe9f7aa 100644 --- a/neural_speed/models/llama/llama_utils.cpp +++ b/neural_speed/models/llama/llama_utils.cpp @@ -32,6 +32,7 @@ #include "core/ne.h" #include "core/ne_layers.h" #include "models/llama/llama.h" +#include "models/model_utils/model_utils.h" #include "models/model_utils/model_config.h" #include "models/model_utils/model_files.h" #include "models/model_utils/model_types.h" @@ -93,14 +94,23 @@ void Llama::load(model_context* ctx, model_progress_callback progress_callback, size_t ctx_size; size_t mmapped_size; ml->calc_sizes(&ctx_size, &mmapped_size); + int n_cpu_layer = n_layer - n_gpu_layer; + n_cpu_layer = n_cpu_layer < 0 ? 0 : n_cpu_layer; fprintf(stderr, "%s: ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0); - + auto host_size = + (ctx_size + (50 << 20)) * n_cpu_layer / n_layer + n_embd * n_vocab * sizeof(float); // embedding on CPU + auto device_size = (ctx_size + (50 << 20)) * n_gpu_layer / n_layer + (50 << 20); + fprintf(stderr, "%s: host ctx size = %7.2f MB\n", __func__, host_size / 1024.0 / 1024.0); +#ifdef NS_SYCL + fprintf(stderr, "%s: device ctx size = %7.2f MB\n", __func__, device_size / 1024.0 / 1024.0); +#endif // create the ne context - lctx.model.buf.resize(ctx_size); + lctx.model.buf.resize(host_size); if (use_mlock) { lctx.model.mlock_buf.init(lctx.model.buf.addr); lctx.model.mlock_buf.grow_to(lctx.model.buf.size); } + model_alloc_sycl_mem(lctx.dev_ctx, device_size); struct ne_init_params params = { /*.mem_size =*/lctx.model.buf.size, @@ -112,12 +122,14 @@ void Llama::load(model_context* ctx, model_progress_callback progress_callback, if (!model.ctx) { throw format("ne_init() failed"); } - + ne_ctx->dev_ctx = ctx->dev_ctx; ml->ne_ctx = ne_ctx; const int i_gpu_start = n_layer - n_gpu_layer; model.layers.resize(n_layer); size_t vram_total = 0; + size_t device_total = 0; + if (ml->verify_tensor("token_embd.weight")) { // GGUF model.others[0] = ml->get_tensor("token_embd.weight", {n_embd, n_vocab}, NE_BACKEND_CPU); model.others[1] = ml->get_tensor("output_norm.weight", {n_embd}, NE_BACKEND_CPU); @@ -170,12 +182,12 @@ void Llama::load(model_context* ctx, model_progress_callback progress_callback, } } else { // NE Fortmat model.others[0] = ml->get_tensor("tok_embeddings.weight", {n_embd, n_vocab}, NE_BACKEND_CPU); - model.others[1] = ml->get_tensor("norm.weight", {n_embd}, NE_BACKEND_CPU); - model.others[2] = ml->get_tensor("output.weight", {n_embd, n_vocab}, - n_gpu_layer > static_cast(n_layer) ? MODEL_BACKEND_OFFLOAD : NE_BACKEND_CPU); + model.others[1] = ml->get_tensor("norm.weight", {n_embd}, n_gpu_layer ? NE_BACKEND_SYCL : NE_BACKEND_CPU); + model.others[2] = + ml->get_tensor("output.weight", {n_embd, n_vocab}, n_gpu_layer > 0 ? NE_BACKEND_SYCL : NE_BACKEND_CPU); for (uint32_t i = 0; i < n_layer; ++i) { - const ne_backend backend = static_cast(i) < i_gpu_start ? NE_BACKEND_CPU : MODEL_BACKEND_OFFLOAD; + const ne_backend backend = static_cast(i) < i_gpu_start ? NE_BACKEND_CPU : NE_BACKEND_SYCL; auto& layer = model.layers[i]; std::string layers_i = "layers." + std::to_string(i); @@ -214,24 +226,18 @@ void Llama::load(model_context* ctx, model_progress_callback progress_callback, ml->get_tensor(layers_i + ".ffn_up." + std::to_string(x) + ".weight", {n_embd, n_ff}, backend); } } - if (backend != NE_BACKEND_CPU) { - vram_total += ne_nbytes(layer.norm[0]) + ne_nbytes(layer.attn[0]) + ne_nbytes(layer.attn[1]) + - ne_nbytes(layer.attn[2]) + ne_nbytes(layer.attn[3]) + ne_nbytes(layer.norm[1]) + - ne_nbytes(layer.ffn[0]) + ne_nbytes(layer.ffn[1]) + ne_nbytes(layer.ffn[2]); + auto layer_total = ne_nbytes(layer.norm[0]) + ne_nbytes(layer.attn[0]) + ne_nbytes(layer.attn[1]) + + ne_nbytes(layer.attn[2]) + ne_nbytes(layer.attn[3]) + ne_nbytes(layer.norm[1]) + + ne_nbytes(layer.ffn[0]) + ne_nbytes(layer.ffn[1]) + ne_nbytes(layer.ffn[2]); + if (backend == NE_BACKEND_CPU) { + vram_total += layer_total; + } else { + device_total += layer_total; } } } - - // print memory requirements - // this is the total memory required to run the inference - const size_t mem_required = ctx_size + mmapped_size - vram_total + // weights in VRAM not in memory - scratch.scratch0 + scratch.scratch1 + scratch.eval; - fprintf(stderr, "%s: scratch0 = %7.2f MB\n", __func__, scratch.scratch0 / 1024.0 / 1024.0); - fprintf(stderr, "%s: scratch1 = %7.2f MB\n", __func__, scratch.scratch1 / 1024.0 / 1024.0); - fprintf(stderr, "%s: scratch2 = %7.2f MB\n", __func__, scratch.eval / 1024.0 / 1024.0); - fprintf(stderr, "%s: mem required = %7.2f MB (+ memory per state)\n", __func__, mem_required / 1024.0 / 1024.0); - - (void)n_gpu_layer; + NE_ASSERT(vram_total <= host_size); + NE_ASSERT(device_total <= device_size); // populate `tensors_by_name` for (model_load_tensor& lt : ml->tensors_map.tensors) { @@ -239,7 +245,6 @@ void Llama::load(model_context* ctx, model_progress_callback progress_callback, } ml->load_all_data(progress_callback, progress_callback_user_data, use_mlock ? &lctx.model.mlock_mmap : nullptr); - if (progress_callback) { progress_callback(1.0f, progress_callback_user_data); } diff --git a/neural_speed/models/model_utils/model_config.h b/neural_speed/models/model_utils/model_config.h index 575d56984..07797de2f 100644 --- a/neural_speed/models/model_utils/model_config.h +++ b/neural_speed/models/model_utils/model_config.h @@ -121,7 +121,7 @@ std::vector model_tokenize(struct model_context* ctx, const std::st // Model utils // -struct model_context* model_init_from_gpt_params(const gpt_params& params); +struct model_context* model_init_from_gpt_params(const gpt_params& params, ne_sycl_context* dev_ctx); // KV cache elements per layer per batch per beam void get_batch_kv_elements_from_gpt_params(int heads_kv, int head_size, int n_ctx, ne_type wtype, int32_t* k_size, diff --git a/neural_speed/models/model_utils/model_files.h b/neural_speed/models/model_utils/model_files.h index 444d28976..34f91c169 100644 --- a/neural_speed/models/model_utils/model_files.h +++ b/neural_speed/models/model_utils/model_files.h @@ -1456,13 +1456,13 @@ struct model_model_loader { struct ne_tensor* tensor; if (lt.ne.size() == 2) { if (lt.type == NE_TYPE_BTLA) { - tensor = ne_new_tensor_2d(ne_ctx, lt.type, lt.ne.at(0), lt.ne.at(1), lt.size); + tensor = ne_new_tensor_2d(ne_ctx, lt.type, lt.ne.at(0), lt.ne.at(1), lt.size, backend); } else { - tensor = ne_new_tensor_2d(ne_ctx, lt.type, lt.ne.at(0), lt.ne.at(1), NE_SIZE_CALC); + tensor = ne_new_tensor_2d(ne_ctx, lt.type, lt.ne.at(0), lt.ne.at(1), NE_SIZE_CALC, backend); } } else { MODEL_ASSERT(lt.ne.size() == 1); - tensor = ne_new_tensor_1d(ne_ctx, lt.type, lt.ne.at(0), NE_SIZE_CALC); + tensor = ne_new_tensor_1d(ne_ctx, lt.type, lt.ne.at(0), NE_SIZE_CALC, backend); } ne_set_name(tensor, lt.name.c_str()); MODEL_ASSERT(lt.ne_tensor == nullptr); // if this fails, we called get_tensor twice on the same tensor @@ -1503,16 +1503,31 @@ struct model_model_loader { size_t done_size = 0; for (model_load_tensor& lt : tensors_map.tensors) { - if (lt.ne_tensor->backend != NE_BACKEND_CPU) { - continue; - } if (progress_callback) { progress_callback((float)done_size / data_size, progress_callback_user_data); } MODEL_ASSERT(lt.ne_tensor); // unused tensors should have been caught by load_data already lt.data = (uint8_t*)lt.ne_tensor->data; - load_data_for(lt); - lt.ne_tensor->data = lt.data; + if (lt.ne_tensor->backend == NE_BACKEND_CPU) { + load_data_for(lt); + lt.ne_tensor->data = lt.data; + } else { +#ifdef NS_SYCL + lt.data = bestla::utils::amalloc(lt.ne_tensor->size); + load_data_for(lt); + if (lt.ne_tensor->type == NE_TYPE_BTLA) { + void* dptr = NULL; + memcpy(&dptr, lt.ne_tensor->padding, sizeof(dptr)); + bestla_device_load_storage(lt.data, lt.ne_tensor->data, dptr, ne_ctx->dev_ctx->queue); + } else { + bestla_device_memcpy_sync(lt.ne_tensor->data, lt.data, lt.ne_tensor->size, ne_ctx->dev_ctx->queue); + } + + bestla::utils::afree(lt.data); +#else + NE_ASSERT(false); +#endif + } done_size += lt.size; if (use_mmap && lmlock) { lmlock->grow_to(done_size); diff --git a/neural_speed/models/model_utils/model_types.h b/neural_speed/models/model_utils/model_types.h index e92c9ea7c..c6500d1a1 100644 --- a/neural_speed/models/model_utils/model_types.h +++ b/neural_speed/models/model_utils/model_types.h @@ -200,12 +200,18 @@ struct kv_seq_cell { struct model_kv_cache { struct ne_tensor* k = nullptr; struct ne_tensor* v = nullptr; + struct ne_tensor* k_d = nullptr; + struct ne_tensor* v_d = nullptr; struct ne_tensor* cossin = nullptr; // cached cos/sin value for shifting RoPE struct ne_context* ctx = nullptr; model_ctx_buffer buf; + void* device_buf = nullptr; + size_t device_size = 0; + + int n_gpu_layer = 0; int n; // number of tokens currently in the cache bool has_shift = false; // ring-buffer (for too long text generation like streaming-llm) @@ -350,6 +356,8 @@ struct model_context { int buf_last = 0; size_t buf_max_size[MODEL_MAX_SCRATCH_BUFFERS] = {0}; + ne_sycl_context* dev_ctx = NULL; + void use_buf(struct ne_context* ctx, int i) { #if defined(MODEL_USE_SCRATCH) size_t last_size = 0; @@ -461,6 +469,7 @@ struct model_context_params { model_progress_callback progress_callback; // context pointer passed to the progress callback void* progress_callback_user_data; + ne_sycl_context* dev_ctx; }; class model_name_to_arch { diff --git a/neural_speed/models/model_utils/model_utils.cpp b/neural_speed/models/model_utils/model_utils.cpp index bd799b173..72e570959 100644 --- a/neural_speed/models/model_utils/model_utils.cpp +++ b/neural_speed/models/model_utils/model_utils.cpp @@ -60,11 +60,13 @@ // non-null pointer of model for kv-cache as components of model->layers[il] (e.g. chatglm) static bool kv_cache_init(const struct model_hparams& hparams, struct model_kv_cache& cache, // NOLINT const ne_type wtype, const int n_ctx, const int batch_size, const int beam_size, - const bool shift_roped_k, model_struct* model) { + const bool shift_roped_k, model_struct* model, ne_sycl_context* dev_ctx) { const auto n_layer = hparams.n_layer; auto heads_kv = hparams.n_head_kv > 0 ? hparams.n_head_kv : hparams.n_head; const auto head_size = hparams.n_embd_head_k == 0 ? hparams.n_embd / hparams.n_head : hparams.n_embd_head_k; - + if (cache.n_gpu_layer) { + NE_ASSERT(wtype != NE_TYPE_BTLA); + } #ifdef NS_TP_MODEL // when use TP, cached kv will also have smaller size parallel_context* p_ctx = init_parallel_context(); @@ -77,8 +79,11 @@ static bool kv_cache_init(const struct model_hparams& hparams, struct model_kv_c int64_t layer_ne_k = batch_size * beam_size * k_size; int64_t layer_ne_v = batch_size * beam_size * v_size; const auto wsize = wtype == NE_TYPE_BTLA ? 1 : ne_type_size(wtype); - - cache.buf.resize(n_layer * (layer_ne_k + layer_ne_v) * wsize + 2u * MB); + int n_cpu_layer = n_layer - cache.n_gpu_layer; + n_cpu_layer = n_cpu_layer < 0 ? 0 : n_cpu_layer; + size_t size_cpu = (size_t)n_cpu_layer * (layer_ne_k + layer_ne_v) * wsize + 2u * MB; + size_t size_gpu = (size_t)cache.n_gpu_layer * (layer_ne_k + layer_ne_v) * wsize + 2u * MB; + cache.buf.resize(size_cpu); cache.seq_cells.resize(batch_size * beam_size); for (int i = 0; i < cache.seq_cells.size(); ++i) { cache.seq_cells[i].token_cells.resize(n_ctx); @@ -95,25 +100,30 @@ static bool kv_cache_init(const struct model_hparams& hparams, struct model_kv_c fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__); return false; } + model_alloc_sycl_mem(dev_ctx, size_gpu); + cache.ctx->dev_ctx = dev_ctx; // NE_TYPE_BTLA can not be allocated memory const auto wtype_alloc = wtype == NE_TYPE_BTLA ? NE_TYPE_I8 : wtype; if (model) { // non-null param of model for kv-cache as components of model->layers[il] for (int il = 0; il < n_layer; ++il) { + const ne_backend backend = il < n_cpu_layer ? NE_BACKEND_CPU : NE_BACKEND_SYCL; auto& k_cache = model->layers[il].k_cache; auto& v_cache = model->layers[il].v_cache; if (wtype == NE_TYPE_F16) { // chatglm does not support fp32 kv-cache in original impl of chatglm_util.cpp const auto head_size = hparams.n_embd_head_k == 0 ? hparams.n_embd / hparams.n_head : hparams.n_embd_head_k; const int heads_kv = hparams.multi_query_group_num > 0 ? hparams.multi_query_group_num : hparams.n_head; - k_cache = d_ne_new_tensor_4d(model->ctx, NE_TYPE_F16, head_size, n_ctx, heads_kv, batch_size * beam_size); - v_cache = d_ne_new_tensor_4d(model->ctx, NE_TYPE_F16, n_ctx, head_size, heads_kv, batch_size * beam_size); + k_cache = ne_new_tensor_4d(model->ctx, NE_TYPE_F16, head_size, n_ctx, heads_kv, batch_size * beam_size, + NE_SIZE_CALC, backend); + v_cache = ne_new_tensor_4d(model->ctx, NE_TYPE_F16, n_ctx, head_size, heads_kv, batch_size * beam_size, + NE_SIZE_CALC, backend); } else if (wtype == NE_TYPE_BTLA) { - k_cache = ne_new_tensor_1d(model->ctx, wtype_alloc, layer_ne_k + NE_ALIGNMENT, NE_SIZE_CALC); + k_cache = ne_new_tensor_1d(model->ctx, wtype_alloc, layer_ne_k + NE_ALIGNMENT, NE_SIZE_CALC, backend); const auto k_align_off = reinterpret_cast(k_cache->data) % NE_ALIGNMENT; k_cache = ne_view_1d(model->ctx, k_cache, layer_ne_k, NE_ALIGNMENT - k_align_off); k_cache->type = wtype; - v_cache = ne_new_tensor_1d(model->ctx, wtype_alloc, layer_ne_v + NE_ALIGNMENT, NE_SIZE_CALC); + v_cache = ne_new_tensor_1d(model->ctx, wtype_alloc, layer_ne_v + NE_ALIGNMENT, NE_SIZE_CALC, backend); const auto v_align_off = reinterpret_cast(v_cache->data) % NE_ALIGNMENT; v_cache = ne_view_1d(model->ctx, v_cache, layer_ne_v, NE_ALIGNMENT - v_align_off); v_cache->type = wtype; @@ -126,21 +136,35 @@ static bool kv_cache_init(const struct model_hparams& hparams, struct model_kv_c const bool run_mha_reordered = model->layers[0].k_cache->type == NE_TYPE_BTLA; fprintf(stderr, "%s: run_mha_reordered = %d\n", __func__, run_mha_reordered); } else { - cache.k = ne_new_tensor_1d(cache.ctx, wtype_alloc, n_layer * layer_ne_k + NE_ALIGNMENT, NE_SIZE_CALC); - const auto k_align_off = reinterpret_cast(cache.k->data) % NE_ALIGNMENT; - cache.k = ne_view_1d(cache.ctx, cache.k, n_layer * layer_ne_k, NE_ALIGNMENT - k_align_off); - cache.k->type = wtype; - cache.v = ne_new_tensor_1d(cache.ctx, wtype_alloc, n_layer * layer_ne_v + NE_ALIGNMENT, NE_SIZE_CALC); - const auto v_align_off = reinterpret_cast(cache.v->data) % NE_ALIGNMENT; - cache.v = ne_view_1d(cache.ctx, cache.v, n_layer * layer_ne_v, NE_ALIGNMENT - v_align_off); - cache.v->type = wtype; - ne_set_name(cache.k, "cache_k"); - ne_set_name(cache.v, "cache_v"); + if (n_cpu_layer) { + cache.k = ne_new_tensor_1d(cache.ctx, wtype_alloc, n_cpu_layer * layer_ne_k + NE_ALIGNMENT, NE_SIZE_CALC, + NE_BACKEND_CPU); + const auto k_align_off = reinterpret_cast(cache.k->data) % NE_ALIGNMENT; + cache.k = ne_view_1d(cache.ctx, cache.k, n_layer * layer_ne_k, NE_ALIGNMENT - k_align_off); + cache.k->type = wtype; + cache.v = ne_new_tensor_1d(cache.ctx, wtype_alloc, n_cpu_layer * layer_ne_v + NE_ALIGNMENT, NE_SIZE_CALC, + NE_BACKEND_CPU); + const auto v_align_off = reinterpret_cast(cache.v->data) % NE_ALIGNMENT; + cache.v = ne_view_1d(cache.ctx, cache.v, n_layer * layer_ne_v, NE_ALIGNMENT - v_align_off); + cache.v->type = wtype; + ne_set_name(cache.k, "cache_k"); + ne_set_name(cache.v, "cache_v"); + } + if (cache.n_gpu_layer) { + cache.k_d = ne_new_tensor_1d(cache.ctx, wtype_alloc, cache.n_gpu_layer * layer_ne_k + NE_ALIGNMENT, NE_SIZE_CALC, + NE_BACKEND_SYCL); + cache.k_d->type = wtype; + cache.v_d = ne_new_tensor_1d(cache.ctx, wtype_alloc, cache.n_gpu_layer * layer_ne_v + NE_ALIGNMENT, NE_SIZE_CALC, + NE_BACKEND_SYCL); + cache.v_d->type = wtype; + ne_set_name(cache.k_d, "cache_k_dev"); + ne_set_name(cache.v_d, "cache_v_dev"); + } } if (shift_roped_k) { // prepare rope helper for fused-attention const auto cossin_dtype = wtype == NE_TYPE_BTLA ? NE_TYPE_F16 : wtype; - cache.cossin = ne_new_tensor_1d(cache.ctx, cossin_dtype, head_size, NE_SIZE_CALC); + cache.cossin = ne_new_tensor_1d(cache.ctx, cossin_dtype, head_size, NE_SIZE_CALC, NE_BACKEND_CPU); ne_set_name(cache.cossin, "cossin(-1)"); float freq_base = hparams.freq_base; float theta = -1 * hparams.freq_scale; @@ -169,6 +193,55 @@ static bool kv_cache_init(const struct model_hparams& hparams, struct model_kv_c return true; } +// non-null pointer of model for kv-cache as components of model->layers[il] (e.g. chatglm) +static bool kv_cache_device_init(const struct model_hparams& hparams, struct model_kv_cache& cache, // NOLINT + const ne_type _wtype, const int n_ctx, const int batch_size, const int beam_size, + const bool shift_roped_k, model_struct* model, void* device_queue) { + const auto n_layer = hparams.n_layer; + auto heads_kv = hparams.n_head_kv > 0 ? hparams.n_head_kv : hparams.n_head; + const auto head_size = hparams.n_embd_head_k == 0 ? hparams.n_embd / hparams.n_head : hparams.n_embd_head_k; + + int32_t k_size, v_size; + auto wtype = NE_TYPE_F32; + assert(model == NULL); + assert(!shift_roped_k); + + get_batch_kv_elements_from_gpt_params(heads_kv, head_size, n_ctx, wtype, &k_size, &v_size); + + int64_t layer_ne_k = batch_size * beam_size * k_size; + int64_t layer_ne_v = batch_size * beam_size * v_size; + const auto wsize = ne_type_size(wtype); + cache.device_size = n_layer * (layer_ne_k + layer_ne_v) * wsize; + cache.device_size = (cache.device_size + 255) / 256 * 256; +#ifdef NS_SYCL + cache.device_buf = bestla_device_malloc(cache.device_size, device_queue); +#else + cache.device_buf = nullptr; +#endif + cache.seq_cells.resize(batch_size * beam_size); + for (int i = 0; i < cache.seq_cells.size(); ++i) { + cache.seq_cells[i].token_cells.resize(n_ctx); + } + + struct ne_init_params params; + params.mem_size = 2u * MB; + params.mem_buffer = nullptr; + params.no_alloc = false; + + cache.ctx = ne_init(params); + + if (!cache.ctx) { + fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__); + return false; + } + // NE_TYPE_BTLA can not be allocated memory + cache.k = ne_new_tensor_1d(cache.ctx, wtype, n_layer * layer_ne_k, NE_SIZE_CALC, NE_BACKEND_SYCL); + cache.v = ne_new_tensor_1d(cache.ctx, wtype, n_layer * layer_ne_v, NE_SIZE_CALC, NE_BACKEND_SYCL); + ne_set_name(cache.k, "cache_k"); + ne_set_name(cache.v, "cache_v"); + return true; +} + struct model_context_params model_context_default_params() { struct model_context_params result = { /*.arch =*/MODEL_LLAMA, @@ -213,6 +286,56 @@ void model_init_backend() { } } +ne_sycl_context* model_init_sycl(bool profile) { +#ifdef NS_SYCL + auto ctx = new ne_sycl_context; + auto dev = bestla_create_device(profile); + NE_ASSERT(dev != NULL); + auto queue = bestla_get_device_queue(dev); + memset(ctx->buffers, 0, sizeof(ctx->buffers)); + memset(ctx->sizes, 0, sizeof(ctx->sizes)); + memset(ctx->offs, 0, sizeof(ctx->offs)); + ctx->dev = dev; + ctx->queue = queue; + ctx->n_buffers = 0; + return ctx; +#else + return nullptr; +#endif +} + +void model_alloc_sycl_mem(ne_sycl_context* ctx, size_t size) { +#ifdef NS_SYCL + if (ctx && size) { + auto gsize = bestla_device_gmem_size(ctx->dev); + NE_ASSERT(gsize >= size); + int num_buffers = (size + MAX_SYCL_BUFFER_SIZE - 1) / MAX_SYCL_BUFFER_SIZE; + NE_ASSERT(num_buffers > 0 && num_buffers + ctx->n_buffers < MAX_SYCL_BUFFER_COUNT); + for (size_t i = 0; i < num_buffers; i++) { + size_t size_to_alloc = size > MAX_SYCL_BUFFER_SIZE ? MAX_SYCL_BUFFER_SIZE : size; + ctx->buffers[i + ctx->n_buffers] = bestla_device_malloc(size_to_alloc, ctx->queue); + ctx->sizes[i + ctx->n_buffers] = size_to_alloc; + NE_ASSERT(ctx->buffers[i + ctx->n_buffers]); + size -= MAX_SYCL_BUFFER_SIZE; + } + ctx->n_buffers += num_buffers; + } + +#endif +} + +void model_release_sycl(ne_sycl_context* ctx) { +#ifdef NS_SYCL + if (ctx) { + for (size_t i = 0; i < ctx->n_buffers; i++) { + bestla_device_free(ctx->buffers[i], ctx->queue); + } + bestla_release_device(ctx->dev); + delete ctx; + } +#endif +} + int64_t model_time_us() { return ne_time_us(); } // @@ -875,6 +998,7 @@ struct model_context* model_init_from_file(const char* path_model, struct model_ ne_time_init(); model_context* ctx = new model_context; + ctx->dev_ctx = params.dev_ctx; if (params.seed < 0) { params.seed = time(nullptr); @@ -913,7 +1037,6 @@ struct model_context* model_init_from_file(const char* path_model, struct model_ } ctx->cont_batching = params.cont_batching; ctx->generation_conf = params.gen_conf; - ctx->scratch_size_ratio = params.scratch_size_ratio * params.max_request_num * params.beam_size; const model_archs arch = params.arch; @@ -946,17 +1069,18 @@ struct model_context* model_init_from_file(const char* path_model, struct model_ const bool support_bestla_kv = ctx->support_bestla_kv && bestla_reordered_attn_fp32_support(&attn_shape); fprintf(stderr, "%s: support_bestla_kv = %d\n", __func__, support_bestla_kv); - const ne_type memory_type = params.kv_type == KV_MEM_TYPE_F16 ? NE_TYPE_F16 - : params.kv_type == KV_MEM_TYPE_F32 ? NE_TYPE_F32 - : params.kv_type == KV_MEM_TYPE_AUTO - ? (support_bestla_kv ? NE_TYPE_BTLA : NE_TYPE_F16) // fall back to fp16 - : NE_TYPE_COUNT; - NE_ASSERT(memory_type != NE_TYPE_COUNT); + ne_type memory_type = params.kv_type == KV_MEM_TYPE_F16 ? NE_TYPE_F16 + : params.kv_type == KV_MEM_TYPE_F32 ? NE_TYPE_F32 + : params.kv_type == KV_MEM_TYPE_AUTO + ? (support_bestla_kv ? NE_TYPE_BTLA : NE_TYPE_F16) // fall back to fp16 + : NE_TYPE_COUNT; + NE_ASSERT(memory_type != NE_TYPE_COUNT); const bool kv_in_layers = (arch == MODEL_CHATGLM3 || arch == MODEL_CHATGLM2 || arch == MODEL_CHATGLM || arch == MODEL_BAICHUAN); + ctx->model.kv_self.n_gpu_layer = params.n_gpu_layers; if (!kv_cache_init(ctx->model.hparams, ctx->model.kv_self, memory_type, ctx->n_ctx, ctx->max_request_num, - ctx->beam_size, params.shift_roped_k, (kv_in_layers ? &ctx->model : nullptr))) { + ctx->beam_size, params.shift_roped_k, (kv_in_layers ? &ctx->model : nullptr), ctx->dev_ctx)) { fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__); model_free(ctx); return nullptr; @@ -966,16 +1090,22 @@ struct model_context* model_init_from_file(const char* path_model, struct model_ const size_t memory_size = params.kv_type == KV_MEM_TYPE_AUTO ? ne_nelements(ctx->model.kv_self.k) + ne_nelements(ctx->model.kv_self.v) : ne_nbytes(ctx->model.kv_self.k) + ne_nbytes(ctx->model.kv_self.v); - fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); - } else if (ctx->model.layers[0].k_cache != nullptr) { + fprintf(stderr, "%s: cpu kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); + } + if (ctx->model.kv_self.k_d != nullptr) { + const size_t g_memory_size = params.kv_type == KV_MEM_TYPE_AUTO + ? ne_nelements(ctx->model.kv_self.k_d) + ne_nelements(ctx->model.kv_self.v_d) + : ne_nbytes(ctx->model.kv_self.k_d) + ne_nbytes(ctx->model.kv_self.v_d); + fprintf(stderr, "%s: gpu kv self size = %7.2f MB\n", __func__, g_memory_size / 1024.0 / 1024.0); + } + if (ctx->model.layers[0].k_cache != nullptr) { const auto k_cache = ctx->model.layers[0].k_cache; const auto v_cache = ctx->model.layers[0].v_cache; const size_t layer_memory_size = params.kv_type == KV_MEM_TYPE_AUTO ? ne_nelements(k_cache) + ne_nelements(v_cache) : ne_nbytes(k_cache) + ne_nbytes(v_cache); + fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, layer_memory_size / 1024.0 / 1024.0 * hparams.n_layer); - } else { - NE_ASSERT(("KV-cache not allocated!", false)); } ctx->model.hparams.mha_prefer_f32 = params.mha_prefer_f32; @@ -989,11 +1119,22 @@ struct model_context* model_init_from_file(const char* path_model, struct model_ if (params.embedding) { ctx->embedding.resize(hparams.n_embd); } - + size_t act_mem_per_layer = ctx->batch_size * ctx->beam_size * + (32ULL * ctx->n_ctx * hparams.n_embd + 2ULL * ctx->n_ctx * hparams.ffn_hidden_size) * + ne_type_size(NE_TYPE_F32) + + (10 << 20); +#ifdef NS_SYCL + ctx->buf_compute.resize(act_mem_per_layer); + ctx->buf_scratch[0].resize(act_mem_per_layer); + fprintf(stderr, "%s: cpu activation size = %7.2f MB\n", __func__, act_mem_per_layer / 1024.0 / 1024.0); + model_alloc_sycl_mem(ctx->dev_ctx, act_mem_per_layer); + fprintf(stderr, "%s: gpu activation size = %7.2f MB\n", __func__, act_mem_per_layer / 1024.0 / 1024.0); +#else ctx->buf_compute.resize(ctx->model.scratchs.eval); ctx->buf_scratch[0].resize(ctx->model.scratchs.scratch0); ctx->buf_scratch[1].resize(ctx->model.scratchs.scratch1); +#endif } return ctx; @@ -1147,7 +1288,7 @@ int model_apply_lora_from_file_internal(struct model_context* ctx, const char* p } ne_tensor* lora_tensor; if (n_dims == 2) { - lora_tensor = ne_new_tensor_2d(lora_ctx, wtype, ne[0], ne[1], NE_SIZE_CALC); + lora_tensor = ne_new_tensor_2d(lora_ctx, wtype, ne[0], ne[1], NE_SIZE_CALC, NE_BACKEND_CPU); } else { fprintf(stderr, "%s: unsupported tensor dimension %d\n", __func__, n_dims); return 1; @@ -1261,7 +1402,7 @@ int model_apply_lora_from_file(struct model_context* ctx, const char* path_lora, } } -struct model_context* model_init_from_gpt_params(const gpt_params& params) { +struct model_context* model_init_from_gpt_params(const gpt_params& params, ne_sycl_context* dev_ctx) { if (params.model_arch == MODEL_UNKNOWN) { fprintf(stderr, "error, please set model_name \n"); exit(0); @@ -1275,6 +1416,10 @@ struct model_context* model_init_from_gpt_params(const gpt_params& params) { lparams.seed = params.seed; lparams.kv_type = params.memory_type; lparams.mha_prefer_f32 = params.mha_prefer_f32; + lparams.dev_ctx = dev_ctx; + if (dev_ctx) { + lparams.kv_type = KV_MEM_TYPE_F32; + } // TODO(Yi): MHA FOR LONG TOKENS int32_t long_tokens = 6144; @@ -1325,8 +1470,10 @@ struct model_context* model_init_from_gpt_params(const gpt_params& params) { /* .sl_q = */ 1, // Note: make sure that bestla reordered attn supports next token inferencing /* .sl_kv = */ static_cast(lparams.n_ctx), }; - const auto k_cache_example = lctx->model.kv_self.k != nullptr ? lctx->model.kv_self.k // llama.cpp style - : lctx->model.layers[0].k_cache; // chatglm style + auto k_cache_example = lctx->model.kv_self.k != nullptr ? lctx->model.kv_self.k // llama.cpp style + : lctx->model.layers[0].k_cache; // chatglm style + k_cache_example = k_cache_example == nullptr ? lctx->model.kv_self.k_d // llama.cpp style + : k_cache_example; // chatglm style NE_ASSERT(k_cache_example->type != NE_TYPE_BTLA || bestla_reordered_attn_fp32_support(&attn_shape)); if (lctx == nullptr) { @@ -1476,11 +1623,13 @@ size_t model_copy_state_data(struct model_context* ctx, uint8_t* dst) { ne_cgraph gf{}; gf.n_threads = 1; - ne_tensor* kout3d = ne_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer, NE_SIZE_CALC); + ne_tensor* kout3d = + ne_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer, NE_SIZE_CALC, NE_BACKEND_CPU); kout3d->data = out; out += ne_nbytes(kout3d); - ne_tensor* vout3d = ne_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer, NE_SIZE_CALC); + ne_tensor* vout3d = + ne_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer, NE_SIZE_CALC, NE_BACKEND_CPU); vout3d->data = out; out += ne_nbytes(vout3d); @@ -1589,11 +1738,13 @@ size_t model_set_state_data(struct model_context* ctx, uint8_t* src) { ne_cgraph gf{}; gf.n_threads = 1; - ne_tensor* kin3d = ne_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer, NE_SIZE_CALC); + ne_tensor* kin3d = + ne_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer, NE_SIZE_CALC, NE_BACKEND_CPU); kin3d->data = reinterpret_cast(inp); inp += ne_nbytes(kin3d); - ne_tensor* vin3d = ne_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer, NE_SIZE_CALC); + ne_tensor* vin3d = + ne_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer, NE_SIZE_CALC, NE_BACKEND_CPU); vin3d->data = reinterpret_cast(inp); inp += ne_nbytes(vin3d); @@ -1943,7 +2094,7 @@ static ne_tensor* ne_model_kv_cache_seq_concat(struct ne_cgraph* cgraph, struct continue; } else { if (dst == nullptr) { - dst = ne_new_tensor_4d(nectx, cache->type, ne0, ne1, ne2, ne3, NE_SIZE_CALC); + dst = ne_new_tensor_4d(nectx, cache->type, ne0, ne1, ne2, ne3, NE_SIZE_CALC, NE_BACKEND_CPU); } struct ne_tensor* dst_i = ne_view_4d(nectx, dst, ne0, ne1, ne2, cont_bs, elem_size * ne0, elem_size * ne0 * ne1, elem_size * ne0 * ne1 * ne2, dst_off); diff --git a/neural_speed/models/model_utils/model_utils.h b/neural_speed/models/model_utils/model_utils.h index e893553dc..ecf4a3b0d 100644 --- a/neural_speed/models/model_utils/model_utils.h +++ b/neural_speed/models/model_utils/model_utils.h @@ -79,6 +79,10 @@ MODEL_API bool model_mlock_supported(); // Call once at the start of the program MODEL_API void model_init_backend(); +MODEL_API ne_sycl_context* model_init_sycl(bool profile); +MODEL_API void model_alloc_sycl_mem(ne_sycl_context*, size_t size); +MODEL_API void model_release_sycl(ne_sycl_context* ctx); + MODEL_API int64_t model_time_us(); // Various functions for loading a ne model model. diff --git a/neural_speed/models/model_utils/scheduler.cpp b/neural_speed/models/model_utils/scheduler.cpp index dc6c2459f..674509139 100644 --- a/neural_speed/models/model_utils/scheduler.cpp +++ b/neural_speed/models/model_utils/scheduler.cpp @@ -15,7 +15,7 @@ #include "models/model_utils/scheduler.h" // Iter_level_worker -Iter_level_worker::Iter_level_worker(const gpt_params& params) : m_ctx(model_init_from_gpt_params(params)) { +Iter_level_worker::Iter_level_worker(const gpt_params& params) : m_ctx(model_init_from_gpt_params(params, nullptr)) { if (m_ctx == nullptr) { fprintf(stderr, "%s: error: unable to load model.\n", __func__); exit(0); diff --git a/neural_speed/models/mpt/mpt.cpp b/neural_speed/models/mpt/mpt.cpp index 2731ce33e..7de22894a 100644 --- a/neural_speed/models/mpt/mpt.cpp +++ b/neural_speed/models/mpt/mpt.cpp @@ -163,8 +163,7 @@ static bool mpt_model_eval_internal(model_context* ctx, const model_input* input // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, // 2, 1, 3) [64, N, 12] struct ne_tensor* Q = ne_permute( - ctx0, ne_cpy(ctx0, Qcur, ne_new_tensor_3d(ctx0, NE_TYPE_F32, n_embd / n_head, n_head, N, NE_SIZE_CALC)), 0, 2, - 1, 3); + ctx0, ne_cpy(ctx0, Qcur, d_ne_new_tensor_3d(ctx0, NE_TYPE_F32, n_embd / n_head, n_head, N)), 0, 2, 1, 3); // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, // 3) [64, n_past + N, 12] @@ -201,7 +200,7 @@ static bool mpt_model_eval_internal(model_context* ctx, const model_input* input struct ne_tensor* KQV_merged = ne_permute(ctx0, KQV, 0, 2, 1, 3); // cur = KQV_merged.contiguous().view(n_embd, N) - cur = ne_cpy(ctx0, KQV_merged, ne_new_tensor_2d(ctx0, NE_TYPE_F32, n_embd, N, NE_SIZE_CALC)); + cur = ne_cpy(ctx0, KQV_merged, d_ne_new_tensor_2d(ctx0, NE_TYPE_F32, n_embd, N)); } else { const auto seq_kv = n_past + N; const auto k_size = kv_cache_info.k_bytes; diff --git a/neural_speed/models/opt/opt.cpp b/neural_speed/models/opt/opt.cpp index 64d7980f4..8bfbc5d95 100644 --- a/neural_speed/models/opt/opt.cpp +++ b/neural_speed/models/opt/opt.cpp @@ -224,7 +224,7 @@ static bool opt_model_eval_internal(model_context* ctx, const model_input* input struct ne_tensor* KQV_merged = ne_permute(ctx0, KQV, 0, 2, 1, 3); // [n_embd, N] - cur = ne_cpy(ctx0, KQV_merged, ne_new_tensor_2d(ctx0, NE_TYPE_F32, n_embd, N, NE_SIZE_CALC)); + cur = ne_cpy(ctx0, KQV_merged, d_ne_new_tensor_2d(ctx0, NE_TYPE_F32, n_embd, N)); } // attn out projection diff --git a/neural_speed/models/phi/phi.cpp b/neural_speed/models/phi/phi.cpp index a177bc11b..89a2b6b6d 100644 --- a/neural_speed/models/phi/phi.cpp +++ b/neural_speed/models/phi/phi.cpp @@ -256,7 +256,7 @@ static bool phi2_model_eval_internal(model_context* ctx, const model_input* inpu struct ne_tensor* KQV_merged = ne_permute(ctx0, KQV, 0, 2, 1, 3); // cur = KQV_merged.contiguous().view(n_embd, N) - cur = ne_cpy(ctx0, KQV_merged, ne_new_tensor_2d(ctx0, NE_TYPE_F32, n_embd, N * batch_size, NE_SIZE_CALC)); + cur = ne_cpy(ctx0, KQV_merged, d_ne_new_tensor_2d(ctx0, NE_TYPE_F32, n_embd, N * batch_size)); } else { const auto seq_kv = n_past + N; const auto k_size = kv_cache_info.k_bytes; diff --git a/neural_speed/models/phi/phi3.cpp b/neural_speed/models/phi/phi3.cpp index 0f3219933..91d0a904d 100644 --- a/neural_speed/models/phi/phi3.cpp +++ b/neural_speed/models/phi/phi3.cpp @@ -117,7 +117,7 @@ static bool phi3_model_eval_internal(model_context* ctx, const model_input* inpu bestla_reordered_attn_fp32_batch_kv_info(&kv_shape, &kv_cache_info); } struct ne_tensor* embd = d_ne_new_tensor_1d(ctx0, NE_TYPE_I32, N * batch_size); - struct ne_tensor* factor = ne_new_tensor_1d(ctx0, NE_TYPE_F32, 48, sizeof(float)); + struct ne_tensor* factor = d_ne_new_tensor_1d(ctx0, NE_TYPE_F32, 48); const float longfactor[48] = { 1.0299999713897705, 1.0499999523162842, 1.0499999523162842, 1.0799999237060547, 1.2299998998641968, 1.2299998998641968, 1.2999999523162842, 1.4499999284744263, 1.5999999046325684, 1.6499998569488525, @@ -265,7 +265,7 @@ static bool phi3_model_eval_internal(model_context* ctx, const model_input* inpu struct ne_tensor* KQV_merged = ne_permute(ctx0, KQV, 0, 2, 1, 3); // cur = KQV_merged.contiguous().view(n_embd, N) - cur = ne_cpy(ctx0, KQV_merged, ne_new_tensor_2d(ctx0, NE_TYPE_F32, n_embd, N * batch_size, NE_SIZE_CALC)); + cur = ne_cpy(ctx0, KQV_merged, d_ne_new_tensor_2d(ctx0, NE_TYPE_F32, n_embd, N * batch_size)); } else { const auto seq_kv = n_past + N; const auto k_size = kv_cache_info.k_bytes; diff --git a/neural_speed/models/qwen/qwen.cpp b/neural_speed/models/qwen/qwen.cpp index f70365997..575ae06ff 100644 --- a/neural_speed/models/qwen/qwen.cpp +++ b/neural_speed/models/qwen/qwen.cpp @@ -282,7 +282,7 @@ static bool qwen_model_eval_internal(model_context* ctx, const model_input* inpu struct ne_tensor* KQV_merged = ne_permute(ctx0, KQV, 0, 2, 1, 3); // cur = KQV_merged.contiguous().view(n_embd, N) - cur = ne_cpy(ctx0, KQV_merged, ne_new_tensor_2d(ctx0, NE_TYPE_F32, n_embd, N * batch_size, NE_SIZE_CALC)); + cur = ne_cpy(ctx0, KQV_merged, d_ne_new_tensor_2d(ctx0, NE_TYPE_F32, n_embd, N * batch_size)); } else { const auto seq_kv = n_past + N; const auto k_size = kv_cache_info.k_bytes; diff --git a/neural_speed/models/stablelm/stablelm.cpp b/neural_speed/models/stablelm/stablelm.cpp index 20c843b37..6a4dc0713 100644 --- a/neural_speed/models/stablelm/stablelm.cpp +++ b/neural_speed/models/stablelm/stablelm.cpp @@ -259,7 +259,7 @@ static bool stablelm_model_eval_internal(model_context* ctx, const model_input* struct ne_tensor* KQV_merged = ne_permute(ctx0, KQV, 0, 2, 1, 3); // cur = KQV_merged.contiguous().view(n_embd, N) - cur = ne_cpy(ctx0, KQV_merged, ne_new_tensor_2d(ctx0, NE_TYPE_F32, n_embd, N * batch_size, NE_SIZE_CALC)); + cur = ne_cpy(ctx0, KQV_merged, d_ne_new_tensor_2d(ctx0, NE_TYPE_F32, n_embd, N * batch_size)); } else { const auto seq_kv = n_past + N; const auto k_size = kv_cache_info.k_bytes; @@ -435,4 +435,4 @@ int model_eval(struct model_context* ctx, const model_input* inputs, const int n } return 0; -} \ No newline at end of file +} diff --git a/neural_speed/models/starcoder/starcoder.cpp b/neural_speed/models/starcoder/starcoder.cpp index 113d762ec..23d0b6d35 100644 --- a/neural_speed/models/starcoder/starcoder.cpp +++ b/neural_speed/models/starcoder/starcoder.cpp @@ -248,7 +248,7 @@ static bool starcoder_model_eval_internal(model_context* ctx, const model_input* // cur = KQV_merged.contiguous().view(n_embd, N) // [768, N] - cur = ne_cpy(ctx0, KQV_merged, ne_new_tensor_2d(ctx0, NE_TYPE_F32, n_embd, N, NE_SIZE_CALC)); + cur = ne_cpy(ctx0, KQV_merged, d_ne_new_tensor_2d(ctx0, NE_TYPE_F32, n_embd, N)); } else { const auto seq_kv = n_past + N; const auto k_size = kv_cache_info.k_bytes; diff --git a/neural_speed/models/whisper/whisper.cpp b/neural_speed/models/whisper/whisper.cpp index 1865094ac..52a353253 100644 --- a/neural_speed/models/whisper/whisper.cpp +++ b/neural_speed/models/whisper/whisper.cpp @@ -565,8 +565,8 @@ static bool kv_cache_init(const struct whisper_hparams_t& hparams, const size_t const int n_mem = n_text_layer * n_ctx; const int n_elements = n_text_state * n_mem; - cache->k = ne_new_tensor_1d(cache->ctx, wtype, n_elements, NE_SIZE_CALC); - cache->v = ne_new_tensor_1d(cache->ctx, wtype, n_elements, NE_SIZE_CALC); + cache->k = d_ne_new_tensor_1d(cache->ctx, wtype, n_elements); + cache->v = d_ne_new_tensor_1d(cache->ctx, wtype, n_elements); return true; } @@ -595,8 +595,8 @@ static bool kv_cache_reinit(struct whisper_kv_cache_t* cache) { return false; } - cache->k = ne_new_tensor_1d(cache->ctx, wtype, n_elements, NE_SIZE_CALC); - cache->v = ne_new_tensor_1d(cache->ctx, wtype, n_elements, NE_SIZE_CALC); + cache->k = d_ne_new_tensor_1d(cache->ctx, wtype, n_elements); + cache->v = d_ne_new_tensor_1d(cache->ctx, wtype, n_elements); return true; } @@ -965,16 +965,16 @@ static bool whisper_model_load(struct whisper_model_loader* loader, whisper_cont // encoder { - model.e_pe = ne_new_tensor_2d(ctx, NE_TYPE_F32, n_audio_state, n_audio_ctx, NE_SIZE_CALC); + model.e_pe = d_ne_new_tensor_2d(ctx, NE_TYPE_F32, n_audio_state, n_audio_ctx); - model.e_conv_1_w = ne_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state, NE_SIZE_CALC); - model.e_conv_1_b = ne_new_tensor_2d(ctx, NE_TYPE_F32, 1, n_audio_state, NE_SIZE_CALC); + model.e_conv_1_w = d_ne_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state); + model.e_conv_1_b = d_ne_new_tensor_2d(ctx, NE_TYPE_F32, 1, n_audio_state); - model.e_conv_2_w = ne_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state, NE_SIZE_CALC); - model.e_conv_2_b = ne_new_tensor_2d(ctx, NE_TYPE_F32, 1, n_audio_state, NE_SIZE_CALC); + model.e_conv_2_w = d_ne_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state); + model.e_conv_2_b = d_ne_new_tensor_2d(ctx, NE_TYPE_F32, 1, n_audio_state); - model.e_ln_w = ne_new_tensor_1d(ctx, NE_TYPE_F32, n_audio_state, NE_SIZE_CALC); - model.e_ln_b = ne_new_tensor_1d(ctx, NE_TYPE_F32, n_audio_state, NE_SIZE_CALC); + model.e_ln_w = d_ne_new_tensor_1d(ctx, NE_TYPE_F32, n_audio_state); + model.e_ln_b = d_ne_new_tensor_1d(ctx, NE_TYPE_F32, n_audio_state); // map by name model.tensors["encoder.positional_embedding"] = model.e_pe; @@ -991,28 +991,28 @@ static bool whisper_model_load(struct whisper_model_loader* loader, whisper_cont for (int i = 0; i < n_audio_layer; ++i) { auto& layer = model.layers_encoder[i]; - layer.mlp_ln_w = ne_new_tensor_1d(ctx, NE_TYPE_F32, n_audio_state, NE_SIZE_CALC); - layer.mlp_ln_b = ne_new_tensor_1d(ctx, NE_TYPE_F32, n_audio_state, NE_SIZE_CALC); + layer.mlp_ln_w = d_ne_new_tensor_1d(ctx, NE_TYPE_F32, n_audio_state); + layer.mlp_ln_b = d_ne_new_tensor_1d(ctx, NE_TYPE_F32, n_audio_state); - layer.mlp_0_w = ne_new_tensor_2d(ctx, wtype, n_audio_state, 4 * n_audio_state, NE_SIZE_CALC); - layer.mlp_0_b = ne_new_tensor_1d(ctx, NE_TYPE_F32, 4 * n_audio_state, NE_SIZE_CALC); + layer.mlp_0_w = d_ne_new_tensor_2d(ctx, wtype, n_audio_state, 4 * n_audio_state); + layer.mlp_0_b = d_ne_new_tensor_1d(ctx, NE_TYPE_F32, 4 * n_audio_state); - layer.mlp_1_w = ne_new_tensor_2d(ctx, wtype, 4 * n_audio_state, n_audio_state, NE_SIZE_CALC); - layer.mlp_1_b = ne_new_tensor_1d(ctx, NE_TYPE_F32, n_audio_state, NE_SIZE_CALC); + layer.mlp_1_w = d_ne_new_tensor_2d(ctx, wtype, 4 * n_audio_state, n_audio_state); + layer.mlp_1_b = d_ne_new_tensor_1d(ctx, NE_TYPE_F32, n_audio_state); - layer.attn_ln_0_w = ne_new_tensor_1d(ctx, NE_TYPE_F32, n_audio_state, NE_SIZE_CALC); - layer.attn_ln_0_b = ne_new_tensor_1d(ctx, NE_TYPE_F32, n_audio_state, NE_SIZE_CALC); + layer.attn_ln_0_w = d_ne_new_tensor_1d(ctx, NE_TYPE_F32, n_audio_state); + layer.attn_ln_0_b = d_ne_new_tensor_1d(ctx, NE_TYPE_F32, n_audio_state); - layer.attn_q_w = ne_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state, NE_SIZE_CALC); - layer.attn_q_b = ne_new_tensor_1d(ctx, NE_TYPE_F32, n_audio_state, NE_SIZE_CALC); + layer.attn_q_w = d_ne_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state); + layer.attn_q_b = d_ne_new_tensor_1d(ctx, NE_TYPE_F32, n_audio_state); - layer.attn_k_w = ne_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state, NE_SIZE_CALC); + layer.attn_k_w = d_ne_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state); - layer.attn_v_w = ne_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state, NE_SIZE_CALC); - layer.attn_v_b = ne_new_tensor_1d(ctx, NE_TYPE_F32, n_audio_state, NE_SIZE_CALC); + layer.attn_v_w = d_ne_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state); + layer.attn_v_b = d_ne_new_tensor_1d(ctx, NE_TYPE_F32, n_audio_state); - layer.attn_ln_1_w = ne_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state, NE_SIZE_CALC); - layer.attn_ln_1_b = ne_new_tensor_1d(ctx, NE_TYPE_F32, n_audio_state, NE_SIZE_CALC); + layer.attn_ln_1_w = d_ne_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state); + layer.attn_ln_1_b = d_ne_new_tensor_1d(ctx, NE_TYPE_F32, n_audio_state); // map by name model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w; @@ -1042,12 +1042,12 @@ static bool whisper_model_load(struct whisper_model_loader* loader, whisper_cont // decoder { - model.d_pe = ne_new_tensor_2d(ctx, NE_TYPE_F32, n_text_state, n_text_ctx, NE_SIZE_CALC); + model.d_pe = d_ne_new_tensor_2d(ctx, NE_TYPE_F32, n_text_state, n_text_ctx); - model.d_te = ne_new_tensor_2d(ctx, wtype, n_text_state, n_vocab, NE_SIZE_CALC); + model.d_te = d_ne_new_tensor_2d(ctx, wtype, n_text_state, n_vocab); - model.d_ln_w = ne_new_tensor_1d(ctx, NE_TYPE_F32, n_text_state, NE_SIZE_CALC); - model.d_ln_b = ne_new_tensor_1d(ctx, NE_TYPE_F32, n_text_state, NE_SIZE_CALC); + model.d_ln_w = d_ne_new_tensor_1d(ctx, NE_TYPE_F32, n_text_state); + model.d_ln_b = d_ne_new_tensor_1d(ctx, NE_TYPE_F32, n_text_state); // map by name model.tensors["decoder.positional_embedding"] = model.d_pe; @@ -1060,42 +1060,42 @@ static bool whisper_model_load(struct whisper_model_loader* loader, whisper_cont for (int i = 0; i < n_text_layer; ++i) { auto& layer = model.layers_decoder[i]; - layer.mlp_ln_w = ne_new_tensor_1d(ctx, NE_TYPE_F32, n_text_state, NE_SIZE_CALC); - layer.mlp_ln_b = ne_new_tensor_1d(ctx, NE_TYPE_F32, n_text_state, NE_SIZE_CALC); + layer.mlp_ln_w = d_ne_new_tensor_1d(ctx, NE_TYPE_F32, n_text_state); + layer.mlp_ln_b = d_ne_new_tensor_1d(ctx, NE_TYPE_F32, n_text_state); - layer.mlp_0_w = ne_new_tensor_2d(ctx, wtype, n_text_state, 4 * n_text_state, NE_SIZE_CALC); - layer.mlp_0_b = ne_new_tensor_1d(ctx, NE_TYPE_F32, 4 * n_text_state, NE_SIZE_CALC); + layer.mlp_0_w = d_ne_new_tensor_2d(ctx, wtype, n_text_state, 4 * n_text_state); + layer.mlp_0_b = d_ne_new_tensor_1d(ctx, NE_TYPE_F32, 4 * n_text_state); - layer.mlp_1_w = ne_new_tensor_2d(ctx, wtype, 4 * n_text_state, n_text_state, NE_SIZE_CALC); - layer.mlp_1_b = ne_new_tensor_1d(ctx, NE_TYPE_F32, n_text_state, NE_SIZE_CALC); + layer.mlp_1_w = d_ne_new_tensor_2d(ctx, wtype, 4 * n_text_state, n_text_state); + layer.mlp_1_b = d_ne_new_tensor_1d(ctx, NE_TYPE_F32, n_text_state); - layer.attn_ln_0_w = ne_new_tensor_1d(ctx, NE_TYPE_F32, n_text_state, NE_SIZE_CALC); - layer.attn_ln_0_b = ne_new_tensor_1d(ctx, NE_TYPE_F32, n_text_state, NE_SIZE_CALC); + layer.attn_ln_0_w = d_ne_new_tensor_1d(ctx, NE_TYPE_F32, n_text_state); + layer.attn_ln_0_b = d_ne_new_tensor_1d(ctx, NE_TYPE_F32, n_text_state); - layer.attn_q_w = ne_new_tensor_2d(ctx, wtype, n_text_state, n_text_state, NE_SIZE_CALC); - layer.attn_q_b = ne_new_tensor_1d(ctx, NE_TYPE_F32, n_text_state, NE_SIZE_CALC); + layer.attn_q_w = d_ne_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.attn_q_b = d_ne_new_tensor_1d(ctx, NE_TYPE_F32, n_text_state); - layer.attn_k_w = ne_new_tensor_2d(ctx, wtype, n_text_state, n_text_state, NE_SIZE_CALC); + layer.attn_k_w = d_ne_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); - layer.attn_v_w = ne_new_tensor_2d(ctx, wtype, n_text_state, n_text_state, NE_SIZE_CALC); - layer.attn_v_b = ne_new_tensor_1d(ctx, NE_TYPE_F32, n_text_state, NE_SIZE_CALC); + layer.attn_v_w = d_ne_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.attn_v_b = d_ne_new_tensor_1d(ctx, NE_TYPE_F32, n_text_state); - layer.attn_ln_1_w = ne_new_tensor_2d(ctx, wtype, n_text_state, n_text_state, NE_SIZE_CALC); - layer.attn_ln_1_b = ne_new_tensor_1d(ctx, NE_TYPE_F32, n_text_state, NE_SIZE_CALC); + layer.attn_ln_1_w = d_ne_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.attn_ln_1_b = d_ne_new_tensor_1d(ctx, NE_TYPE_F32, n_text_state); - layer.cross_attn_ln_0_w = ne_new_tensor_1d(ctx, NE_TYPE_F32, n_text_state, NE_SIZE_CALC); - layer.cross_attn_ln_0_b = ne_new_tensor_1d(ctx, NE_TYPE_F32, n_text_state, NE_SIZE_CALC); + layer.cross_attn_ln_0_w = d_ne_new_tensor_1d(ctx, NE_TYPE_F32, n_text_state); + layer.cross_attn_ln_0_b = d_ne_new_tensor_1d(ctx, NE_TYPE_F32, n_text_state); - layer.cross_attn_q_w = ne_new_tensor_2d(ctx, wtype, n_text_state, n_text_state, NE_SIZE_CALC); - layer.cross_attn_q_b = ne_new_tensor_1d(ctx, NE_TYPE_F32, n_text_state, NE_SIZE_CALC); + layer.cross_attn_q_w = d_ne_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.cross_attn_q_b = d_ne_new_tensor_1d(ctx, NE_TYPE_F32, n_text_state); - layer.cross_attn_k_w = ne_new_tensor_2d(ctx, wtype, n_text_state, n_text_state, NE_SIZE_CALC); + layer.cross_attn_k_w = d_ne_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); - layer.cross_attn_v_w = ne_new_tensor_2d(ctx, wtype, n_text_state, n_text_state, NE_SIZE_CALC); - layer.cross_attn_v_b = ne_new_tensor_1d(ctx, NE_TYPE_F32, n_text_state, NE_SIZE_CALC); + layer.cross_attn_v_w = d_ne_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.cross_attn_v_b = d_ne_new_tensor_1d(ctx, NE_TYPE_F32, n_text_state); - layer.cross_attn_ln_1_w = ne_new_tensor_2d(ctx, wtype, n_text_state, n_text_state, NE_SIZE_CALC); - layer.cross_attn_ln_1_b = ne_new_tensor_1d(ctx, NE_TYPE_F32, n_text_state, NE_SIZE_CALC); + layer.cross_attn_ln_1_w = d_ne_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.cross_attn_ln_1_b = d_ne_new_tensor_1d(ctx, NE_TYPE_F32, n_text_state); // map by name model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w; @@ -1261,7 +1261,7 @@ static bool whisper_encode_internal(whisper_context* wctx, whisper_state* wstate wstate->use_buf(ctx0, 0); - struct ne_tensor* mel = ne_new_tensor_2d(ctx0, NE_TYPE_F32, 2 * n_ctx, n_mels, NE_SIZE_CALC); + struct ne_tensor* mel = d_ne_new_tensor_2d(ctx0, NE_TYPE_F32, 2 * n_ctx, n_mels); assert(mel->type == NE_TYPE_F32); { float* dst = reinterpret_cast(mel->data); @@ -1368,29 +1368,27 @@ static bool whisper_encode_internal(whisper_context* wctx, whisper_state* wstate wstate->use_buf(ctx0, 0); #ifdef WHISPER_USE_FLASH_ATTN - struct ne_tensor* Q = ne_permute( - ctx0, ne_cpy(ctx0, Qcur, ne_new_tensor_3d(ctx0, wctx.itype, n_state / n_head, n_head, n_ctx, NE_SIZE_CALC)), - 0, 2, 1, 3); + struct ne_tensor* Q = + ne_permute(ctx0, ne_cpy(ctx0, Qcur, d_ne_new_tensor_3d(ctx0, wctx.itype, n_state / n_head, n_head, n_ctx)), + 0, 2, 1, 3); - struct ne_tensor* K = ne_permute( - ctx0, ne_cpy(ctx0, Kcur, ne_new_tensor_3d(ctx0, wctx.itype, n_state / n_head, n_head, n_ctx, NE_SIZE_CALC)), - 0, 2, 1, 3); + struct ne_tensor* K = + ne_permute(ctx0, ne_cpy(ctx0, Kcur, d_ne_new_tensor_3d(ctx0, wctx.itype, n_state / n_head, n_head, n_ctx)), + 0, 2, 1, 3); struct ne_tensor* V = ne_cpy(ctx0, ne_permute(ctx0, ne_reshape_3d(ctx0, Vcur, n_state / n_head, n_head, n_ctx), 1, 2, 0, 3), - ne_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state / n_head, n_head, NE_SIZE_CALC)); + d_ne_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state / n_head, n_head)); struct ne_tensor* KQV = ggml_flash_attn(ctx0, Q, K, V, false); #else - struct ne_tensor* Q = ne_permute( - ctx0, - ne_cpy(ctx0, Qcur, ne_new_tensor_3d(ctx0, NE_TYPE_F32, n_state / n_head, n_head, n_ctx, NE_SIZE_CALC)), 0, - 2, 1, 3); + struct ne_tensor* Q = + ne_permute(ctx0, ne_cpy(ctx0, Qcur, d_ne_new_tensor_3d(ctx0, NE_TYPE_F32, n_state / n_head, n_head, n_ctx)), + 0, 2, 1, 3); - struct ne_tensor* K = ne_permute( - ctx0, - ne_cpy(ctx0, Kcur, ne_new_tensor_3d(ctx0, wctx->itype, n_state / n_head, n_head, n_ctx, NE_SIZE_CALC)), 0, - 2, 1, 3); + struct ne_tensor* K = + ne_permute(ctx0, ne_cpy(ctx0, Kcur, d_ne_new_tensor_3d(ctx0, wctx->itype, n_state / n_head, n_head, n_ctx)), + 0, 2, 1, 3); // K * Q struct ne_tensor* KQ = ne_mul_mat(ctx0, K, Q); @@ -1402,7 +1400,7 @@ static bool whisper_encode_internal(whisper_context* wctx, whisper_state* wstate struct ne_tensor* V = ne_cpy(ctx0, ne_permute(ctx0, ne_reshape_3d(ctx0, Vcur, n_state / n_head, n_head, n_ctx), 1, 2, 0, 3), - ne_new_tensor_3d(ctx0, wctx->itype, n_ctx, n_state / n_head, n_head, NE_SIZE_CALC)); + d_ne_new_tensor_3d(ctx0, wctx->itype, n_ctx, n_state / n_head, n_head)); struct ne_tensor* KQV = ne_mul_mat(ctx0, V, KQ_soft_max); #endif @@ -1410,7 +1408,7 @@ static bool whisper_encode_internal(whisper_context* wctx, whisper_state* wstate wstate->use_buf(ctx0, 1); - cur = ne_cpy(ctx0, KQV_merged, ne_new_tensor_2d(ctx0, NE_TYPE_F32, n_state, n_ctx, NE_SIZE_CALC)); + cur = ne_cpy(ctx0, KQV_merged, d_ne_new_tensor_2d(ctx0, NE_TYPE_F32, n_state, n_ctx)); } // projection @@ -1449,7 +1447,7 @@ static bool whisper_encode_internal(whisper_context* wctx, whisper_state* wstate #ifdef WHISPER_USE_FLASH_FF wstate.use_buf(ctx0, 0); - cur = ggml_flash_ff(ctx0, ne_cpy(ctx0, cur, ne_new_tensor_2d(ctx0, wstate.itype, n_state, n_ctx, NE_SIZE_CALC)), + cur = ggml_flash_ff(ctx0, ne_cpy(ctx0, cur, d_ne_new_tensor_2d(ctx0, wstate.itype, n_state, n_ctx)), layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b); #else wstate->use_buf(ctx0, 0); @@ -1630,10 +1628,10 @@ static bool whisper_decode_internal(whisper_context* wctx, whisper_state* wstate struct ne_cgraph gf = {}; gf.n_threads = n_threads; - struct ne_tensor* embd = ne_new_tensor_1d(ctx0, NE_TYPE_I32, N, NE_SIZE_CALC); + struct ne_tensor* embd = d_ne_new_tensor_1d(ctx0, NE_TYPE_I32, N); memcpy(embd->data, tokens, N * ne_element_size(embd)); - struct ne_tensor* position = ne_new_tensor_1d(ctx0, NE_TYPE_I32, N, NE_SIZE_CALC); + struct ne_tensor* position = d_ne_new_tensor_1d(ctx0, NE_TYPE_I32, N); for (int i = 0; i < N; ++i) { (reinterpret_cast(position->data))[i] = n_past + i; } @@ -1695,8 +1693,7 @@ static bool whisper_decode_internal(whisper_context* wctx, whisper_state* wstate wstate->use_buf(ctx0, 0); struct ne_tensor* Q = ne_permute( - ctx0, ne_cpy(ctx0, Qcur, ne_new_tensor_3d(ctx0, NE_TYPE_F32, n_state / n_head, n_head, N, NE_SIZE_CALC)), 0, - 2, 1, 3); + ctx0, ne_cpy(ctx0, Qcur, d_ne_new_tensor_3d(ctx0, NE_TYPE_F32, n_state / n_head, n_head, N)), 0, 2, 1, 3); struct ne_tensor* K = ne_permute(ctx0, ne_reshape_3d(ctx0, @@ -1728,7 +1725,7 @@ static bool whisper_decode_internal(whisper_context* wctx, whisper_state* wstate struct ne_tensor* KQV_merged = ne_permute(ctx0, KQV, 0, 2, 1, 3); - cur = ne_cpy(ctx0, KQV_merged, ne_new_tensor_2d(ctx0, NE_TYPE_F32, n_state, N, NE_SIZE_CALC)); + cur = ne_cpy(ctx0, KQV_merged, d_ne_new_tensor_2d(ctx0, NE_TYPE_F32, n_state, N)); } // projection @@ -1781,7 +1778,7 @@ static bool whisper_decode_internal(whisper_context* wctx, whisper_state* wstate // struct ne_tensor * V_trans = // ne_cpy(ctx0, // ne_permute(ctx0, Vcross, 1, 2, 0, 3), - // ne_new_tensor_3d(ctx0, Vcross->type, M, n_state/n_head, + // d_ne_new_tensor_3d(ctx0, Vcross->type, M, n_state/n_head, // n_head)); struct ne_tensor* V = @@ -1792,8 +1789,7 @@ static bool whisper_decode_internal(whisper_context* wctx, whisper_state* wstate // ------ struct ne_tensor* Q = ne_permute( - ctx0, ne_cpy(ctx0, Qcur, ne_new_tensor_3d(ctx0, NE_TYPE_F32, n_state / n_head, n_head, N, NE_SIZE_CALC)), 0, - 2, 1, 3); + ctx0, ne_cpy(ctx0, Qcur, d_ne_new_tensor_3d(ctx0, NE_TYPE_F32, n_state / n_head, n_head, N)), 0, 2, 1, 3); struct ne_tensor* K = ne_permute(ctx0, Kcross, 0, 2, 1, 3); @@ -1817,7 +1813,7 @@ static bool whisper_decode_internal(whisper_context* wctx, whisper_state* wstate struct ne_tensor* KQV_merged = ne_permute(ctx0, KQV, 0, 2, 1, 3); // cur = KQV_merged.contiguous().view(n_state, N) - cur = ne_cpy(ctx0, KQV_merged, ne_new_tensor_2d(ctx0, NE_TYPE_F32, n_state, N, NE_SIZE_CALC)); + cur = ne_cpy(ctx0, KQV_merged, d_ne_new_tensor_2d(ctx0, NE_TYPE_F32, n_state, N)); } // projection diff --git a/neural_speed/vectors/cpu/quantize.h b/neural_speed/vectors/cpu/quantize.h index 73100e423..6cac22324 100644 --- a/neural_speed/vectors/cpu/quantize.h +++ b/neural_speed/vectors/cpu/quantize.h @@ -100,7 +100,7 @@ static inline __m256 sum_i16_pairs_float(const __m256i x) { static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { #if __AVXVNNI__ const __m256i zero = _mm256_setzero_si256(); - const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy); + const __m256i summed_pairs = _mm256_dpbusd_avx_epi32(zero, ax, sy); return _mm256_cvtepi32_ps(summed_pairs); #else // Perform multiplication and create 16-bit values diff --git a/neural_speed/vectors/cpu/vec_arithmetic.cpp b/neural_speed/vectors/cpu/vec_arithmetic.cpp index 1e47ef8b2..1e517dc70 100644 --- a/neural_speed/vectors/cpu/vec_arithmetic.cpp +++ b/neural_speed/vectors/cpu/vec_arithmetic.cpp @@ -15,7 +15,7 @@ #include "vec_load.hpp" #include "vec_store.hpp" #include "vec_arithmetic.hpp" -#include "cmath" +#include fp32x16 sub_fp32x16(fp32x16 x, fp32x16 y) { #if __AVX512F__ diff --git a/neural_speed/vectors/cpu/vec_base.hpp b/neural_speed/vectors/cpu/vec_base.hpp index 2051f9bb9..2e9709b4b 100644 --- a/neural_speed/vectors/cpu/vec_base.hpp +++ b/neural_speed/vectors/cpu/vec_base.hpp @@ -16,7 +16,7 @@ #define ENGINE_EXECUTOR_INCLUDE_VEC_BASE_HPP_ #include -#include +#include #if __AVX512F__ struct fp32x16 { diff --git a/neural_speed/vectors/ele_reduce.cpp b/neural_speed/vectors/ele_reduce.cpp index 655706d47..7a851a4cb 100644 --- a/neural_speed/vectors/ele_reduce.cpp +++ b/neural_speed/vectors/ele_reduce.cpp @@ -13,7 +13,7 @@ // limitations under the License. #include "vectors/cpu/vec.hpp" #include "vectors/ele_reduce.h" -#include "cmath" +#include void ne_vec_norm_f32_(const int n, float* s, const float* x) { float sum = 0.0;