Skip to content

Commit

Permalink
use larger block size
Browse files Browse the repository at this point in the history
  • Loading branch information
ReinForce-II committed May 23, 2024
1 parent d83ca59 commit d13da05
Showing 1 changed file with 40 additions and 15 deletions.
55 changes: 40 additions & 15 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
#undef GGML_USE_LLAMAFILE
#define AMX_TILE_MN 16
#define AMX_TILE_K 16
#define AMX_BLCK_SIZE 64
#endif

#ifdef GGML_USE_LLAMAFILE
Expand Down Expand Up @@ -12433,8 +12434,13 @@ static void ggml_compute_forward_mul_mat_one_chunk(
assert(ne13 % ne03 == 0);

// block-tiling attempt
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
const int64_t blck_0 = num_rows_per_vec_dot == AMX_TILE_MN ? AMX_BLCK_SIZE : 16;
const int64_t blck_1 = num_rows_per_vec_dot == AMX_TILE_MN ? AMX_BLCK_SIZE : 16;
#else
const int64_t blck_0 = 16;
const int64_t blck_1 = 16;
#endif

const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11;

Expand All @@ -12448,19 +12454,40 @@ static void ggml_compute_forward_mul_mat_one_chunk(
assert(blck_0 % AMX_TILE_MN == 0 && blck_1 % AMX_TILE_MN == 0);
assert(src1->type == GGML_TYPE_F32);
}
// 16 * AMX_TILE_MN, accounting for amx kernels
float tmp[16*AMX_TILE_MN];
uint8_t * wbase = (uint8_t *) (params->wdata) + params->ith*(2*AMX_TILE_MN*ne00*sizeof(ggml_bf16_t)+ne00*sizeof(float)+4096);
// AMX_BLCK_SIZE * AMX_TILE_MN, accounting for amx kernels
float tmp[AMX_BLCK_SIZE*AMX_TILE_MN];
uint8_t * wbase = (uint8_t *) (params->wdata) + params->ith*(2*AMX_BLCK_SIZE*ne00*sizeof(ggml_bf16_t)+ne00*sizeof(float)+4096);
ggml_bf16_t * xbf16 = (ggml_bf16_t *)(wbase);
ggml_bf16_t * ybf16 = (ggml_bf16_t *)(wbase + 1*AMX_TILE_MN*ne00*sizeof(ggml_bf16_t));
float * xf32 = (float *) (wbase + 2*AMX_TILE_MN*ne00*sizeof(ggml_bf16_t));
ggml_bf16_t * ybf16 = (ggml_bf16_t *)(wbase + 1*AMX_BLCK_SIZE*ne00*sizeof(ggml_bf16_t));
float * xf32 = (float *) (wbase + 2*AMX_BLCK_SIZE*ne00*sizeof(ggml_bf16_t));
xf32 = (float *) (((size_t)xf32 + 4095) & ~4095);
#else
float tmp[16];
#endif

for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
if (num_rows_per_vec_dot == AMX_TILE_MN) {
const int64_t ii13 = (iir1 / (ne12 * ne1));
const int64_t ii12 = (iir1 - ii13 * ne12 * ne1) / ne1;
const int64_t ii11 = (iir1 - ii13 * ne12 * ne1 - ii12 * ne1);

// broadcast src0 into src1
const int64_t ii03 = ii13 / r3;
const int64_t ii02 = ii12 / r2;

const char * src0_row = (const char*)src0->data + (0 + ii02 * nb02 + ii03 * nb03);
const uint8_t * src1_col = (const uint8_t *)src1->data + ii11 * nb11 + ii12 * nb12 + ii13 * nb13;
for (int i = 0; i < blck_0 && iir0 + i < ir0_end; ++i) {
to_float((const uint8_t *)src0_row + iir0*nb01 + i*nb01, xf32, ne00);
ggml_fp32_to_bf16_row(xf32, xbf16 + i*ne00, ne00);
}
for (int i = 0; i < blck_1 && iir1 + i < ir1_end; ++i) {
ggml_fp32_to_bf16_row((const float *)(src1_col + i*nb11), ybf16 + i*ne00, ne00);
}
}
#endif
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1 += num_rows_per_vec_dot) {
const int64_t i13 = (ir1 / (ne12 * ne1));
const int64_t i12 = (ir1 - i13 * ne12 * ne1) / ne1;
Expand All @@ -12478,14 +12505,8 @@ static void ggml_compute_forward_mul_mat_one_chunk(
float * dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
if (num_rows_per_vec_dot == AMX_TILE_MN) {
const uint8_t * src1_col = (const uint8_t *) src1->data + i11 * nb11 + i12 * nb12 + i13 * nb13;
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) {
for (int cn = 0; cn < AMX_TILE_MN; ++cn) {
to_float((const uint8_t *)src0_row + ir0*nb01 + cn*nb01, xf32, ne00);
ggml_fp32_to_bf16_row(xf32, xbf16 + cn*ne00, ne00);
ggml_fp32_to_bf16_row((const float *)(src1_col + cn*nb11), ybf16 + cn*ne00, ne00);
}
ggml_vec_dot_bf16(ne00, &tmp[ir0 - iir0], 16, xbf16, ne00*sizeof(ggml_bf16_t), ybf16, ne00*sizeof(ggml_bf16_t), AMX_TILE_MN);
ggml_vec_dot_bf16(ne00, &tmp[ir0 - iir0], blck_0, xbf16 + (ir0-iir0)*ne00, ne00*sizeof(ggml_bf16_t), ybf16 + (ir1-iir1)*ne00, ne00*sizeof(ggml_bf16_t), AMX_TILE_MN);
}
}
else
Expand All @@ -12505,12 +12526,12 @@ static void ggml_compute_forward_mul_mat_one_chunk(
//}

for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) {
vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? blck_0 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
}
}

for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) {
memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * 16), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float));
memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * blck_0), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float));
}
}
}
Expand Down Expand Up @@ -12748,7 +12769,11 @@ UseGgmlGemm2:;
#endif

// Now select a reasonable chunk size.
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
int chunk_size = AMX_BLCK_SIZE;
#else
int chunk_size = 16;
#endif

// We need to step up the size if it's small
if (nr0 == 1 || nr1 == 1) {
Expand Down Expand Up @@ -20314,7 +20339,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
#endif
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
if ((node->src[0]->ne[0] % (AMX_TILE_K*4/sizeof(ggml_bf16_t)) == 0) && (node->src[0]->ne[1] % AMX_TILE_MN == 0) && (node->src[1]->ne[1] % AMX_TILE_MN == 0)) {
cur = n_threads*(2*AMX_TILE_MN*node->src[0]->ne[0]*sizeof(ggml_bf16_t)+node->src[0]->ne[0]*sizeof(float)+4096);
cur = n_threads*(2*AMX_BLCK_SIZE*node->src[0]->ne[0]*sizeof(ggml_bf16_t)+node->src[0]->ne[0]*sizeof(float)+4096);
} else
#endif
if (node->src[1]->type != vec_dot_type) {
Expand Down

0 comments on commit d13da05

Please sign in to comment.