Skip to content

Commit

Permalink
llama : add phi3 128K model support (ggerganov#7225)
Browse files Browse the repository at this point in the history
* add phi3 128k support in convert-hf-to-gguf

* add phi3 128k support in cuda

* address build warnings on llama.cpp

* adjust index value in cuda long rope freq factors

* add long rope support in ggml cpu backend

* make freq factors only depend on ctx size

* remove unused rope scaling type 'su' frin gguf converter

* fix flint warnings on convert-hf-to-gguf.py

* set to the short freq factor when context size is small than trained context size

* add one line of comments

* metal : support rope freq_factors

* ggml : update ggml_rope_ext API to support freq. factors

* backends : add dev messages to support rope freq. factors

* minor : style

* tests : update to use new rope API

* backends : fix pragma semicolons

* minor : cleanup

* llama : move rope factors from KV header to tensors

* llama : remove tmp assert

* cuda : fix compile warning

* convert : read/write n_head_kv

* llama : fix uninitialized tensors

---------

Co-authored-by: Georgi Gerganov <[email protected]>
  • Loading branch information
liuwei-git and ggerganov authored May 21, 2024
1 parent 6369bf0 commit 201cc11
Show file tree
Hide file tree
Showing 15 changed files with 478 additions and 227 deletions.
49 changes: 43 additions & 6 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from hashlib import sha256
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast

import math
import numpy as np
import torch

Expand Down Expand Up @@ -1784,23 +1785,59 @@ def set_vocab(self):
def set_gguf_parameters(self):
block_count = self.find_hparam(["num_hidden_layers", "n_layer"])

rot_pct = 1.0
n_embd = self.find_hparam(["hidden_size", "n_embd"])
n_head = self.find_hparam(["num_attention_heads", "n_head"])
n_head_kv = self.find_hparam(["num_key_value_heads", "n_head_kv"])
rms_eps = self.find_hparam(["rms_norm_eps"])
max_pos_embds = self.find_hparam(["n_positions", "max_position_embeddings"])
orig_max_pos_embds = self.find_hparam(["original_max_position_embeddings"])
rope_dims = n_embd // n_head

self.gguf_writer.add_name("Phi3")
self.gguf_writer.add_context_length(self.find_hparam(["n_positions", "max_position_embeddings"]))

self.gguf_writer.add_context_length(max_pos_embds)
self.gguf_writer.add_rope_scaling_orig_ctx_len(orig_max_pos_embds)
self.gguf_writer.add_embedding_length(n_embd)
self.gguf_writer.add_feed_forward_length(8192)
self.gguf_writer.add_feed_forward_length(self.find_hparam(["intermediate_size"]))
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_head_count(n_head)
self.gguf_writer.add_head_count_kv(n_head)
self.gguf_writer.add_head_count_kv(n_head_kv)
self.gguf_writer.add_layer_norm_rms_eps(rms_eps)
self.gguf_writer.add_rope_dimension_count(int(rot_pct * n_embd) // n_head)
self.gguf_writer.add_rope_dimension_count(rope_dims)
self.gguf_writer.add_rope_freq_base(self.find_hparam(["rope_theta"]))
self.gguf_writer.add_file_type(self.ftype)

# write rope scaling for long context (128k) model
rope_scaling = self.find_hparam(['rope_scaling'], True)
if (rope_scaling is None):
return

scale = max_pos_embds / orig_max_pos_embds

rope_scaling_type = rope_scaling.get('type', '').lower()
if len(rope_scaling_type) == 0:
raise KeyError('Missing the required key rope_scaling.type')

if rope_scaling_type == 'su':
attn_factor = math.sqrt(1 + math.log(scale) / math.log(orig_max_pos_embds)) if scale > 1.0 else 1.0
elif rope_scaling_type == 'yarn':
attn_factor = 0.1 * math.log(scale) + 1.0 if scale > 1.0 else 1.0
else:
raise NotImplementedError(f'The rope scaling type {rope_scaling_type} is not supported yet')

self.gguf_writer.add_rope_scaling_attn_factors(attn_factor)

long_factors = rope_scaling.get('long_factor', None)
short_factors = rope_scaling.get('short_factor', None)

if long_factors is None or short_factors is None:
raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor')

if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2:
raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}')

self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_LONG] + ".weight", np.array(long_factors, dtype=np.float32))
self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT] + ".weight", np.array(short_factors, dtype=np.float32))


@Model.register("PlamoForCausalLM")
class PlamoModel(Model):
Expand Down
4 changes: 2 additions & 2 deletions examples/finetune/finetune.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -563,8 +563,8 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
// not capturing these, to silcence warnings
const int rope_mode = 0;

