Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
Implement the YaRN rop scaling feature
Browse files Browse the repository at this point in the history
Interpolate the rotary postion embedding
Only inference is implemented, training is not implemetned.
  • Loading branch information
xigui wang committed Feb 1, 2024
1 parent 26c68c7 commit 9945a7a
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 18 deletions.
116 changes: 98 additions & 18 deletions neural_speed/core/ne_layers.c
Original file line number Diff line number Diff line change
Expand Up @@ -2996,7 +2996,9 @@ struct ne_tensor* ne_soft_max_inplace(struct ne_context* ctx, struct ne_tensor*

struct ne_tensor* ne_rope_impl(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode,
int prompt_size, bool inplace, int n_keep, struct ne_tensor* cossin, int* n_padding,
bool padding_left, float freq_base, float freq_scale) {
bool padding_left, float freq_base, float freq_scale,
int yarn_orig_ctx, float ext_factor, float attn_factor,
float beta_fast, float beta_slow) {
NE_ASSERT(n_past >= 0 || n_keep >= 0);
NE_ASSERT(padding_left);
bool is_node = false;
Expand Down Expand Up @@ -3036,7 +3038,9 @@ struct ne_tensor* ne_rope_impl(struct ne_context* ctx, struct ne_tensor* a, int

ne_scratch_load(ctx);

float params[] = {freq_base, freq_scale};
/* what the diffrence of setting parameters in b->data and in op_parameters */
/* float and int are in different data ?? */
float params[] = {freq_base, freq_scale, (float)yarn_orig_ctx, ext_factor, attn_factor, beta_fast, beta_slow};
ne_set_op_params(result, &params, sizeof(params));

result->op = NE_OP_ROPE;
Expand All @@ -3050,19 +3054,36 @@ struct ne_tensor* ne_rope_impl(struct ne_context* ctx, struct ne_tensor* a, int

struct ne_tensor* ne_rope(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode,
int prompt_size, float freq_base, float freq_scale) {
return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, false, -1, NULL, NULL, true, freq_base, freq_scale);
return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, false, -1, NULL, NULL, true, freq_base, freq_scale,
0, 0.0f, 1.0f, 0.0f, 0.0f);
}

struct ne_tensor* ne_rope_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode,
int prompt_size, float freq_base, float freq_scale) {
return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, true, -1, NULL, NULL, true, freq_base, freq_scale);
return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, true, -1, NULL, NULL, true, freq_base, freq_scale,
0, 0.0f, 1.0f, 0.0f, 0.0f);
}

struct ne_tensor* ne_rope_shift_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_shift, int n_dims, int mode,
int prompt_size, int n_keep, struct ne_tensor* cossin, float freq_base,
float freq_scale) {
return ne_rope_impl(ctx, a, n_shift, n_dims, mode, prompt_size, true, n_keep, cossin, NULL, true, freq_base,
freq_scale);
freq_scale, 0, 0.0f, 1.0f, 0.0f, 0.0f);
}

struct ne_tensor* ne_rope_custom_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode,
int prompt_size, float freq_base, float freq_scale,
int yarn_orig_ctx, float ext_factor, float attn_factor, float beta_fast, float beta_slow) {
return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, true, -1, NULL, NULL, true, freq_base, freq_scale,
yarn_orig_ctx, ext_factor, attn_factor, beta_fast, beta_slow);
}

struct ne_tensor* ne_rope_custom_shift_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_shift, int n_dims, int mode,
int prompt_size, int n_keep, struct ne_tensor* cossin, float freq_base,
float freq_scale,
int yarn_orig_ctx, float ext_factor, float attn_factor, float beta_fast, float beta_slow) {
return ne_rope_impl(ctx, a, n_shift, n_dims, mode, prompt_size, true, n_keep, cossin, NULL, true, freq_base, freq_scale,
yarn_orig_ctx, ext_factor, attn_factor, beta_fast, beta_slow);
}

