Skip to content

Commit

Permalink
4x4 -> 4x
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Nov 28, 2024
1 parent d04731f commit 0e15a08
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 19 deletions.
2 changes: 1 addition & 1 deletion ggml/src/ggml-metal/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -1963,7 +1963,7 @@ static void ggml_metal_encode_node(
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32].pipeline;

const int nsg = 2;
const int r0pt = 2;
const int r0pt = 1;
const int r1pt = 1;
const int nxpsg = ne11 > 1 ? 8 : 32;
const int nypsg = 32/nxpsg;
Expand Down
51 changes: 33 additions & 18 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,26 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
reg = (type4x4) reg_f;
}

template <typename type4>
void dequantize_q4_0x(device const block_q4_0 *xb, short il, thread type4 & reg) {
device const int8_t * qs = ((device const int8_t *)xb->qs);
const half d = xb->d;

for (int i = 0; i < 4; i++) {
reg[i] = qs[0];
}
}

template <typename type4>
void dequantize_q8_0x(device const block_q8_0 *xb, short il, thread type4 & reg) {
device const int8_t * qs = ((device const int8_t *)xb->qs);
const half d = xb->d;

for (int i = 0; i < 4; i++) {
reg[i] = (qs[4*(il%4) + i + 16*(il/4)] * d);
}
}

template <typename type4x4>
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
const float d = xb->d;
Expand Down Expand Up @@ -1762,8 +1782,8 @@ void kernel_mul_mv_ext_q8_0_f32_impl(
ushort3 ntg[[threads_per_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
const short chpt = 1;
const short r0pt = 2;
const short chpt = 4;
const short r0pt = 1;

//const short nxpsg = (32);
const short nypsg = (32/nxpsg)*r0pt;
Expand All @@ -1784,36 +1804,31 @@ void kernel_mul_mv_ext_q8_0_f32_impl(
device const block_q8_0 * xq[r0pt];

for (short ir0 = 0; ir0 < r0pt; ++ir0) {
xq[ir0] = (i01 + ir0 < args.ne01) ? (device const block_q8_0 *) (src0 + offset0 + ir0*args.nb01) + (chpt*tx)/2 : (device const block_q8_0 *) src0;
//xq[ir0] = (i01 + ir0 < args.ne01) ? (device const block_q8_0 *) (src0 + offset0 + ir0*args.nb01) + (chpt*tx)/8 : (device const block_q8_0 *) src0;
xq[ir0] = (i01 + ir0 < args.ne01) ? (device const block_q8_0 *) (src0 + offset0 + ir0*args.nb01) + (tx)/8 : (device const block_q8_0 *) src0;
}

device const float4x4 * y4x4 = (device const float4x4 *) (src1 + offset1) + chpt*tx;
//device const float4 * y4 = (device const float4 *) (src1 + offset1) + chpt*tx;
device const float4 * y4 = (device const float4 *) (src1 + offset1) + tx;

float sumf[r0pt] = { [0 ... r0pt - 1] = 0.0f };

for (int iib = 0; (16*chpt)*(iib*nxpsg + tx) < args.ne00; ++iib) {
float4x4 lx;

#pragma unroll(2)
for (int iib = 0; (4*chpt)*(iib*nxpsg + tx) < args.ne00; ++iib) {
for (short ir0 = 0; ir0 < r0pt; ++ir0) {
#pragma unroll
#pragma unroll(4)
for (short ch = 0; ch < chpt; ++ch) {
dequantize_q8_0(xq[ir0] + ch/2, (chpt*tx + ch)%2, lx);
float4 lx;

const float4x4 ly = y4x4[ch];
dequantize_q8_0x(xq[ir0] + (ch*nxpsg)/8, (tx)%8, lx);

sumf[ir0] +=
dot(lx[0], ly[0]) +
dot(lx[1], ly[1]) +
dot(lx[2], ly[2]) +
dot(lx[3], ly[3]);
sumf[ir0] += dot(lx, y4[ch*nxpsg]);
}
}

y4x4 += ((16*chpt)*nxpsg)/16;
y4 += ((4*chpt)*nxpsg)/4;

for (short ir0 = 0; ir0 < r0pt; ++ir0) {
xq[ir0] += ((16*chpt)*nxpsg)/32;
xq[ir0] += ((4*chpt)*nxpsg)/32;
}
}

Expand Down

0 comments on commit 0e15a08

Please sign in to comment.