return ggml_rope_custom(ctx,
t, KQ_pos, n_rot, rope_mode, n_ctx, 0,
return ggml_rope_ext(ctx,
t, KQ_pos, nullptr, n_rot, rope_mode, n_ctx, 0,
rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f
);
};
Expand Down
4 changes: 2 additions & 2 deletions examples/train-text-from-scratch/train-text-from-scratch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,8 @@ static struct ggml_tensor * llama_build_train_graphs(
// not capturing these, to silcence warnings
const int rope_mode = 0;

return ggml_rope_custom(
ctx, t, KQ_pos, n_rot, rope_mode, n_ctx, 0, rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f
return ggml_rope_ext(
ctx, t, KQ_pos, nullptr, n_rot, rope_mode, n_ctx, 0, rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f
);
};

Expand Down
72 changes: 48 additions & 24 deletions ggml-cuda/rope.cu
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ static __global__ void rope(
dst[i + 1] = x0*sin_theta + x1*cos_theta;
}

template<typename T, bool has_pos>
template<typename T, bool has_pos, bool has_freq_facs>
static __global__ void rope_neox(
const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims, const float * freq_factors
) {
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);

Expand All @@ -88,7 +88,9 @@ static __global__ void rope_neox(
float cur_rot = inv_ndims * ic - ib;

const int p = has_pos ? pos[i2] : 0;
const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f);
const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f;

const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f)/freq_factor;

float cos_theta, sin_theta;
rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
Expand Down Expand Up @@ -164,7 +166,7 @@ static void rope_cuda(
template<typename T>
static void rope_neox_cuda(
const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
) {
GGML_ASSERT(ncols % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
Expand All @@ -175,15 +177,29 @@ static void rope_neox_cuda(
const float inv_ndims = -1.0f / n_dims;

if (pos == nullptr) {
rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims
);
if (freq_factors == nullptr) {
rope_neox<T, false, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims, freq_factors
);
} else {
rope_neox<T, false, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims, freq_factors
);
}
} else {
rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims
);
if (freq_factors == nullptr) {
rope_neox<T, true, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims, freq_factors
);
} else {
rope_neox<T, true, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims, freq_factors
);
}
}
}

Expand Down Expand Up @@ -214,24 +230,27 @@ static void rope_cuda_f32(

static void rope_neox_cuda_f16(
const half * x, half * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) {
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {

rope_neox_cuda<half>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream);
rope_neox_cuda<half>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
}

static void rope_neox_cuda_f32(
const float * x, float * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
) {

rope_neox_cuda<float>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream);
rope_neox_cuda<float>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
}

void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src2 = dst->src[2];

const float * src0_d = (const float *)src0->data;
const float * src1_d = (const float *)src1->data;

float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();

Expand All @@ -241,7 +260,6 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne2 = dst->ne[2];
const int64_t nrows = ggml_nrows(src0);

//const int n_past = ((int32_t *) dst->op_params)[0];
Expand All @@ -259,16 +277,22 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));

const float * freq_factors = nullptr;
const int32_t * pos = nullptr;
if ((mode & 1) == 0) {
GGML_ASSERT(src1->type == GGML_TYPE_I32);
GGML_ASSERT(src1->ne[0] == ne2);
pos = (const int32_t *) src1_d;
}

const bool is_neox = mode & 2;
const bool is_glm = mode & 4;

if (is_neox) {
pos = (const int32_t *) src1_d;

if (src2 != nullptr) {
freq_factors = (const float *) src2->data;
}
} else {
GGML_ASSERT(src2 == nullptr && "TODO: freq_factors not implemented for !is_neox");
}

rope_corr_dims corr_dims;
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);

Expand All @@ -280,12 +304,12 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
if (src0->type == GGML_TYPE_F32) {
rope_neox_cuda_f32(
(const float *)src0_d, (float *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, stream
attn_factor, corr_dims, freq_factors, stream
);
} else if (src0->type == GGML_TYPE_F16) {
rope_neox_cuda_f16(
(const half *)src0_d, (half *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, stream
attn_factor, corr_dims, freq_factors, stream
);
} else {
GGML_ASSERT(false);
Expand Down
4 changes: 4 additions & 0 deletions ggml-kompute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1677,6 +1677,10 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
} break;
case GGML_OP_ROPE:
{
#pragma message("TODO: implement phi3 frequency factors support")
#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225")
GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");

GGML_ASSERT(ne10 == ne02);
GGML_ASSERT(src0t == dstt);
// const int n_past = ((int32_t *) dst->op_params)[0];
Expand Down
Loading

0 comments on commit 201cc11

Please sign in to comment.