// ne_rope_back
Expand Down Expand Up @@ -3100,14 +3121,14 @@ struct ne_tensor* ne_rope_back(struct ne_context* ctx, struct ne_tensor* a, int
struct ne_tensor* ne_rope_with_padding(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode,
int prompt_size, int* n_padding, float freq_base, float freq_scale) {
return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, false, -1, NULL, n_padding, true, freq_base,
freq_scale);
freq_scale, 0, 0.0f, 1.0f, 0.0f, 0.0f);
}

struct ne_tensor* ne_rope_with_padding_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims,
int mode, int prompt_size, int* n_padding, float freq_base,
float freq_scale) {
return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, true, -1, NULL, n_padding, true, freq_base,
freq_scale);
freq_scale, 0, 0.0f, 1.0f, 0.0f, 0.0f);
}

// ne_alibi
Expand Down Expand Up @@ -7902,6 +7923,45 @@ static void ne_compute_forward_clamp(const struct ne_compute_params* params, con
}
}

static float rope_yarn_ramp(const float low, const float high, const int i0) {
const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
return 1.0 - MIN(1.0, MAX(0.0, y));
}

// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
static void rope_yarn(
float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
float * cos_theta, float * sin_theta
) {
// Get n-d rotational scaling corrected for extrapolation
float theta_interp = freq_scale * theta_extrap;
float theta = theta_interp;
if (ext_factor != 0.0f) {
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;

// Get n-d magnitude scaling corrected for interpolation
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
}
*cos_theta = cosf(theta) * mscale;
*sin_theta = sinf(theta) * mscale;
}

// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
// `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
static float ggml_rope_yarn_corr_dim(int n_dims, int n_orig_ctx, float n_rot, float base) {
return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
}

void ggml_rope_yarn_corr_dims(
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
) {
// start and end correction dims
dims[0] = MAX(0, floorf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_fast, freq_base)));
dims[1] = MIN(n_dims - 1, ceilf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_slow, freq_base)));
}

// ne_compute_forward_rope
#define NE_TENSOR_UNARY_OP_LOCALS \
NE_TENSOR_LOCALS(int64_t, ne0, src0, ne); \
Expand All @@ -7914,12 +7974,18 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params,
if (params->type == NE_TASK_INIT || params->type == NE_TASK_FINALIZE) {
return;
}

const int bs = src0->ne[3];
NE_ASSERT(src1->type == NE_TYPE_I32);
NE_ASSERT(ne_nelements(src1) == 5 + bs); // 5 + bs params

const float freq_base = ((float*)(dst->op_params))[0];
const float freq_scale = 1 / ((float*)(dst->op_params))[1];
const int n_orig_ctx = (int)((float*)(dst->op_params))[2];
const float ext_factor = ((float*)(dst->op_params))[3];
const float attn_factor = ((float*)(dst->op_params))[4];
const float beta_fast = ((float*)(dst->op_params))[5];
const float beta_slow = ((float*)(dst->op_params))[6];

const int64_t n_past = ((int32_t*)src1->data)[ROPE_NPAST_IDX];
const int64_t n_dims = ((int32_t*)src1->data)[ROPE_NDIMS_IDX];
Expand Down Expand Up @@ -7952,11 +8018,15 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params,
int ir = 0;

const float theta_scale = powf(freq_base, -2.0f / n_dims);
const float inv_ndims = -1.f/n_dims;
float corr_dims[2];
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);

const bool skip = mode & 1;
const bool is_neox = mode & 2;
const bool is_glm = mode & 4;
const bool is_shift = n_keep >= 0;
const bool use_yarn = ((mode & 0x8) != 0);
NE_ASSERT(("RoPE shift not supported!", !is_shift));

NE_ASSERT(ne3 == bs);
Expand All @@ -7967,21 +8037,21 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params,
if (ir++ < ir0) continue;
if (ir > ir1) break;

