Skip to content

Commit

Permalink
vulkan: optimize coopmat2 q2_k dequant function
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffbolznv committed Jan 7, 2025
1 parent a3d50bc commit 30645aa
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp
Original file line number Diff line number Diff line change
Expand Up @@ -101,19 +101,25 @@ layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ2_
block_q2_K block;
};

layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ2_K_packed16 {
block_q2_K_packed16 block;
};

float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufQ2_K_packed16 bl16 = decodeBufQ2_K_packed16(bl);
const f16vec2 d = bl.block.d;
const uint idx = coordInBlock[1];
const uint iqs = idx;

const uint qsi = (iqs / 128) * 32 + (iqs % 32); // 0..31
const uint scalesi = iqs / 16; // 0..15
const uint qsshift = ((iqs % 128) / 32) * 2; // 0,2,4,6
const uint scalesi = (idx & 0xF0) >> 4; // 0..15
const uint qsshift = (idx & 0x60) >> 4; // 0,2,4,6

uint qs = uint32_t(bl16.block.qs[((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1)]);
qs = (qs >> qsshift) & 0x0303;
qs = unpack8(qs)[idx & 1];

uint32_t qs = bl.block.qs[qsi];
const uint scales = bl.block.scales[scalesi];
float16_t ret = d.x * float16_t(scales & 0xF) * float16_t((qs >> qsshift) & 3) - d.y * float16_t(scales >> 4);
float16_t ret = d.x * float16_t(scales & 0xF) * float16_t(qs) - d.y * float16_t(scales >> 4);
return ret;
}

Expand Down

0 comments on commit 30645aa

Please sign in to comment.