Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
fix packrow
Browse files Browse the repository at this point in the history
  • Loading branch information
zhewang1-intc committed Jan 23, 2024
1 parent 75c1792 commit edfe4ee
Showing 1 changed file with 10 additions and 11 deletions.
21 changes: 10 additions & 11 deletions bestla/bestla/bestla_prologue_b.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class WeightKBlockNInteger {
bool is_asym) {
int KPad = utils::padto(k, _GemmCore_T::KTILE);
int NPad = utils::padto(n, _GemmCore_T::NTILE);
if (qtype == BTLA_DTYPE::S3_CLIP) NPad = utils::padto(n, _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW);
// if (qtype == BTLA_DTYPE::S3_CLIP) NPad = utils::padto(n, _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW);
StorageWeight tmp(_GemmCore_T::ID);
tmp.resize(NPad, KPad, blocksize <= 0 ? KPad : blocksize, n, k, qtype, scat, redt, is_asym);
return tmp;
Expand Down Expand Up @@ -563,10 +563,9 @@ class WeightKBlockNInteger {
parallel::IThreading* threading) {
// TODO(zhe): 1D parallel compress
// N==NPad, K==Kpad, ldb==Kpad
auto ld_dst = _GemmCore_T::PACK_ROW * _GemmCore_T::NTILE * utils::padto(K, 64);
auto col = _GemmCore_T::PACK_ROW * _GemmCore_T::NTILE * K;
assert(N % (_GemmCore_T::PACK_ROW * _GemmCore_T::NTILE) == 0); // TODO(zhe): consider N pad to packrow*ntile?
auto row = N / (_GemmCore_T::PACK_ROW * _GemmCore_T::NTILE);
auto ld_dst = _GemmCore_T::NTILE * utils::padto(K, 64);
auto col = _GemmCore_T::NTILE * K;
auto row = N / _GemmCore_T::NTILE;
auto pad_64_buf = utils::avector<int8_t>(row * ld_dst, 0);
kernel::wrapper::Memcpy2D::forward<BTLA_ISA::NoSIMD>(B, pad_64_buf.data(), row, col, col, ld_dst);
auto bit2ptr = reinterpret_cast<utils::bit2x4*>(dstptr);
Expand Down Expand Up @@ -750,8 +749,8 @@ class WeightKBlockNInteger {
int8_t* bit3_ptr = wptr->template WPtr<int8_t>();
auto elt_offset =
n_offset * utils::padto(KPad, 64) + k_offset * _GemmCore_T::NTILE + i * utils::padto(KPad, 64);
auto ld_dst = _GemmCore_T::PACK_ROW * _GemmCore_T::NTILE * utils::padto(KPad, 64);
auto row = NPad / (_GemmCore_T::PACK_ROW * _GemmCore_T::NTILE);
auto ld_dst = _GemmCore_T::NTILE * utils::padto(KPad, 64);
auto row = NPad / _GemmCore_T::NTILE;
assert(elt_offset % 8 == 0);
auto bit2ptr = reinterpret_cast<utils::bit2x4*>(bit3_ptr + elt_offset / 4);
auto bit1ptr = reinterpret_cast<utils::bit1x8*>(bit3_ptr + row * ld_dst / 4 + elt_offset / 8);
Expand Down Expand Up @@ -831,8 +830,8 @@ class WeightKBlockNInteger {
int8_t* bit3_ptr = wptr->template WPtr<int8_t>();
auto elt_offset =
n_offset * utils::padto(KPad, 64) + k_offset * _GemmCore_T::NTILE + i * utils::padto(KPad, 64);
auto ld_dst = _GemmCore_T::PACK_ROW * _GemmCore_T::NTILE * utils::padto(KPad, 64);
auto row = NPad / (_GemmCore_T::PACK_ROW * _GemmCore_T::NTILE);
auto ld_dst = _GemmCore_T::NTILE * utils::padto(KPad, 64);
auto row = NPad / _GemmCore_T::NTILE;
assert(elt_offset % 8 == 0);
auto bit2ptr = reinterpret_cast<utils::bit2x4*>(bit3_ptr + elt_offset / 4);
auto bit1ptr = reinterpret_cast<utils::bit1x8*>(bit3_ptr + row * ld_dst / 4 + elt_offset / 8);
Expand Down Expand Up @@ -904,8 +903,8 @@ class WeightKBlockNInteger {
auto KPad = wptr->mKPad;
auto NPad = wptr->mNPad;
int constexpr ColSize = _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW;
auto row = NPad / (_GemmCore_T::PACK_ROW * _GemmCore_T::NTILE);
auto ld_dst = _GemmCore_T::PACK_ROW * _GemmCore_T::NTILE * utils::padto(KPad, 64);
auto row = NPad / _GemmCore_T::NTILE;
auto ld_dst = _GemmCore_T::NTILE * utils::padto(KPad, 64);
auto base_offset = n_offset * utils::padto(KPad, 64) + k_offset * _GemmCore_T::NTILE;
for (int i = 0; i < n_size; i += _GemmCore_T::NTILE) {
auto elt_offset = base_offset + i * utils::padto(KPad, 64);
Expand Down

0 comments on commit edfe4ee

Please sign in to comment.