From d13da05cbf20b78222907987985439b2c3a4a0a8 Mon Sep 17 00:00:00 2001 From: Reinforce-II Date: Thu, 23 May 2024 01:17:51 +0800 Subject: [PATCH] use larger block size --- ggml.c | 55 ++++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 40 insertions(+), 15 deletions(-) diff --git a/ggml.c b/ggml.c index b7db6dd6318fe5..b4b68d89013aff 100644 --- a/ggml.c +++ b/ggml.c @@ -46,6 +46,7 @@ #undef GGML_USE_LLAMAFILE #define AMX_TILE_MN 16 #define AMX_TILE_K 16 +#define AMX_BLCK_SIZE 64 #endif #ifdef GGML_USE_LLAMAFILE @@ -12433,8 +12434,13 @@ static void ggml_compute_forward_mul_mat_one_chunk( assert(ne13 % ne03 == 0); // block-tiling attempt +#if defined(__AMX_TILE__) && defined(__AMX_BF16__) + const int64_t blck_0 = num_rows_per_vec_dot == AMX_TILE_MN ? AMX_BLCK_SIZE : 16; + const int64_t blck_1 = num_rows_per_vec_dot == AMX_TILE_MN ? AMX_BLCK_SIZE : 16; +#else const int64_t blck_0 = 16; const int64_t blck_1 = 16; +#endif const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11; @@ -12448,12 +12454,12 @@ static void ggml_compute_forward_mul_mat_one_chunk( assert(blck_0 % AMX_TILE_MN == 0 && blck_1 % AMX_TILE_MN == 0); assert(src1->type == GGML_TYPE_F32); } - // 16 * AMX_TILE_MN, accounting for amx kernels - float tmp[16*AMX_TILE_MN]; - uint8_t * wbase = (uint8_t *) (params->wdata) + params->ith*(2*AMX_TILE_MN*ne00*sizeof(ggml_bf16_t)+ne00*sizeof(float)+4096); + // AMX_BLCK_SIZE * AMX_TILE_MN, accounting for amx kernels + float tmp[AMX_BLCK_SIZE*AMX_TILE_MN]; + uint8_t * wbase = (uint8_t *) (params->wdata) + params->ith*(2*AMX_BLCK_SIZE*ne00*sizeof(ggml_bf16_t)+ne00*sizeof(float)+4096); ggml_bf16_t * xbf16 = (ggml_bf16_t *)(wbase); - ggml_bf16_t * ybf16 = (ggml_bf16_t *)(wbase + 1*AMX_TILE_MN*ne00*sizeof(ggml_bf16_t)); - float * xf32 = (float *) (wbase + 2*AMX_TILE_MN*ne00*sizeof(ggml_bf16_t)); + ggml_bf16_t * ybf16 = (ggml_bf16_t *)(wbase + 1*AMX_BLCK_SIZE*ne00*sizeof(ggml_bf16_t)); + float * xf32 = (float *) (wbase + 2*AMX_BLCK_SIZE*ne00*sizeof(ggml_bf16_t)); xf32 = (float *) (((size_t)xf32 + 4095) & ~4095); #else float tmp[16]; @@ -12461,6 +12467,27 @@ static void ggml_compute_forward_mul_mat_one_chunk( for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) { for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) { +#if defined(__AMX_TILE__) && defined(__AMX_BF16__) + if (num_rows_per_vec_dot == AMX_TILE_MN) { + const int64_t ii13 = (iir1 / (ne12 * ne1)); + const int64_t ii12 = (iir1 - ii13 * ne12 * ne1) / ne1; + const int64_t ii11 = (iir1 - ii13 * ne12 * ne1 - ii12 * ne1); + + // broadcast src0 into src1 + const int64_t ii03 = ii13 / r3; + const int64_t ii02 = ii12 / r2; + + const char * src0_row = (const char*)src0->data + (0 + ii02 * nb02 + ii03 * nb03); + const uint8_t * src1_col = (const uint8_t *)src1->data + ii11 * nb11 + ii12 * nb12 + ii13 * nb13; + for (int i = 0; i < blck_0 && iir0 + i < ir0_end; ++i) { + to_float((const uint8_t *)src0_row + iir0*nb01 + i*nb01, xf32, ne00); + ggml_fp32_to_bf16_row(xf32, xbf16 + i*ne00, ne00); + } + for (int i = 0; i < blck_1 && iir1 + i < ir1_end; ++i) { + ggml_fp32_to_bf16_row((const float *)(src1_col + i*nb11), ybf16 + i*ne00, ne00); + } + } +#endif for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1 += num_rows_per_vec_dot) { const int64_t i13 = (ir1 / (ne12 * ne1)); const int64_t i12 = (ir1 - i13 * ne12 * ne1) / ne1; @@ -12478,14 +12505,8 @@ static void ggml_compute_forward_mul_mat_one_chunk( float * dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3)); #if defined(__AMX_TILE__) && defined(__AMX_BF16__) if (num_rows_per_vec_dot == AMX_TILE_MN) { - const uint8_t * src1_col = (const uint8_t *) src1->data + i11 * nb11 + i12 * nb12 + i13 * nb13; for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) { - for (int cn = 0; cn < AMX_TILE_MN; ++cn) { - to_float((const uint8_t *)src0_row + ir0*nb01 + cn*nb01, xf32, ne00); - ggml_fp32_to_bf16_row(xf32, xbf16 + cn*ne00, ne00); - ggml_fp32_to_bf16_row((const float *)(src1_col + cn*nb11), ybf16 + cn*ne00, ne00); - } - ggml_vec_dot_bf16(ne00, &tmp[ir0 - iir0], 16, xbf16, ne00*sizeof(ggml_bf16_t), ybf16, ne00*sizeof(ggml_bf16_t), AMX_TILE_MN); + ggml_vec_dot_bf16(ne00, &tmp[ir0 - iir0], blck_0, xbf16 + (ir0-iir0)*ne00, ne00*sizeof(ggml_bf16_t), ybf16 + (ir1-iir1)*ne00, ne00*sizeof(ggml_bf16_t), AMX_TILE_MN); } } else @@ -12505,12 +12526,12 @@ static void ggml_compute_forward_mul_mat_one_chunk( //} for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) { - vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot); + vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? blck_0 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot); } } for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) { - memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * 16), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float)); + memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * blck_0), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float)); } } } @@ -12748,7 +12769,11 @@ UseGgmlGemm2:; #endif // Now select a reasonable chunk size. +#if defined(__AMX_TILE__) && defined(__AMX_BF16__) + int chunk_size = AMX_BLCK_SIZE; +#else int chunk_size = 16; +#endif // We need to step up the size if it's small if (nr0 == 1 || nr1 == 1) { @@ -20314,7 +20339,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa #endif #if defined(__AMX_TILE__) && defined(__AMX_BF16__) if ((node->src[0]->ne[0] % (AMX_TILE_K*4/sizeof(ggml_bf16_t)) == 0) && (node->src[0]->ne[1] % AMX_TILE_MN == 0) && (node->src[1]->ne[1] % AMX_TILE_MN == 0)) { - cur = n_threads*(2*AMX_TILE_MN*node->src[0]->ne[0]*sizeof(ggml_bf16_t)+node->src[0]->ne[0]*sizeof(float)+4096); + cur = n_threads*(2*AMX_BLCK_SIZE*node->src[0]->ne[0]*sizeof(ggml_bf16_t)+node->src[0]->ne[0]*sizeof(float)+4096); } else #endif if (node->src[1]->type != vec_dot_type) {