Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
add sycl int4 to graph compute
Browse files Browse the repository at this point in the history
  • Loading branch information
ThanatosShinji committed Jun 1, 2024
1 parent dba3905 commit 9ccbe5a
Show file tree
Hide file tree
Showing 15 changed files with 198 additions and 50 deletions.
2 changes: 1 addition & 1 deletion bestla/bestla/sycl/sycl_gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#ifdef BTLA_SYCL
#include <array>

#include "bestla_utils.h"
#include "bestla/bestla_utils.h"
#include <sycl/sycl.hpp>

namespace bestla {
Expand Down
2 changes: 1 addition & 1 deletion bestla/bestla/sycl/sycl_prologue_a.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#ifdef BTLA_SYCL
#include <array>

#include "bestla_utils.h"
#include "bestla/bestla_utils.h"
#include <sycl/sycl.hpp>

namespace bestla {
Expand Down
2 changes: 1 addition & 1 deletion bestla/bestla/sycl/sycl_prologue_b.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#ifdef BTLA_SYCL
#include <array>

#include "bestla_utils.h"
#include "bestla/bestla_utils.h"
#include <sycl/sycl.hpp>

namespace bestla {
Expand Down
4 changes: 2 additions & 2 deletions bestla/bestla/sycl/sycl_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ class StorageWeightKBlockNInteger {
mCStep = _hoststor.CStep();

if (_hoststor.template ZPtr<void>()) {
mZpSize = mCSize * utils::bestla_dtype_size(mZpT);
mZpSize = _hoststor.CSize() * utils::bestla_dtype_size(mZpT);
}
if (_hoststor.template RPtr<void>()) {
mRedSize = mCSize * utils::bestla_dtype_size(mRedT);
mRedSize = _hoststor.CSize() * utils::bestla_dtype_size(mRedT);
}
// TODO DQ,shuffle support
}
Expand Down
2 changes: 1 addition & 1 deletion bestla/bestla/sycl/sycl_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.
#pragma once
#include "sycl_device.h"
#include "bestla_utils.h"
#include "bestla/bestla_utils.h"

namespace bestla {
namespace sycl_utils {
Expand Down
4 changes: 2 additions & 2 deletions bestla/bestla/sycl/sycl_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#ifdef BTLA_SYCL
#include <sycl/sycl.hpp>

#include "bestla_utils.h"
#include "bestla/bestla_utils.h"
#include "sycl_utils.h"
#include "sycl_device.h"
#include "sycl_gemm.h"
Expand Down Expand Up @@ -156,7 +156,7 @@ class LauncherWOQ {
[=](sycl::nd_item<2> it) [[cl::reqd_work_group_size(
1, GemmCore::WgM,
GemmCore::WgN)]] [[intel::kernel_args_restrict]] [[intel::reqd_sub_group_size(GemmCore::SgSize)]] {
nd_item_helper<GemmCore> helper(it);
sycl_utils::nd_item_helper<GemmCore> helper(it);
if constexpr (debug) {
compute_tile(k, blocksize, B, B_scale, ldb, slm_b, A, lda, C, ldc, it);
} else {
Expand Down
4 changes: 3 additions & 1 deletion neural_speed/application/main_run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
#include "models/model_utils/model_config.h"
#include "models/model_utils/model_utils.h"

#include "core/ne_bestla.h"

#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__))
#include <signal.h>
#include <unistd.h>
Expand Down Expand Up @@ -134,7 +136,7 @@ int main(int argc, char** argv) { // NOLINT
}

model_init_backend();

bestla_set_threads(params.n_threads);
model_context* ctx;
g_ctx = &ctx;

Expand Down
1 change: 1 addition & 0 deletions neural_speed/core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ else()
endif()

if(NS_SYCL)
target_compile_definitions(ne_layers PRIVATE BTLA_SYCL)
target_link_libraries(ne_layers PRIVATE IntelSYCL::SYCL_CXX)
endif()

Expand Down
78 changes: 78 additions & 0 deletions neural_speed/core/layers/ne_bestla.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ void bestla_add(int batch, int vsize, const float* tensor, const float* vector,

#ifdef NS_SYCL
#include "bestla/sycl/sycl_device.h"
#include "bestla/sycl/sycl_storage.h"
void* bestla_create_device(bool profile) {
auto ptr = new sycl_device::SyclDevice(profile);
ptr->print();
Expand Down Expand Up @@ -231,4 +232,81 @@ void bestla_device_sync(void* queue) {
ptr->wait();
}
}

size_t bestla_device_storage_size() { return sizeof(sycl_storage::StorageWeightKBlockNInteger); }

void bestla_device_load_storage(void* hoststor, void* devstor, void* deviceptr, void* device_queue) {
auto ptr = storage::gemm::PackedWeightParser::deserialBuffer(const_cast<void*>(hoststor));
GetCPUDevice();
if (ptr && devstor && deviceptr) {
auto dstor = (sycl_storage::StorageWeightKBlockNInteger*)devstor;
if (ptr->mPrologueID == BTLA_PROLOGUEB_IDS::WeightKBlockNInteger) {
auto sptr = reinterpret_cast<storage::gemm::StorageWeightKBlockNInteger*>(ptr);
auto transtor = sptr->toTrans();
utils::avector<int8_t> buffer1(transtor.mSize);
transtor.assign(buffer1.data());
auto coretype = sptr->mCoreId;
auto NTile = gemm::CoreAttr::get_mask_val(sptr->mCoreId, gemm::CoreAttr::NTILE_MASK, gemm::CoreAttr::NTILE_SHIFT);
auto PackRow = gemm::CoreAttr::get_packrow(sptr->mCoreId);
auto CType = gemm::CoreAttr::get_comp(sptr->mCoreId);
auto btype = static_cast<gemm::CompType>(gemm::CompTypeHelper::get_B(CType));
if (btype == gemm::CompType::tFP32 && PackRow == 1) {
if (NTile == tAVX512F::NTILE && _cd->AVX512F()) {
static prologue_b::gemm::WeightKBlockNInteger<tAVX512F, tAVX512F::ISA> proB;
proB.convertTransStorage(*sptr, transtor, ne_bestla::ne_threading::get());
} else if (NTile == tAVX2::NTILE && _cd->AVX2()) {
static prologue_b::gemm::WeightKBlockNInteger<tAVX2, tAVX2::ISA> proB;
proB.convertTransStorage(*sptr, transtor, ne_bestla::ne_threading::get());
}
}
if (btype == gemm::CompType::tS8 && PackRow == 4) {
if (NTile == tAMX_INT8_SS_KBlock::NTILE && _cd->AMX_INT8()) {
static prologue_b::gemm::WeightKBlockNInteger<tAMX_INT8_SS_KBlock, tAMX_INT8_SS_KBlock::ISA> proB;
proB.convertTransStorage(*sptr, transtor, ne_bestla::ne_threading::get());
} else if (NTile == tAVX512_VNNI_KBlock::NTILE && _cd->AVX512_VNNI()) {
static prologue_b::gemm::WeightKBlockNInteger<tAVX512_VNNI_KBlock, tAVX512_VNNI_KBlock::ISA> proB;
proB.convertTransStorage(*sptr, transtor, ne_bestla::ne_threading::get());
} else if (NTile == tAVX_VNNI_KBlock::NTILE && _cd->AVX_VNNI()) {
static prologue_b::gemm::WeightKBlockNInteger<tAVX_VNNI_KBlock, tAVX_VNNI_KBlock::ISA> proB;
proB.convertTransStorage(*sptr, transtor, ne_bestla::ne_threading::get());
}
}
if (btype == gemm::CompType::tBF16 && PackRow == 2) {
if (NTile == tAMX_BF16::NTILE && _cd->AMX_BF16()) {
static prologue_b::gemm::WeightKBlockNInteger<tAMX_BF16, tAMX_BF16::ISA> proB;
proB.convertTransStorage(*sptr, transtor, ne_bestla::ne_threading::get());
}
}
*dstor = sycl_storage::StorageWeightKBlockNInteger(transtor);
dstor->assign((int8_t*)deviceptr);
dstor->fromHost(transtor, (sycl::queue*)device_queue);
}
}
}

#include "bestla/sycl/sycl_gemm.h"
#include "bestla/sycl/sycl_prologue_b.h"
#include "bestla/sycl/sycl_wrapper.h"
template <class GCT>
using ProAT = sycl_prologue_a::ActivationBase<GCT, float>;
template <class GCT>
using ProBTransT = sycl_prologue_b::WeightS4Trans<GCT, float>;
template <class GCT>
using EpiT = sycl_epilogue::OutputBase<GCT, float>;
void bestla_device_f32f32_forward(float* activation, void* weiptr, float* output, int _m, int _n, int _k, int lda,
int ldo, void* workspace, void* queue) {
using GemmCore = sycl_gemm::xve::DefaultSGemmCore;
auto dstor = (sycl_storage::StorageWeightKBlockNInteger*)weiptr;
if (_m == 1) {
using ProB = ProBTransT<GemmCore>;
auto e_esimd = ProB::gemv(activation, {(uint8_t*)dstor->mQBuf, (float*)dstor->mSBuf, dstor->mCStep}, output, _n, _k,
dstor->mBlockSize, (sycl::queue*)queue);
} else {
using KernelTLauncher = sycl_wrapper::LauncherWOQ<ProAT, ProBTransT, EpiT, GemmCore>;
utils::GemmProblem gp(1, _m, _n, _k);
auto e_esimd = KernelTLauncher::compute(
(sycl::queue*)queue, _m, _n, _k, dstor->mBlockSize,
{{activation, lda}, {(uint8_t*)dstor->mQBuf, (float*)dstor->mSBuf, dstor->mCStep}, {output, ldo}});
}
}
#endif
8 changes: 8 additions & 0 deletions neural_speed/core/ne.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,9 @@ struct ne_cgraph {
size_t work_size;
struct ne_tensor* work;

size_t dev_work_size;
struct ne_tensor* dev_work;

struct ne_tensor* nodes[NE_MAX_NODES];
struct ne_tensor* grads[NE_MAX_NODES];
struct ne_tensor* leafs[NE_MAX_NODES];
Expand Down Expand Up @@ -239,6 +242,11 @@ struct ne_compute_params {
// work buffer for all threads
size_t wsize;
void* wdata;

size_t dev_wsize;
void* dev_wdata;

void* dev_queue;
};

#ifdef __cplusplus
Expand Down
4 changes: 4 additions & 0 deletions neural_speed/core/ne_bestla.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ void bestla_device_free(void* ptr, void* queue);
void bestla_device_memcpy(void* dstptr, const void* srcptr, size_t size, void* queue);
void bestla_device_memcpy_sync(void* dstptr, const void* srcptr, size_t size, void* queue);
void bestla_device_sync(void* queue);
size_t bestla_device_storage_size();
void bestla_device_load_storage(void* hoststor, void* devstor, void* deviceptr, void* queue);
void bestla_device_f32f32_forward(float* activation, void* weiptr, float* output, int _m, int _n, int _k, int lda,
int ldo, void* workspace, void* queue);
#endif
#ifdef __cplusplus
}
Expand Down
Loading

0 comments on commit 9ccbe5a

Please sign in to comment.