float theta = freq_scale * (float)p;
float theta_base = (float)p;

// only for glm when mode == 4
if (is_glm) {
const int64_t n_padding = ((int32_t*)src1->data)[ROPE_PARAMS_NUM + i3];
// position ids
theta = MIN(MAX(p - n_padding, 0), prompt_size - 2 - n_padding);
theta_base = MIN(MAX(p - n_padding, 0), prompt_size - 2 - n_padding);
float block_theta = MAX(p - (prompt_size - 2), 0);
for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
const float cos_theta = cosf(theta);
const float sin_theta = sinf(theta);
const float cos_theta = cosf(theta_base);
const float sin_theta = sinf(theta_base);
const float cos_block_theta = cosf(block_theta);
const float sin_block_theta = sinf(block_theta);

theta *= theta_scale;
theta_base *= theta_scale;
block_theta *= theta_scale;

const float* const src = (float*)((char*)src0->data + i3 * nb03 + i2 * nb02 + i1 * nb01 + i0 * nb00);
Expand All @@ -7998,11 +8068,14 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params,
dst_data[n_dims / 2 * 3] = x2 * sin_block_theta + x3 * cos_block_theta;
}
} else if (!is_neox) {
//printf("theta_base = %ld, freq_scale %.4f, ne0 %d\n", p, freq_scale, ne0);
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
const float cos_theta = cosf(theta);
const float sin_theta = sinf(theta);
float cos_theta, sin_theta;
rope_yarn(
theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
);

theta *= theta_scale; // theta = i2 * theta_scale^(i0/2)
theta_base *= theta_scale;

const float* const src = (float*)((char*)src0->data + i3 * nb03 + i2 * nb02 + i1 * nb01 + i0 * nb00);
float* dst_data = (float*)((char*)dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0);
Expand All @@ -8017,12 +8090,19 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params,
// TODO: this is probably wrong, but I can't figure it out ..
// ref:
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
theta_base = theta_base * freq_scale;

for (int64_t ib = 0; ib < ne0 / n_dims; ++ib) {
for (int64_t ic = 0; ic < n_dims; ic += 2) {
const float cos_theta = cosf(theta);
const float sin_theta = sinf(theta);
// simplified from `(ib * n_dims + ic) * inv_ndims`
float cur_rot = inv_ndims * ic - ib;

theta *= theta_scale;
float cos_theta, sin_theta;
rope_yarn(
theta_base, freq_scale, corr_dims, (int)cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta
);

theta_base *= theta_scale;

const int64_t i0 = ib * n_dims + ic / 2;

Expand Down
14 changes: 14 additions & 0 deletions neural_speed/core/ne_layers.h
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,20 @@ NE_API struct ne_tensor* ne_rope_shift_inplace(struct ne_context* ctx, struct ne
int mode, int prompt_size, int n_keep, struct ne_tensor* cossin,
float freq_base, float freq_scale);

// in-place, returns view(a)
NE_API struct ne_tensor* ne_rope_custom_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode,
int prompt_size, float freq_base, float freq_scale,
int yarn_orig_ctx, float ext_factor, float attn_factor,
float beta_fast, float beta_slow);

// shift all tokens by a give p (n_shift)
// Optionally give a 1d tensor of precomputed interleaved cos/sin value of n_shift*scale^k for k \in [0, n_dims)
NE_API struct ne_tensor* ne_rope_custom_shift_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_shift, int n_dims,
int mode, int prompt_size, int n_keep, struct ne_tensor* cossin,
float freq_base, float freq_scale,
int yarn_orig_ctx, float ext_factor, float attn_factor,
float beta_fast, float beta_slow);

// rotary position embedding backward, i.e compute dx from dy
// a - dy
NE_API struct ne_tensor* ne_rope_back(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode);
Expand Down

0 comments on commit 9945a7a

Please sign in to comment.