From 9945a7a29be8c7e83a712cb82e07adbb24861c7f Mon Sep 17 00:00:00 2001 From: xigui wang Date: Thu, 1 Feb 2024 01:33:00 -0800 Subject: [PATCH] Implement the YaRN rop scaling feature Interpolate the rotary postion embedding Only inference is implemented, training is not implemetned. --- neural_speed/core/ne_layers.c | 116 ++++++++++++++++++++++++++++------ neural_speed/core/ne_layers.h | 14 ++++ 2 files changed, 112 insertions(+), 18 deletions(-) diff --git a/neural_speed/core/ne_layers.c b/neural_speed/core/ne_layers.c index 89b40321e..d5f571a2f 100644 --- a/neural_speed/core/ne_layers.c +++ b/neural_speed/core/ne_layers.c @@ -2996,7 +2996,9 @@ struct ne_tensor* ne_soft_max_inplace(struct ne_context* ctx, struct ne_tensor* struct ne_tensor* ne_rope_impl(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode, int prompt_size, bool inplace, int n_keep, struct ne_tensor* cossin, int* n_padding, - bool padding_left, float freq_base, float freq_scale) { + bool padding_left, float freq_base, float freq_scale, + int yarn_orig_ctx, float ext_factor, float attn_factor, + float beta_fast, float beta_slow) { NE_ASSERT(n_past >= 0 || n_keep >= 0); NE_ASSERT(padding_left); bool is_node = false; @@ -3036,7 +3038,9 @@ struct ne_tensor* ne_rope_impl(struct ne_context* ctx, struct ne_tensor* a, int ne_scratch_load(ctx); - float params[] = {freq_base, freq_scale}; + /* what the diffrence of setting parameters in b->data and in op_parameters */ + /* float and int are in different data ?? */ + float params[] = {freq_base, freq_scale, (float)yarn_orig_ctx, ext_factor, attn_factor, beta_fast, beta_slow}; ne_set_op_params(result, ¶ms, sizeof(params)); result->op = NE_OP_ROPE; @@ -3050,19 +3054,36 @@ struct ne_tensor* ne_rope_impl(struct ne_context* ctx, struct ne_tensor* a, int struct ne_tensor* ne_rope(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode, int prompt_size, float freq_base, float freq_scale) { - return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, false, -1, NULL, NULL, true, freq_base, freq_scale); + return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, false, -1, NULL, NULL, true, freq_base, freq_scale, + 0, 0.0f, 1.0f, 0.0f, 0.0f); } struct ne_tensor* ne_rope_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode, int prompt_size, float freq_base, float freq_scale) { - return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, true, -1, NULL, NULL, true, freq_base, freq_scale); + return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, true, -1, NULL, NULL, true, freq_base, freq_scale, + 0, 0.0f, 1.0f, 0.0f, 0.0f); } struct ne_tensor* ne_rope_shift_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_shift, int n_dims, int mode, int prompt_size, int n_keep, struct ne_tensor* cossin, float freq_base, float freq_scale) { return ne_rope_impl(ctx, a, n_shift, n_dims, mode, prompt_size, true, n_keep, cossin, NULL, true, freq_base, - freq_scale); + freq_scale, 0, 0.0f, 1.0f, 0.0f, 0.0f); +} + +struct ne_tensor* ne_rope_custom_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode, + int prompt_size, float freq_base, float freq_scale, + int yarn_orig_ctx, float ext_factor, float attn_factor, float beta_fast, float beta_slow) { + return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, true, -1, NULL, NULL, true, freq_base, freq_scale, + yarn_orig_ctx, ext_factor, attn_factor, beta_fast, beta_slow); +} + +struct ne_tensor* ne_rope_custom_shift_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_shift, int n_dims, int mode, + int prompt_size, int n_keep, struct ne_tensor* cossin, float freq_base, + float freq_scale, + int yarn_orig_ctx, float ext_factor, float attn_factor, float beta_fast, float beta_slow) { + return ne_rope_impl(ctx, a, n_shift, n_dims, mode, prompt_size, true, n_keep, cossin, NULL, true, freq_base, freq_scale, + yarn_orig_ctx, ext_factor, attn_factor, beta_fast, beta_slow); } // ne_rope_back @@ -3100,14 +3121,14 @@ struct ne_tensor* ne_rope_back(struct ne_context* ctx, struct ne_tensor* a, int struct ne_tensor* ne_rope_with_padding(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode, int prompt_size, int* n_padding, float freq_base, float freq_scale) { return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, false, -1, NULL, n_padding, true, freq_base, - freq_scale); + freq_scale, 0, 0.0f, 1.0f, 0.0f, 0.0f); } struct ne_tensor* ne_rope_with_padding_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode, int prompt_size, int* n_padding, float freq_base, float freq_scale) { return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, true, -1, NULL, n_padding, true, freq_base, - freq_scale); + freq_scale, 0, 0.0f, 1.0f, 0.0f, 0.0f); } // ne_alibi @@ -7902,6 +7923,45 @@ static void ne_compute_forward_clamp(const struct ne_compute_params* params, con } } +static float rope_yarn_ramp(const float low, const float high, const int i0) { + const float y = (i0 / 2 - low) / MAX(0.001f, high - low); + return 1.0 - MIN(1.0, MAX(0.0, y)); +} + +// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn +// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. +static void rope_yarn( + float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale, + float * cos_theta, float * sin_theta +) { + // Get n-d rotational scaling corrected for extrapolation + float theta_interp = freq_scale * theta_extrap; + float theta = theta_interp; + if (ext_factor != 0.0f) { + float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; + theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + + // Get n-d magnitude scaling corrected for interpolation + mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale); + } + *cos_theta = cosf(theta) * mscale; + *sin_theta = sinf(theta) * mscale; +} + +// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get +// `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` +static float ggml_rope_yarn_corr_dim(int n_dims, int n_orig_ctx, float n_rot, float base) { + return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)M_PI)) / (2 * logf(base)); +} + +void ggml_rope_yarn_corr_dims( + int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2] +) { + // start and end correction dims + dims[0] = MAX(0, floorf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_fast, freq_base))); + dims[1] = MIN(n_dims - 1, ceilf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_slow, freq_base))); +} + // ne_compute_forward_rope #define NE_TENSOR_UNARY_OP_LOCALS \ NE_TENSOR_LOCALS(int64_t, ne0, src0, ne); \ @@ -7914,12 +7974,18 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params, if (params->type == NE_TASK_INIT || params->type == NE_TASK_FINALIZE) { return; } + const int bs = src0->ne[3]; NE_ASSERT(src1->type == NE_TYPE_I32); NE_ASSERT(ne_nelements(src1) == 5 + bs); // 5 + bs params const float freq_base = ((float*)(dst->op_params))[0]; const float freq_scale = 1 / ((float*)(dst->op_params))[1]; + const int n_orig_ctx = (int)((float*)(dst->op_params))[2]; + const float ext_factor = ((float*)(dst->op_params))[3]; + const float attn_factor = ((float*)(dst->op_params))[4]; + const float beta_fast = ((float*)(dst->op_params))[5]; + const float beta_slow = ((float*)(dst->op_params))[6]; const int64_t n_past = ((int32_t*)src1->data)[ROPE_NPAST_IDX]; const int64_t n_dims = ((int32_t*)src1->data)[ROPE_NDIMS_IDX]; @@ -7952,11 +8018,15 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params, int ir = 0; const float theta_scale = powf(freq_base, -2.0f / n_dims); + const float inv_ndims = -1.f/n_dims; + float corr_dims[2]; + ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); const bool skip = mode & 1; const bool is_neox = mode & 2; const bool is_glm = mode & 4; const bool is_shift = n_keep >= 0; + const bool use_yarn = ((mode & 0x8) != 0); NE_ASSERT(("RoPE shift not supported!", !is_shift)); NE_ASSERT(ne3 == bs); @@ -7967,21 +8037,21 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params, if (ir++ < ir0) continue; if (ir > ir1) break; - float theta = freq_scale * (float)p; + float theta_base = (float)p; // only for glm when mode == 4 if (is_glm) { const int64_t n_padding = ((int32_t*)src1->data)[ROPE_PARAMS_NUM + i3]; // position ids - theta = MIN(MAX(p - n_padding, 0), prompt_size - 2 - n_padding); + theta_base = MIN(MAX(p - n_padding, 0), prompt_size - 2 - n_padding); float block_theta = MAX(p - (prompt_size - 2), 0); for (int64_t i0 = 0; i0 < ne0 / 4; i0++) { - const float cos_theta = cosf(theta); - const float sin_theta = sinf(theta); + const float cos_theta = cosf(theta_base); + const float sin_theta = sinf(theta_base); const float cos_block_theta = cosf(block_theta); const float sin_block_theta = sinf(block_theta); - theta *= theta_scale; + theta_base *= theta_scale; block_theta *= theta_scale; const float* const src = (float*)((char*)src0->data + i3 * nb03 + i2 * nb02 + i1 * nb01 + i0 * nb00); @@ -7998,11 +8068,14 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params, dst_data[n_dims / 2 * 3] = x2 * sin_block_theta + x3 * cos_block_theta; } } else if (!is_neox) { + //printf("theta_base = %ld, freq_scale %.4f, ne0 %d\n", p, freq_scale, ne0); for (int64_t i0 = 0; i0 < ne0; i0 += 2) { - const float cos_theta = cosf(theta); - const float sin_theta = sinf(theta); + float cos_theta, sin_theta; + rope_yarn( + theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta + ); - theta *= theta_scale; // theta = i2 * theta_scale^(i0/2) + theta_base *= theta_scale; const float* const src = (float*)((char*)src0->data + i3 * nb03 + i2 * nb02 + i1 * nb01 + i0 * nb00); float* dst_data = (float*)((char*)dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0); @@ -8017,12 +8090,19 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params, // TODO: this is probably wrong, but I can't figure it out .. // ref: // https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28 + theta_base = theta_base * freq_scale; + for (int64_t ib = 0; ib < ne0 / n_dims; ++ib) { for (int64_t ic = 0; ic < n_dims; ic += 2) { - const float cos_theta = cosf(theta); - const float sin_theta = sinf(theta); + // simplified from `(ib * n_dims + ic) * inv_ndims` + float cur_rot = inv_ndims * ic - ib; - theta *= theta_scale; + float cos_theta, sin_theta; + rope_yarn( + theta_base, freq_scale, corr_dims, (int)cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta + ); + + theta_base *= theta_scale; const int64_t i0 = ib * n_dims + ic / 2; diff --git a/neural_speed/core/ne_layers.h b/neural_speed/core/ne_layers.h index 032283696..ab3c819ef 100644 --- a/neural_speed/core/ne_layers.h +++ b/neural_speed/core/ne_layers.h @@ -405,6 +405,20 @@ NE_API struct ne_tensor* ne_rope_shift_inplace(struct ne_context* ctx, struct ne int mode, int prompt_size, int n_keep, struct ne_tensor* cossin, float freq_base, float freq_scale); +// in-place, returns view(a) +NE_API struct ne_tensor* ne_rope_custom_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode, + int prompt_size, float freq_base, float freq_scale, + int yarn_orig_ctx, float ext_factor, float attn_factor, + float beta_fast, float beta_slow); + +// shift all tokens by a give p (n_shift) +// Optionally give a 1d tensor of precomputed interleaved cos/sin value of n_shift*scale^k for k \in [0, n_dims) +NE_API struct ne_tensor* ne_rope_custom_shift_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_shift, int n_dims, + int mode, int prompt_size, int n_keep, struct ne_tensor* cossin, + float freq_base, float freq_scale, + int yarn_orig_ctx, float ext_factor, float attn_factor, + float beta_fast, float beta_slow); + // rotary position embedding backward, i.e compute dx from dy // a - dy NE_API struct ne_tensor* ne_rope_back(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode);