Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
new api to assign sycl buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
ThanatosShinji committed Jun 1, 2024
1 parent adaec6b commit d2ae4ac
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 38 deletions.
95 changes: 66 additions & 29 deletions bestla/bestla/sycl/sycl_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,18 @@ class StorageWeightKBlockNInteger {
int mN = 0, mK = 0;
int mBlockSize = 1;
int mDqBlockSize = 0;
size_t mCSize = 0;
int mCStep = 0;
sycl_utils::sycl_vector<int8_t> mQBuf;
sycl_utils::sycl_vector<int8_t> mScaleBuf;
sycl_utils::sycl_vector<int8_t> mZpBuf, mRedBuf;
sycl_utils::sycl_vector<int8_t> mDQCorrectionBuf;
sycl_utils::sycl_vector<int8_t> mShuffleIndices;
int8_t* mQBuf = nullptr;
size_t mWSize = 0;
int8_t* mSBuf = nullptr;
size_t mCSize = 0;
int8_t *mZpBuf = nullptr, *mRedBuf = nullptr;
size_t mZpSize = 0, mRedSize = 0;
int8_t* mDQCorrectionBuf = nullptr;
int8_t* mShuffleIndices = nullptr;
size_t mDQCorSize = 0, mShufSize = 0;

StorageWeightKBlockNInteger(bestla::storage::gemm::StorageWeightKBlockNInteger& _hoststor, sycl::queue* queue) {
StorageWeightKBlockNInteger(bestla::storage::gemm::StorageWeightKBlockNInteger& _hoststor) {
mPrologueID = BTLA_PROLOGUEB_IDS::WeightKBlockNInteger;
mCoreId = 0;
mDType = _hoststor.mDType;
Expand All @@ -48,40 +51,74 @@ class StorageWeightKBlockNInteger {
mK = _hoststor.mK;
mBlockSize = _hoststor.mBlockSize;
mDqBlockSize = _hoststor.mDqBlockSize;
mCSize = _hoststor.CSize();
mWSize = _hoststor.template WSize<int8_t>();
mCSize = _hoststor.CSize() * utils::bestla_dtype_size(mScaT);
mCStep = _hoststor.CStep();
if (_hoststor.template WPtr<void>()) {
mQBuf.resize(_hoststor.template WSize<int8_t>(), queue);
queue->memcpy(mQBuf.data(), _hoststor.template WPtr<void>(), mQBuf.size()).wait();
}
size_t csize = _hoststor.CSize();
if (_hoststor.template SPtr<void>()) {
mScaleBuf.resize(csize * _hoststor.mCorrection.mScaEleSize, queue);
queue->memcpy(mScaleBuf.data(), _hoststor.template SPtr<void>(), mScaleBuf.size()).wait();
}

if (_hoststor.template ZPtr<void>()) {
mZpBuf.resize(csize * _hoststor.mCorrection.mZpEleSize, queue);
queue->memcpy(mZpBuf.data(), _hoststor.template ZPtr<void>(), mZpBuf.size()).wait();
mZpSize = mCSize * utils::bestla_dtype_size(mZpT);
}
if (_hoststor.template RPtr<void>()) {
mRedBuf.resize(csize * _hoststor.mCorrection.mRedEleSize, queue);
queue->memcpy(mRedBuf.data(), _hoststor.template RPtr<void>(), mRedBuf.size()).wait();
mRedSize = mCSize * utils::bestla_dtype_size(mRedT);
}
// TODO DQ,shuffle support
}

size_t getDeviceSize() { return mWSize + mCSize + mZpSize + mRedSize + mDQCorSize + mShufSize; }

void assign(int8_t* dptr) {
mQBuf = dptr;
dptr += mWSize;
mSBuf = dptr;
dptr += mCSize;
if (mZpSize) {
mZpBuf = dptr;
dptr += mZpSize;
}
if (mRedSize) {
mRedBuf = dptr;
dptr += mRedSize;
}
if (mDQCorSize) {
mDQCorrectionBuf = dptr;
dptr += mDQCorSize;
}
if (mShuffleIndices) {
mDQCorrectionBuf = dptr;
dptr += mShufSize;
}
}

void fromHost(bestla::storage::gemm::StorageWeightKBlockNInteger& _hoststor, sycl::queue* queue) {
if (_hoststor.template WPtr<void>() && mQBuf) {
queue->memcpy(mQBuf, _hoststor.template WPtr<void>(), mWSize);
}
if (_hoststor.template SPtr<void>() && mSBuf) {
queue->memcpy(mSBuf, _hoststor.template SPtr<void>(), mCSize);
}
if (_hoststor.template ZPtr<void>() && mZpBuf) {
queue->memcpy(mZpBuf, _hoststor.template ZPtr<void>(), mZpSize);
}
if (_hoststor.template RPtr<void>() && mRedBuf) {
queue->memcpy(mRedBuf, _hoststor.template RPtr<void>(), mRedSize);
}
queue->wait();
}

void toHost(bestla::storage::gemm::StorageWeightKBlockNInteger& _hoststor, sycl::queue* queue) {
if (mQBuf.data()) {
queue->memcpy(_hoststor.template WPtr<void>(), mQBuf.data(), mQBuf.size()).wait();
if (mQBuf) {
queue->memcpy(_hoststor.template WPtr<void>(), mQBuf, mWSize);
}
if (mScaleBuf.data()) {
queue->memcpy(_hoststor.template SPtr<void>(), mScaleBuf.data(), mScaleBuf.size()).wait();
if (mSBuf) {
queue->memcpy(_hoststor.template SPtr<void>(), mSBuf, mCSize);
}
if (mZpBuf.data()) {
queue->memcpy(_hoststor.template ZPtr<void>(), mZpBuf.data(), mZpBuf.size()).wait();
if (mZpBuf) {
queue->memcpy(_hoststor.template ZPtr<void>(), mZpBuf, mZpSize);
}
if (mRedBuf.data()) {
queue->memcpy(_hoststor.template RPtr<void>(), mRedBuf.data(), mRedBuf.size()).wait();
if (mRedBuf) {
queue->memcpy(_hoststor.template RPtr<void>(), mRedBuf, mRedSize);
}
queue->wait();
}
};
} // namespace sycl_storage
Expand Down
25 changes: 16 additions & 9 deletions bestla/bestla/ut/sycl_misc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class UT_SyclDevice {
dev->print();
}
};
static UT_SyclDevice sUT_SyclDevice;
static UT_SyclDevice sUT_SyclDevice;

class UT_SyclVector {
public:
Expand Down Expand Up @@ -56,7 +56,10 @@ class UT_StorageMemCheck {
packedW.assign(buf0.data());
auto dev = UT_Device::get();
auto q = dev->getQueue();
sycl_storage::StorageWeightKBlockNInteger sycl_stor(packedW, q);
sycl_storage::StorageWeightKBlockNInteger sycl_stor(packedW);
sycl_utils::sycl_vector<int8_t> dbuf(sycl_stor.getDeviceSize(), q);
sycl_stor.assign(dbuf.data());
sycl_stor.fromHost(packedW, q);
storage::gemm::StorageWeightKBlockNInteger tmp = packedW;
tmp.assign(buf1.data());
sycl_stor.toHost(tmp, q);
Expand Down Expand Up @@ -92,14 +95,16 @@ class UT_BlockQunatize_S3S4 {
avector<int8_t> buffer1(transtor.mSize);
transtor.assign(buffer1.data());
kernel.convertTransStorage(ptr, transtor, UT_Threading::get());
sycl_storage::StorageWeightKBlockNInteger sycl_stor(transtor, q);
sycl_storage::StorageWeightKBlockNInteger sycl_stor(transtor);
sycl_utils::sycl_vector<int8_t> dbuf(sycl_stor.getDeviceSize(), q);
sycl_stor.assign(dbuf.data());
sycl_stor.fromHost(transtor, q);
avector<float> dequant(n * k, 0);
using ProB = sycl_prologue_b::WeightS4Trans<GemmCore, float>;
sycl_utils::sycl_vector<float> dequantB(n * k, q);
int blks = updiv(k, blocksize);
auto evt = ProB::dequant_s4<sycl_prologue_b::KernelConfigTrans>(
n, k, blocksize, {(uint8_t*)sycl_stor.mQBuf.data(), (float*)sycl_stor.mScaleBuf.data(), blks}, dequantB.data(),
q);
n, k, blocksize, {(uint8_t*)sycl_stor.mQBuf, (float*)sycl_stor.mSBuf, blks}, dequantB.data(), q);
evt.wait();
q->memcpy(dequant.data(), dequantB.data(), dequantB.size() * 4).wait();
ut::buffer_error(raw.data(), dequant.data(), dequant.size(), 0.01f);
Expand Down Expand Up @@ -149,11 +154,13 @@ class UT_CompFp32 {
avector<int8_t> buffer1(transtor.mSize);
transtor.assign(buffer1.data());
proB.convertTransStorage(packedw, transtor, UT_Threading::get());
sycl_storage::StorageWeightKBlockNInteger sycl_stor(transtor, q);
sycl_storage::StorageWeightKBlockNInteger sycl_stor(transtor);
sycl_utils::sycl_vector<int8_t> dbuf(sycl_stor.getDeviceSize(), q);
sycl_stor.assign(dbuf.data());
sycl_stor.fromHost(transtor, q);
int blks = updiv(k, blocksize);
auto e_esimd =
ProBTransT::gemv(dA.data(), {(uint8_t*)sycl_stor.mQBuf.data(), (float*)sycl_stor.mScaleBuf.data(), blks},
dC.data(), n, k, blocksize, q);
auto e_esimd = ProBTransT::gemv(dA.data(), {(uint8_t*)sycl_stor.mQBuf, (float*)sycl_stor.mSBuf, blks}, dC.data(), n,
k, blocksize, q);
e_esimd.wait();
q->memcpy(matC.data(), dC.data(), matC.size() * 4).wait();

Expand Down

0 comments on commit d2ae4ac

Please sign in to comment.