From 1f0fea70fb761d10e2264cbdcf4852ed32706c89 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Thu, 1 Aug 2024 10:43:42 -0400 Subject: [PATCH 01/18] llama : initial Mamba-2 support --- convert_hf_to_gguf.py | 67 ++++++++ ggml/include/ggml.h | 3 +- ggml/src/ggml.c | 193 ++++++++++++++-------- gguf-py/gguf/constants.py | 19 +++ gguf-py/gguf/gguf_writer.py | 3 + gguf-py/gguf/tensor_mapping.py | 6 +- src/llama.cpp | 291 +++++++++++++++++++++++++++++++-- 7 files changed, 495 insertions(+), 87 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 108c822cff5d2..0ac64574a3043 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2788,6 +2788,73 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [(new_name, data_torch)] +@Model.register("Mamba2ForCausalLM") +class Mamba2Model(Model): + model_arch = gguf.MODEL_ARCH.MAMBA2 + + def set_vocab(self): + vocab_size = self.hparams["vocab_size"] + # Round vocab size to next multiple of 16 + pad_vocab = self.hparams.get("pad_vocab_size_multiple", 16) + # pad using ceiling division + # ref: https://stackoverflow.com/a/17511341/22827863 + vocab_size = -(vocab_size // -pad_vocab) * pad_vocab + self.hparams["vocab_size"] = vocab_size + + if (self.dir_model / "tokenizer.json").is_file(): + self._set_vocab_gpt2() + elif (self.dir_model / "tokenizer.model").is_file(): + self._set_vocab_sentencepiece() + elif (self.dir_model / "tokenizer.model.v3").is_file(): + # mamba-codestral + raise NotImplementedError(f"Please rename {self.dir_model / 'tokenizer.model.v3'} to {self.dir_model / 'tokenizer.model'}") + else: + # Use the GPT-NeoX tokenizer when no tokenizer files are present + self._set_vocab_builtin("gpt-neox", vocab_size) + + def set_gguf_parameters(self): + d_model = self.find_hparam(["hidden_size", "d_model", "dim"]) + d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4 + d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model + d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 128 + head_dim = self.find_hparam(["head_dim"], optional=True) or 64 + n_group = self.find_hparam(["n_groups"], optional=True) or 1 + + rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5 + + # Fail early for models which don't have a block expansion factor of 2 + # TODO: does this really matter? + assert d_inner == 2 * d_model + assert d_inner % head_dim == 0 + + self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default + self.gguf_writer.add_embedding_length(d_model) + self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading + self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading + self.gguf_writer.add_block_count(self.block_count) + self.gguf_writer.add_ssm_conv_kernel(d_conv) + self.gguf_writer.add_ssm_inner_size(d_inner) + self.gguf_writer.add_ssm_state_size(d_state) + self.gguf_writer.add_ssm_time_step_rank(d_inner // head_dim) + self.gguf_writer.add_ssm_group_count(n_group) + self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps) + self.gguf_writer.add_file_type(self.ftype) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + if name.endswith(".dt_bias"): + name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias" + + new_name = self.map_tensor_name(name) + + if name.endswith(".A_log"): + logger.debug("A_log --> A ==> " + new_name) + data_torch = -torch.exp(data_torch) + + yield (new_name, data_torch) + + @Model.register("CohereForCausalLM") class CommandR2Model(Model): model_arch = gguf.MODEL_ARCH.COMMAND_R diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index b8a21a2ccc3f0..59e0022dd4286 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1787,7 +1787,8 @@ extern "C" { struct ggml_tensor * dt, struct ggml_tensor * A, struct ggml_tensor * B, - struct ggml_tensor * C); + struct ggml_tensor * C, + struct ggml_tensor * D); // partition into non-overlapping windows with padding if needed // example: diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index d63c917a5705a..6668209081b6c 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -7270,32 +7270,48 @@ struct ggml_tensor * ggml_ssm_scan( struct ggml_tensor * dt, struct ggml_tensor * A, struct ggml_tensor * B, - struct ggml_tensor * C) { + struct ggml_tensor * C, + struct ggml_tensor * D) { GGML_ASSERT(ggml_is_contiguous(s)); - GGML_ASSERT(ggml_is_contiguous(x)); GGML_ASSERT(ggml_is_contiguous(dt)); GGML_ASSERT(ggml_is_contiguous(A)); - GGML_ASSERT(ggml_is_matrix(A)); - GGML_ASSERT(ggml_is_3d(B)); - GGML_ASSERT(ggml_is_3d(s)); + GGML_ASSERT(x->nb[0] == ggml_type_size(x->type)); GGML_ASSERT(B->nb[0] == ggml_type_size(B->type)); GGML_ASSERT(C->nb[0] == ggml_type_size(C->type)); - GGML_ASSERT(ggml_are_same_shape(x, dt)); + GGML_ASSERT(x->nb[1] == x->ne[0]*x->nb[0]); + GGML_ASSERT(B->nb[1] == B->ne[0]*B->nb[0]); + GGML_ASSERT(C->nb[1] == C->ne[0]*C->nb[0]); GGML_ASSERT(ggml_are_same_shape(B, C)); { const int64_t d_state = s->ne[0]; - const int64_t d_inner = s->ne[1]; - const int64_t n_seq_tokens = x->ne[1]; - const int64_t n_seqs = x->ne[2]; - - GGML_ASSERT(s->ne[2] == n_seqs); - GGML_ASSERT(x->ne[0] == d_inner); - GGML_ASSERT(A->ne[0] == d_state); - GGML_ASSERT(A->ne[1] == d_inner); + const int64_t head_dim = x->ne[0]; + const int64_t n_head = x->ne[1]; + const int64_t n_seq_tokens = x->ne[2]; + const int64_t n_seqs = x->ne[3]; + + GGML_ASSERT(dt->ne[0] == n_head); + GGML_ASSERT(dt->ne[1] == n_seq_tokens); + GGML_ASSERT(dt->ne[2] == n_seqs); + GGML_ASSERT(ggml_is_3d(dt)); + GGML_ASSERT(s->ne[1] == head_dim); + GGML_ASSERT(s->ne[2] == n_head); + GGML_ASSERT(s->ne[3] == n_seqs); GGML_ASSERT(B->ne[0] == d_state); - GGML_ASSERT(B->ne[1] == n_seq_tokens); - GGML_ASSERT(B->ne[2] == n_seqs); + GGML_ASSERT(B->ne[2] == n_seq_tokens); + GGML_ASSERT(B->ne[3] == n_seqs); + GGML_ASSERT(D->ne[0] == n_head); + GGML_ASSERT(ggml_is_vector(D)); + + if (ggml_is_vector(A)) { + // Mamba-2 + GGML_ASSERT(A->ne[0] == n_head); + } else { + // Mamba-1 + GGML_ASSERT(A->ne[0] == d_state); + GGML_ASSERT(A->ne[1] == n_head); + GGML_ASSERT(ggml_is_matrix(A)); + } } bool is_node = false; @@ -7316,6 +7332,7 @@ struct ggml_tensor * ggml_ssm_scan( result->src[3] = A; result->src[4] = B; result->src[5] = C; + result->src[6] = D; return result; } @@ -15840,20 +15857,25 @@ static void ggml_compute_forward_ssm_conv( static void ggml_compute_forward_ssm_scan_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { - const struct ggml_tensor * src0 = dst->src[0]; // s - const struct ggml_tensor * src1 = dst->src[1]; // x - const struct ggml_tensor * src2 = dst->src[2]; // dt - const struct ggml_tensor * src3 = dst->src[3]; // A - const struct ggml_tensor * src4 = dst->src[4]; // B - const struct ggml_tensor * src5 = dst->src[5]; // C + const struct ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs} + const struct ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs} + const struct ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs} + const struct ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {n_head} + const struct ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs} + const struct ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs} + const struct ggml_tensor * src6 = dst->src[6]; // D {n_head} const int ith = params->ith; const int nth = params->nth; - const int64_t nc = src0->ne[0]; // d_state - const int64_t nr = src0->ne[1]; // d_inner - const int64_t n_t = src1->ne[1]; // number of tokens per sequence - const int64_t n_s = src0->ne[2]; // number of sequences in the batch + const int64_t nc = src0->ne[0]; // d_state + const int64_t nr = src0->ne[1]; // dim + const int64_t nh = src1->ne[1]; // n_head + const int64_t ng = src4->ne[1]; + const int64_t nt = src1->ne[2]; // number of tokens per sequence + const int64_t ns = src0->ne[3]; // number of sequences in the batch + + const int64_t s_off = ggml_element_size(src1) * ggml_nelements(src1); GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); GGML_ASSERT(src0->nb[0] == sizeof(float)); @@ -15862,51 +15884,86 @@ static void ggml_compute_forward_ssm_scan_f32( GGML_ASSERT(src3->nb[0] == sizeof(float)); GGML_ASSERT(src4->nb[0] == sizeof(float)); GGML_ASSERT(src5->nb[0] == sizeof(float)); - // required for the dot product between s and C - GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); - // required for per-sequence offsets for states - GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float)); - // required to get correct offset for state destination (i.e. src1->nb[3]) - GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - const int ir = ir1 - ir0; - - for (int i3 = 0; i3 < n_s; ++i3) { - for (int i2 = 0; i2 < n_t; ++i2) { - const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s} - const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} - const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s} - const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} - const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s} - const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s} - float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} - float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s} - - // use the output as the source for the next token-wise iterations + GGML_ASSERT(src6->nb[0] == sizeof(float)); + // allows optimizing the modulo since n_group should be a power of 2 + GGML_ASSERT((ng & -ng) == ng); + + // heads per thread + const int dh = (nh + nth - 1)/nth; + + // head range for this thread + const int ih0 = dh*ith; + const int ih1 = MIN(ih0 + dh, nh); + + for (int i3 = 0; i3 < ns; ++i3) { + for (int i2 = 0; i2 < nt; ++i2) { + const float * s0 = (const float *) ((const char *) src0->data + i3*(src0->nb[3])); // {d_state, dim, nh, ns} + const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns} + const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns} + const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {nh} + const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns} + const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns} + const float * D = (const float *) ((const char *) src6->data); // {nh} + float * y = (float *) ((char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns} + float * s = (float *) ((char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns} + + // use the output as the source when it's not the first token-wise iteration if (i2 > 0) { s0 = s; } - // d_inner - for (int i1 = 0; i1 < ir; ++i1) { - // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78 - float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; - float x_dt = x[i1] * dt_soft_plus; - float sumf = 0.0f; - // d_state - for (int i0 = 0; i0 < nc; ++i0) { - int i = i0 + i1*nc; - // state = prev_state * dA + dB * x - float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); - // y = rowwise_dotprod(state, C) - sumf += state * C[i0]; - s[i] = state; + if (ggml_is_vector(src3)) { + // Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop + + // n_head + for (int h = ih0; h < ih1; ++h) { + // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16 + const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h]; + const float dA = expf(dt_soft_plus * A[h]); + + // TODO: SIMD implementation + // dim + for (int i1 = 0; i1 < nr; ++i1) { + const int i = i1 + h*nr; + const float x_dt = x[i] * dt_soft_plus; + float sumf = 0.0f; + // d_state + for (int i0 = 0; i0 < nc; ++i0) { + const int ii = i0 + i*nc; + const int ig = i0 + (h & (ng - 1))*nc; + // state = prev_state * dA + dB * x + const float state = (s0[ii] * dA) + (B[ig] * x_dt); + // y = rowwise_dotprod(state, C) + sumf += state * C[ig]; + s[ii] = state; + } + y[i] = sumf + x[i] * D[h]; + } + } + } else { + // Mamba-1 has an element-wise decay factor for the states + + // n_head + for (int h = ih0; h < ih1; ++h) { + // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16 + const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h]; + + // dim + for (int i1 = 0; i1 < nr; ++i1) { + const int i = i1 + h*nr; + const float x_dt = x[i] * dt_soft_plus; + float sumf = 0.0f; + // d_state + for (int i0 = 0; i0 < nc; ++i0) { + const int ii = i0 + i*nc; + const int ig = i0 + (h & (ng - 1))*nc; + // state = prev_state * dA + dB * x + const float state = (s0[ii] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt); + // y = rowwise_dotprod(state, C) + sumf += state * C[ig]; + s[ii] = state; + } + y[i] = sumf + x[i] * D[h]; + } } - y[i1] = sumf; } } } diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index b55effa9907b1..32a2fb20f84b9 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -130,6 +130,7 @@ class SSM: INNER_SIZE = "{arch}.ssm.inner_size" STATE_SIZE = "{arch}.ssm.state_size" TIME_STEP_RANK = "{arch}.ssm.time_step_rank" + GROUP_COUNT = "{arch}.ssm.group_count" DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms" class Tokenizer: @@ -208,6 +209,7 @@ class MODEL_ARCH(IntEnum): GEMMA2 = auto() STARCODER2 = auto() MAMBA = auto() + MAMBA2 = auto() XVERSE = auto() COMMAND_R = auto() DBRX = auto() @@ -269,6 +271,7 @@ class MODEL_TENSOR(IntEnum): SSM_DT = auto() SSM_A = auto() SSM_D = auto() + SSM_NORM = auto() SSM_OUT = auto() ATTN_Q_A = auto() ATTN_Q_B = auto() @@ -338,6 +341,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.GEMMA2: "gemma2", MODEL_ARCH.STARCODER2: "starcoder2", MODEL_ARCH.MAMBA: "mamba", + MODEL_ARCH.MAMBA2: "mamba2", MODEL_ARCH.XVERSE: "xverse", MODEL_ARCH.COMMAND_R: "command-r", MODEL_ARCH.DBRX: "dbrx", @@ -399,6 +403,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt", MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a", MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d", + MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm", MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", MODEL_TENSOR.ATTN_Q_A: "blk.{bid}.attn_q_a", MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b", @@ -869,6 +874,19 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_D, MODEL_TENSOR.SSM_OUT, ], + MODEL_ARCH.MAMBA2: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.SSM_IN, + MODEL_TENSOR.SSM_CONV1D, + MODEL_TENSOR.SSM_DT, + MODEL_TENSOR.SSM_A, + MODEL_TENSOR.SSM_D, + MODEL_TENSOR.SSM_NORM, + MODEL_TENSOR.SSM_OUT, + ], MODEL_ARCH.XVERSE: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, @@ -1373,6 +1391,7 @@ def get_type(val: Any) -> GGUFValueType: KEY_SSM_INNER_SIZE = Keys.SSM.INNER_SIZE KEY_SSM_STATE_SIZE = Keys.SSM.STATE_SIZE KEY_SSM_TIME_STEP_RANK = Keys.SSM.TIME_STEP_RANK +KEY_SSM_GROUP_COUNT = Keys.SSM.GROUP_COUNT KEY_SSM_DT_B_C_RMS = Keys.SSM.DT_B_C_RMS # tokenization diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index af3b98c679b0b..ea788918dbf2c 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -730,6 +730,9 @@ def add_ssm_state_size(self, value: int) -> None: def add_ssm_time_step_rank(self, value: int) -> None: self.add_uint32(Keys.SSM.TIME_STEP_RANK.format(arch=self.arch), value) + def add_ssm_group_count(self, value: int) -> None: + self.add_uint32(Keys.SSM.GROUP_COUNT.format(arch=self.arch), value) + def add_ssm_dt_b_c_rms(self, value: bool) -> None: self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index a4f185c0658a3..8593a80a5ab8f 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -396,7 +396,7 @@ class TensorNameMap: "encoder.layers.{bid}.norm2", # nomic-bert "transformer.decoder_layer.{bid}.rms_norm_3", # Grok "encoder.layer.{bid}.mlp.layernorm", # jina-bert-v2 - "encoder.layer.{bid}.layer_norm_2" # jina-v2-code + "encoder.layer.{bid}.layer_norm_2", # jina-v2-code ), MODEL_TENSOR.SSM_IN: ( @@ -429,6 +429,10 @@ class TensorNameMap: "backbone.layers.{bid}.mixer.D", ), + MODEL_TENSOR.SSM_NORM: ( + "backbone.layers.{bid}.mixer.norm", # mamba2 + ), + MODEL_TENSOR.SSM_OUT: ( "model.layers.{bid}.out_proj", "backbone.layers.{bid}.mixer.out_proj", diff --git a/src/llama.cpp b/src/llama.cpp index bd7f1508b2644..5be0ef7a2ac7a 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -198,6 +198,7 @@ enum llm_arch { LLM_ARCH_GEMMA2, LLM_ARCH_STARCODER2, LLM_ARCH_MAMBA, + LLM_ARCH_MAMBA2, LLM_ARCH_XVERSE, LLM_ARCH_COMMAND_R, LLM_ARCH_DBRX, @@ -245,6 +246,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_GEMMA2, "gemma2" }, { LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_MAMBA, "mamba" }, + { LLM_ARCH_MAMBA2, "mamba2" }, { LLM_ARCH_XVERSE, "xverse" }, { LLM_ARCH_COMMAND_R, "command-r" }, { LLM_ARCH_DBRX, "dbrx" }, @@ -328,6 +330,7 @@ enum llm_kv { LLM_KV_SSM_CONV_KERNEL, LLM_KV_SSM_STATE_SIZE, LLM_KV_SSM_TIME_STEP_RANK, + LLM_KV_SSM_GROUP_COUNT, LLM_KV_SSM_DT_B_C_RMS, LLM_KV_TOKENIZER_MODEL, @@ -427,7 +430,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" }, { LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" }, { LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" }, - { LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" }, + { LLM_KV_SSM_GROUP_COUNT, "%s.ssm.group_count" }, + { LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" }, { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, @@ -517,6 +521,7 @@ enum llm_tensor { LLM_TENSOR_SSM_DT, LLM_TENSOR_SSM_A, LLM_TENSOR_SSM_D, + LLM_TENSOR_SSM_NORM, LLM_TENSOR_SSM_OUT, LLM_TENSOR_ATTN_Q_A, LLM_TENSOR_ATTN_Q_B, @@ -1068,6 +1073,22 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, }, }, + { + LLM_ARCH_MAMBA2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, + { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, + { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, + { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, + { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, + { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, + }, + }, { LLM_ARCH_XVERSE, { @@ -2239,6 +2260,7 @@ struct llama_hparams { uint32_t ssm_d_inner = 0; uint32_t ssm_d_state = 0; uint32_t ssm_dt_rank = 0; + uint32_t ssm_n_group = 0; bool ssm_dt_b_c_rms = false; float f_clamp_kqv = 0.0f; @@ -2289,6 +2311,7 @@ struct llama_hparams { if (this->ssm_d_inner != other.ssm_d_inner) return true; if (this->ssm_d_state != other.ssm_d_state) return true; if (this->ssm_dt_rank != other.ssm_dt_rank) return true; + if (this->ssm_n_group != other.ssm_n_group) return true; if (this->ssm_dt_b_c_rms != other.ssm_dt_b_c_rms) return true; if (this->dec_start_token_id != other.dec_start_token_id) return true; @@ -2357,7 +2380,7 @@ struct llama_hparams { // corresponds to Mamba's conv_states size // TODO: maybe support other convolution strides than 1 // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed - return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; + return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * (ssm_d_inner + 2*ssm_n_group*ssm_d_state); } uint32_t n_embd_v_s() const { // dimension of the recurrent state embeddings @@ -2419,6 +2442,7 @@ struct llama_layer { struct ggml_tensor * ffn_sub_norm; struct ggml_tensor * attn_norm_cross; struct ggml_tensor * attn_norm_enc; + struct ggml_tensor * ssm_norm; // attention struct ggml_tensor * wq; @@ -5573,6 +5597,38 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_MAMBA2: + { + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 24: + switch (hparams.n_embd) { + case 768: model.type = e_model::MODEL_SMALL; break; + default: model.type = e_model::MODEL_UNKNOWN; + } break; + case 48: + switch (hparams.n_embd) { + case 1024: model.type = e_model::MODEL_MEDIUM; break; + case 1536: model.type = e_model::MODEL_LARGE; break; + case 2048: model.type = e_model::MODEL_XL; break; + default: model.type = e_model::MODEL_UNKNOWN; + } break; + case 64: + switch (hparams.n_embd) { + case 2560: model.type = e_model::MODEL_3B; break; + case 4096: model.type = e_model::MODEL_7B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } case LLM_ARCH_XVERSE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -6404,6 +6460,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); + LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group); LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); } @@ -7639,7 +7696,7 @@ static bool llm_load_tensors( layer.ssm_in = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner}); - layer.ssm_conv1d = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner}); + layer.ssm_conv1d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner}); layer.ssm_conv1d_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner}); layer.ssm_x = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state}); @@ -7648,9 +7705,61 @@ static bool llm_load_tensors( layer.ssm_dt_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner}); // no "weight" suffix for these - layer.ssm_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}); + layer.ssm_a = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}); layer.ssm_d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_D, i), {d_inner}); + // out_proj + layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}); + } + } break; + case LLM_ARCH_MAMBA2: + { + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t n_head = hparams.ssm_dt_rank; + const int64_t head_dim = n_embd / n_head; + const int64_t n_group = hparams.ssm_n_group; + const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_head; + + // only an expansion factor of 2 is supported for now + GGML_ASSERT(2 * n_embd == d_inner); + + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed, duplicated to allow offloading + if (model.output == NULL) { + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + // norm + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + + layer.ssm_in = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}); + + layer.ssm_conv1d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}); + layer.ssm_conv1d_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}); + + layer.ssm_dt_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_DT, "bias", i), {n_head}); + + // no "weight" suffix for these + layer.ssm_a = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_A, i), {n_head}); + layer.ssm_d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_D, i), {n_head}); + + layer.ssm_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner}); + // out_proj layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}); } @@ -9041,6 +9150,8 @@ static struct ggml_tensor * llm_build_mamba( const int64_t d_inner = hparams.ssm_d_inner; const int64_t d_state = hparams.ssm_d_state; const int64_t dt_rank = hparams.ssm_dt_rank; + const int64_t n_head = d_inner; + const int64_t head_dim = 1; const int64_t n_seqs = batch.n_seqs; // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers) const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms; @@ -9064,7 +9175,7 @@ static struct ggml_tensor * llm_build_mamba( struct ggml_tensor * ssm = llm_build_copy_mask_state(ctx, graph, ssm_states_all, state_copy, state_mask, hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs); - ssm = ggml_reshape_3d(ctx, ssm, d_state, d_inner, n_seqs); + ssm = ggml_reshape_4d(ctx, ssm, d_state, head_dim, n_head, n_seqs); // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} cur = ggml_reshape_3d(ctx, cur, cur->ne[0], n_seq_tokens, n_seqs); @@ -9113,8 +9224,8 @@ static struct ggml_tensor * llm_build_mamba( struct ggml_tensor * x_db = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_x, x); // split struct ggml_tensor * dt = ggml_view_3d(ctx, x_db, dt_rank, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], 0); - struct ggml_tensor * B = ggml_view_3d(ctx, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank); - struct ggml_tensor * C = ggml_view_3d(ctx, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*(dt_rank+d_state)); + struct ggml_tensor * B = ggml_view_4d(ctx, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state*x_db->nb[0], x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank); + struct ggml_tensor * C = ggml_view_4d(ctx, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state*x_db->nb[0], x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*(dt_rank+d_state)); // Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers if (ssm_dt_b_c_rms) { @@ -9127,23 +9238,23 @@ static struct ggml_tensor * llm_build_mamba( dt = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_dt, dt); dt = ggml_add(ctx, dt, model.layers[il].ssm_dt_b); + x = ggml_reshape_4d(ctx, x, head_dim, n_head, n_seq_tokens, n_seqs); + // Custom operator to optimize the parallel associative scan // as described in the Annex D of the Mamba paper. // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} - struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C); + struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C, model.layers[il].ssm_d); // store last states ggml_build_forward_expand(graph, ggml_cpy(ctx, - ggml_view_1d(ctx, y_ssm, d_state*d_inner*n_seqs, x->nb[3]), + ggml_view_1d(ctx, y_ssm, d_state*d_inner*n_seqs, x->nb[3]*x->ne[3]), ggml_view_1d(ctx, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); - struct ggml_tensor * y = ggml_view_3d(ctx, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[1], x->nb[2], 0); + struct ggml_tensor * y = ggml_view_3d(ctx, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[2], x->nb[3], 0); // TODO: skip computing output earlier for unused tokens - // {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs} - y = ggml_add(ctx, y, ggml_mul(ctx, x, model.layers[il].ssm_d)); y = ggml_mul(ctx, y, ggml_silu(ctx, ggml_cont(ctx, z))); // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} @@ -9157,6 +9268,136 @@ static struct ggml_tensor * llm_build_mamba( return cur; } +static struct ggml_tensor * llm_build_mamba2( + struct ggml_context * ctx, + struct llama_context & lctx, + const llama_ubatch & batch, + struct ggml_cgraph * graph, + struct ggml_tensor * cur, + struct ggml_tensor * state_copy, + struct ggml_tensor * state_mask, + int32_t kv_head, + int32_t n_kv, + const llm_build_cb & cb, + int il) { + const llama_model & model = lctx.model; + const llama_hparams & hparams = model.hparams; + const llama_kv_cache & kv = lctx.kv_self; + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t n_head = hparams.ssm_dt_rank; + const int64_t head_dim = d_inner / n_head; // FIXME + const int64_t n_group = hparams.ssm_n_group; + const int64_t n_seqs = batch.n_seqs; + + const int64_t n_seq_tokens = batch.n_seq_tokens; + + GGML_ASSERT(n_seqs != 0); + GGML_ASSERT(batch.equal_seqs); + GGML_ASSERT(batch.n_tokens == n_seq_tokens * n_seqs); + + struct ggml_tensor * conv_states_all = kv.k_l[il]; + struct ggml_tensor * ssm_states_all = kv.v_l[il]; + + // (ab)using the KV cache to store the states + struct ggml_tensor * conv = llm_build_copy_mask_state(ctx, + graph, conv_states_all, state_copy, state_mask, + hparams.n_embd_k_s(), kv.size, kv_head, n_kv, n_seqs); + conv = ggml_reshape_3d(ctx, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs); + struct ggml_tensor * ssm = llm_build_copy_mask_state(ctx, + graph, ssm_states_all, state_copy, state_mask, + hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs); + ssm = ggml_reshape_4d(ctx, ssm, d_state, head_dim, n_head, n_seqs); + + // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} + cur = ggml_reshape_3d(ctx, cur, cur->ne[0], n_seq_tokens, n_seqs); + + // d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads + + // {n_embd, d_in_proj} @ {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs} + struct ggml_tensor * zxBCdt = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_in, cur); + + // split the above in three + struct ggml_tensor * z = ggml_view_3d(ctx, zxBCdt, d_inner, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], 0); + struct ggml_tensor * xBC = ggml_view_3d(ctx, zxBCdt, d_inner + 2*n_group*d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], d_inner*ggml_element_size(zxBCdt)); + struct ggml_tensor * dt = ggml_view_3d(ctx, zxBCdt, n_head, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], (2*d_inner + 2*n_group*d_state)*ggml_element_size(zxBCdt)); + + // conv + { + // => {d_conv - 1 + n_seq_tokens, d_inner + 2*n_group*d_state, n_seqs} + struct ggml_tensor * conv_x = ggml_concat(ctx, conv, ggml_transpose(ctx, xBC), 0); + + // copy last (d_conv - 1) columns back into the state cache + struct ggml_tensor * last_conv = ggml_view_3d(ctx, conv_x, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0])); + + ggml_build_forward_expand(graph, + ggml_cpy(ctx, last_conv, + ggml_view_1d(ctx, conv_states_all, + (d_conv - 1)*(d_inner + 2*n_group*d_state)*(n_seqs), + kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all)))); + + // 1D convolution + // The equivalent is to make a self-overlapping view of conv_x + // over d_conv columns at each stride in the 3rd dimension, + // then element-wise multiply that with the conv1d weight, + // then sum the elements of each row, + // (the last two steps are a dot product over rows (also doable with mul_mat)) + // then permute away the ne[0] dimension, + // and then you're left with the resulting x tensor. + // For simultaneous sequences, all sequences need to have the same length. + xBC = ggml_ssm_conv(ctx, conv_x, model.layers[il].ssm_conv1d); + + // bias + xBC = ggml_add(ctx, xBC, model.layers[il].ssm_conv1d_b); + + xBC = ggml_silu(ctx, xBC); + } + + // ssm + { + // These correspond to V K Q in SSM/attention duality + struct ggml_tensor * x = ggml_view_4d(ctx, xBC, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*xBC->nb[0], xBC->nb[1], xBC->nb[2], 0); + struct ggml_tensor * B = ggml_view_4d(ctx, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], d_inner*ggml_element_size(xBC)); + struct ggml_tensor * C = ggml_view_4d(ctx, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], (d_inner + n_group*d_state)*ggml_element_size(xBC)); + + // {n_head, n_seq_tokens, n_seqs} + dt = ggml_add(ctx, dt, model.layers[il].ssm_dt_b); + + // TODO: use semistructured matrices to implement state-space duality + // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} + struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C, model.layers[il].ssm_d); + + // store last states + ggml_build_forward_expand(graph, + ggml_cpy(ctx, + ggml_view_1d(ctx, y_ssm, d_state*d_inner*n_seqs, ggml_nelements(x)*x->nb[0]), + ggml_view_1d(ctx, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); + + struct ggml_tensor * y = ggml_view_3d(ctx, y_ssm, d_inner, n_seq_tokens, n_seqs, n_head*x->nb[1], n_seq_tokens*n_head*x->nb[1], 0); + + // TODO: skip computing output earlier for unused tokens + + y = ggml_mul(ctx, y, ggml_silu(ctx, ggml_cont(ctx, z))); + + // grouped RMS norm + y = ggml_reshape_4d(ctx, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs); + y = llm_build_norm(ctx, y, hparams, + model.layers[il].ssm_norm, NULL, + LLM_NORM_RMS, cb, il); + y = ggml_reshape_3d(ctx, y, d_inner, n_seq_tokens, n_seqs); + + // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} + cur = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_out, y); + } + + // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} + cur = ggml_reshape_2d(ctx, cur, cur->ne[0], n_seq_tokens * n_seqs); + cb(cur, "mamba_out", il); + + return cur; +} + struct llm_build_context { const llama_model & model; llama_context & lctx; @@ -12788,7 +13029,7 @@ struct llm_build_context { return gf; } - struct ggml_cgraph * build_mamba() { + struct ggml_cgraph * build_mamba(int32_t version = 1) { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); struct ggml_tensor * cur; @@ -12807,9 +13048,19 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "attn_norm", il); - cur = llm_build_mamba(ctx0, lctx, batch, gf, cur, - state_copy, state_mask, - kv_head, n_kv, cb, il); + switch (version) { + case 2: + cur = llm_build_mamba2(ctx0, lctx, batch, gf, cur, + state_copy, state_mask, + kv_head, n_kv, cb, il); + break; + case 1: + default: + cur = llm_build_mamba(ctx0, lctx, batch, gf, cur, + state_copy, state_mask, + kv_head, n_kv, cb, il); + break; + } if (il == n_layer - 1) { // skip computing output for unused tokens @@ -14858,7 +15109,11 @@ static struct ggml_cgraph * llama_build_graph( } break; case LLM_ARCH_MAMBA: { - result = llm.build_mamba(); + result = llm.build_mamba(/* version */ 1); + } break; + case LLM_ARCH_MAMBA2: + { + result = llm.build_mamba(/* version */ 2); } break; case LLM_ARCH_XVERSE: { @@ -17954,6 +18209,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_REFACT: case LLM_ARCH_BLOOM: case LLM_ARCH_MAMBA: + case LLM_ARCH_MAMBA2: case LLM_ARCH_JINA_BERT_V2: case LLM_ARCH_T5: case LLM_ARCH_T5ENCODER: @@ -18125,6 +18381,7 @@ llama_token llama_model_decoder_start_token(const struct llama_model * model) { bool llama_model_is_recurrent(const struct llama_model * model) { switch (model->arch) { + case LLM_ARCH_MAMBA2: case LLM_ARCH_MAMBA: return true; default: return false; } From dceff23faec99945d3161d24ea209a0c433546db Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sun, 18 Aug 2024 21:49:39 -0400 Subject: [PATCH 02/18] ggml : SIMD ggml_ssm_scan for Mamba-2 * ggml : improve ggml_mul speed when masking recurrent states --- ggml/src/ggml.c | 95 +++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 80 insertions(+), 15 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 6668209081b6c..f8e708088b357 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -10207,7 +10207,37 @@ static void ggml_compute_forward_mul_f32( GGML_ASSERT( nb0 == sizeof(float)); GGML_ASSERT(nb00 == sizeof(float)); - if (nb10 == sizeof(float)) { + if (ne00 > 1 && ne10 == 1) { + // fast broadcast path + for (int64_t ir = ith; ir < nr; ir += nth) { + // src0 and dst are same shape => same indices + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); + + const float scale = src1_ptr[0]; + + if (scale == 0.0f) { + // NOTE: this also sets NANs to zero, which is not compliant with IEEE754, + // but it is useful when resetting the state of recurrent models. + memset((char *)dst->data + ir*nb1, 0, nb1); + } else { + if (dst->data != src0->data) { + // src0 is same shape as dst => same indices + memcpy((char *)dst->data + ir*nb1, (char *)src0->data + ir*nb01, ne0 * sizeof(float)); + } + if (scale != 1.0f) { + ggml_vec_scale_f32(ne0, (float *) ((char *) dst->data + ir*nb1), scale); + } + } + } + } else if (nb10 == sizeof(float)) { for (int64_t ir = ith; ir < nr; ir += nth) { // src0 and dst are same shape => same indices const int64_t i03 = ir/(ne02*ne01); @@ -15919,23 +15949,56 @@ static void ggml_compute_forward_ssm_scan_f32( const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h]; const float dA = expf(dt_soft_plus * A[h]); - // TODO: SIMD implementation // dim for (int i1 = 0; i1 < nr; ++i1) { - const int i = i1 + h*nr; - const float x_dt = x[i] * dt_soft_plus; + const int ii = i1 + h*nr; + const float x_dt = x[ii] * dt_soft_plus; float sumf = 0.0f; +#if defined(GGML_SIMD) + const int np = (nc & ~(GGML_F32_STEP - 1)); + + GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; + + GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA); + GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt); + + GGML_F32_VEC ax[GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; + GGML_F32_VEC az[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc); + ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc); + az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc); + + ax[j] = GGML_F32_VEC_MUL(ax[j], adA); + ay[j] = GGML_F32_VEC_MUL(ay[j], axdt); + + ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]); + + sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]); + + GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]); + } + } + + // reduce sum0..sum3 to sum0 + GGML_F32_VEC_REDUCE(sumf, sum); +#else + const int np = 0; +#endif // d_state - for (int i0 = 0; i0 < nc; ++i0) { - const int ii = i0 + i*nc; + for (int i0 = np; i0 < nc; ++i0) { + const int i = i0 + ii*nc; const int ig = i0 + (h & (ng - 1))*nc; // state = prev_state * dA + dB * x - const float state = (s0[ii] * dA) + (B[ig] * x_dt); + const float state = (s0[i] * dA) + (B[ig] * x_dt); // y = rowwise_dotprod(state, C) sumf += state * C[ig]; - s[ii] = state; + s[i] = state; } - y[i] = sumf + x[i] * D[h]; + y[ii] = sumf + x[ii] * D[h]; } } } else { @@ -15948,20 +16011,22 @@ static void ggml_compute_forward_ssm_scan_f32( // dim for (int i1 = 0; i1 < nr; ++i1) { - const int i = i1 + h*nr; - const float x_dt = x[i] * dt_soft_plus; + const int ii = i1 + h*nr; + const float x_dt = x[ii] * dt_soft_plus; float sumf = 0.0f; + // NOTE: can't really use GGML_SIMD here because d_state is usually 16 + // and also because expf is used within the loop. // d_state for (int i0 = 0; i0 < nc; ++i0) { - const int ii = i0 + i*nc; + const int i = i0 + ii*nc; const int ig = i0 + (h & (ng - 1))*nc; // state = prev_state * dA + dB * x - const float state = (s0[ii] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt); + const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt); // y = rowwise_dotprod(state, C) sumf += state * C[ig]; - s[ii] = state; + s[i] = state; } - y[i] = sumf + x[i] * D[h]; + y[ii] = sumf + x[ii] * D[h]; } } } From 2bfe9de6d3a3598d4b778f9b144bb8ac33c2797b Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sun, 18 Aug 2024 22:43:39 -0400 Subject: [PATCH 03/18] llama : support running Mamba-Codestral-7B-v0.1 --- convert_hf_to_gguf.py | 4 ++++ src/llama.cpp | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 0ac64574a3043..a5bdd5def2029 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2843,6 +2843,10 @@ def set_gguf_parameters(self): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: del bid # unused + if name.startswith("model.backbone") or name.startswith("model.lm_head"): + # map Mamba-Codestral-7B-v0.1 tensor names to the names used by Mamba-2 + name = name.removeprefix("model.") + if name.endswith(".dt_bias"): name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias" diff --git a/src/llama.cpp b/src/llama.cpp index 5be0ef7a2ac7a..fd80361bd7605 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -9383,7 +9383,7 @@ static struct ggml_tensor * llm_build_mamba2( // grouped RMS norm y = ggml_reshape_4d(ctx, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs); y = llm_build_norm(ctx, y, hparams, - model.layers[il].ssm_norm, NULL, + ggml_reshape_2d(ctx, model.layers[il].ssm_norm, d_inner / n_group, n_group), NULL, LLM_NORM_RMS, cb, il); y = ggml_reshape_3d(ctx, y, d_inner, n_seq_tokens, n_seqs); From aff96920f972d8e042dfdef6dc08644cd8df0234 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 21 Aug 2024 16:28:07 -0400 Subject: [PATCH 04/18] llama : fix Mamba-2 conv state saving * ggml : make the ggml_mul fast broadcast path more consistently formatted --- ggml/src/ggml.c | 4 ++-- src/llama.cpp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index f8e708088b357..415fa6901304a 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -10226,11 +10226,11 @@ static void ggml_compute_forward_mul_f32( if (scale == 0.0f) { // NOTE: this also sets NANs to zero, which is not compliant with IEEE754, // but it is useful when resetting the state of recurrent models. - memset((char *)dst->data + ir*nb1, 0, nb1); + memset((char *) dst->data + ir*nb1, 0, ne0 * sizeof(float)); } else { if (dst->data != src0->data) { // src0 is same shape as dst => same indices - memcpy((char *)dst->data + ir*nb1, (char *)src0->data + ir*nb01, ne0 * sizeof(float)); + memcpy((char *) dst->data + ir*nb1, (char *) src0->data + ir*nb01, ne0 * sizeof(float)); } if (scale != 1.0f) { ggml_vec_scale_f32(ne0, (float *) ((char *) dst->data + ir*nb1), scale); diff --git a/src/llama.cpp b/src/llama.cpp index fd80361bd7605..03f93164a89e8 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -9335,7 +9335,7 @@ static struct ggml_tensor * llm_build_mamba2( ggml_cpy(ctx, last_conv, ggml_view_1d(ctx, conv_states_all, (d_conv - 1)*(d_inner + 2*n_group*d_state)*(n_seqs), - kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all)))); + kv_head*(d_conv - 1)*(d_inner + 2*n_group*d_state)*ggml_element_size(conv_states_all)))); // 1D convolution // The equivalent is to make a self-overlapping view of conv_x From e04910dc48966f1cbc7309d12b8e1b55bdd33df2 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 21 Aug 2024 23:06:22 -0400 Subject: [PATCH 05/18] llama : remove unused variable --- src/llama.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 03f93164a89e8..dda3d51b017d6 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -7718,7 +7718,6 @@ static bool llm_load_tensors( const int64_t d_inner = hparams.ssm_d_inner; const int64_t d_state = hparams.ssm_d_state; const int64_t n_head = hparams.ssm_dt_rank; - const int64_t head_dim = n_embd / n_head; const int64_t n_group = hparams.ssm_n_group; const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_head; @@ -9287,7 +9286,7 @@ static struct ggml_tensor * llm_build_mamba2( const int64_t d_inner = hparams.ssm_d_inner; const int64_t d_state = hparams.ssm_d_state; const int64_t n_head = hparams.ssm_dt_rank; - const int64_t head_dim = d_inner / n_head; // FIXME + const int64_t head_dim = d_inner / n_head; const int64_t n_group = hparams.ssm_n_group; const int64_t n_seqs = batch.n_seqs; From fa358e707132ace9012cb90880abe86fd32464a6 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Thu, 22 Aug 2024 01:13:43 -0400 Subject: [PATCH 06/18] llama : add missing break --- src/llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama.cpp b/src/llama.cpp index dda3d51b017d6..5b6b6707a1c95 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -5628,7 +5628,7 @@ static void llm_load_hparams( } break; default: model.type = e_model::MODEL_UNKNOWN; } - } + } break; case LLM_ARCH_XVERSE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); From 38913dc8ddd1e119df0e0cfcacfb260b9b1f5c02 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Thu, 22 Aug 2024 14:31:12 -0400 Subject: [PATCH 07/18] convert_hf : prefer SentencePiece tokenizer for Mamba-2 when present The tokenzier.json of Mamba-Codestral-7B-v0.1 otherwise requires workarounds to work correctly. --- convert_hf_to_gguf.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index a5bdd5def2029..4851926b7b98f 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2801,13 +2801,13 @@ def set_vocab(self): vocab_size = -(vocab_size // -pad_vocab) * pad_vocab self.hparams["vocab_size"] = vocab_size - if (self.dir_model / "tokenizer.json").is_file(): - self._set_vocab_gpt2() - elif (self.dir_model / "tokenizer.model").is_file(): + if (self.dir_model / "tokenizer.model").is_file(): self._set_vocab_sentencepiece() elif (self.dir_model / "tokenizer.model.v3").is_file(): # mamba-codestral raise NotImplementedError(f"Please rename {self.dir_model / 'tokenizer.model.v3'} to {self.dir_model / 'tokenizer.model'}") + elif (self.dir_model / "tokenizer.json").is_file(): + self._set_vocab_gpt2() else: # Use the GPT-NeoX tokenizer when no tokenizer files are present self._set_vocab_builtin("gpt-neox", vocab_size) From 273e7a495ad8c93bb9ba8123c1a3de3c68f93cf9 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 30 Sep 2024 15:52:42 -0400 Subject: [PATCH 08/18] llama : avoid redundant state copy for Mamba 1 and 2 --- ggml/include/ggml.h | 3 +- ggml/src/ggml.c | 50 ++++++------ src/llama.cpp | 154 +++++++++++++++++-------------------- tests/test-backend-ops.cpp | 54 ++++++++++--- 4 files changed, 142 insertions(+), 119 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index fec6798ff6d06..1fc53bebebf30 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1833,7 +1833,8 @@ extern "C" { struct ggml_tensor * A, struct ggml_tensor * B, struct ggml_tensor * C, - struct ggml_tensor * D); + struct ggml_tensor * D, + struct ggml_tensor * ids); // partition into non-overlapping windows with padding if needed // example: diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 12e4f26942f86..1c4c393e55d06 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -7598,7 +7598,8 @@ struct ggml_tensor * ggml_ssm_scan( struct ggml_tensor * A, struct ggml_tensor * B, struct ggml_tensor * C, - struct ggml_tensor * D) { + struct ggml_tensor * D, + struct ggml_tensor * ids) { GGML_ASSERT(ggml_is_contiguous(s)); GGML_ASSERT(ggml_is_contiguous(dt)); GGML_ASSERT(ggml_is_contiguous(A)); @@ -7609,6 +7610,7 @@ struct ggml_tensor * ggml_ssm_scan( GGML_ASSERT(B->nb[1] == B->ne[0]*B->nb[0]); GGML_ASSERT(C->nb[1] == C->ne[0]*C->nb[0]); GGML_ASSERT(ggml_are_same_shape(B, C)); + GGML_ASSERT(ids->type == GGML_TYPE_I32); { const int64_t d_state = s->ne[0]; @@ -7623,21 +7625,19 @@ struct ggml_tensor * ggml_ssm_scan( GGML_ASSERT(ggml_is_3d(dt)); GGML_ASSERT(s->ne[1] == head_dim); GGML_ASSERT(s->ne[2] == n_head); - GGML_ASSERT(s->ne[3] == n_seqs); GGML_ASSERT(B->ne[0] == d_state); GGML_ASSERT(B->ne[2] == n_seq_tokens); GGML_ASSERT(B->ne[3] == n_seqs); GGML_ASSERT(D->ne[0] == n_head); GGML_ASSERT(ggml_is_vector(D)); + GGML_ASSERT(ids->ne[0] == n_seqs); + GGML_ASSERT(ggml_is_vector(ids)); + GGML_ASSERT(A->ne[1] == n_head); + GGML_ASSERT(ggml_is_matrix(A)); - if (ggml_is_vector(A)) { - // Mamba-2 - GGML_ASSERT(A->ne[0] == n_head); - } else { - // Mamba-1 + if (A->ne[0] != 1) { + // Mamba-1 has more granular decay factors GGML_ASSERT(A->ne[0] == d_state); - GGML_ASSERT(A->ne[1] == n_head); - GGML_ASSERT(ggml_is_matrix(A)); } } @@ -7649,7 +7649,7 @@ struct ggml_tensor * ggml_ssm_scan( } // concatenated y + ssm_states - struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s)); + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + s->ne[0]*s->ne[1]*s->ne[2]*ids->ne[0]); result->op = GGML_OP_SSM_SCAN; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; @@ -7660,6 +7660,7 @@ struct ggml_tensor * ggml_ssm_scan( result->src[4] = B; result->src[5] = C; result->src[6] = D; + result->src[7] = ids; return result; } @@ -16635,13 +16636,14 @@ static void ggml_compute_forward_ssm_conv( static void ggml_compute_forward_ssm_scan_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { - const struct ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs} + const struct ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs+} const struct ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs} const struct ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs} - const struct ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {n_head} + const struct ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head} const struct ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs} const struct ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs} const struct ggml_tensor * src6 = dst->src[6]; // D {n_head} + const struct ggml_tensor * src7 = dst->src[7]; // ids {n_seqs} const int ith = params->ith; const int nth = params->nth; @@ -16651,11 +16653,12 @@ static void ggml_compute_forward_ssm_scan_f32( const int64_t nh = src1->ne[1]; // n_head const int64_t ng = src4->ne[1]; const int64_t nt = src1->ne[2]; // number of tokens per sequence - const int64_t ns = src0->ne[3]; // number of sequences in the batch + const int64_t ns = src1->ne[3]; // number of sequences in the batch - const int64_t s_off = ggml_element_size(src1) * ggml_nelements(src1); + // can't use ggml_nbytes because src1 is not necessarily contiguous + const int64_t s_off = ggml_nelements(src1) * ggml_element_size(src1); - GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); + GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*ns == ggml_nelements(dst)); GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float)); GGML_ASSERT(src2->nb[0] == sizeof(float)); @@ -16663,6 +16666,7 @@ static void ggml_compute_forward_ssm_scan_f32( GGML_ASSERT(src4->nb[0] == sizeof(float)); GGML_ASSERT(src5->nb[0] == sizeof(float)); GGML_ASSERT(src6->nb[0] == sizeof(float)); + GGML_ASSERT(src7->nb[0] == sizeof(int32_t)); // allows optimizing the modulo since n_group should be a power of 2 GGML_ASSERT((ng & -ng) == ng); @@ -16673,22 +16677,22 @@ static void ggml_compute_forward_ssm_scan_f32( const int ih0 = dh*ith; const int ih1 = MIN(ih0 + dh, nh); + const int32_t * ids = (const int32_t *) src7->data; + for (int i3 = 0; i3 < ns; ++i3) { + const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns} + float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns} + for (int i2 = 0; i2 < nt; ++i2) { - const float * s0 = (const float *) ((const char *) src0->data + i3*(src0->nb[3])); // {d_state, dim, nh, ns} const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns} const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns} - const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {nh} + const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh} const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns} const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns} const float * D = (const float *) ((const char *) src6->data); // {nh} float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns} - float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns} - - // use the output as the source when it's not the first token-wise iteration - if (i2 > 0) { s0 = s; } - if (ggml_is_vector(src3)) { + if (src3->ne[0] == 1) { // Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop // n_head @@ -16778,6 +16782,8 @@ static void ggml_compute_forward_ssm_scan_f32( } } } + // use the output as the source when it's not the first token-wise iteration + s0 = s; } } } diff --git a/src/llama.cpp b/src/llama.cpp index c11472112f8fb..3e1f8755ffb85 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2801,6 +2801,10 @@ struct llama_kv_cache { // computed before each graph build uint32_t n = 0; + // first zero-ed state + // NOTE: only used by recurrent models + int32_t rs_z = -1; + ggml_type type_k = GGML_TYPE_F16; ggml_type type_v = GGML_TYPE_F16; @@ -3381,8 +3385,6 @@ struct llama_context { struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch] struct ggml_tensor * inp_cls; // I32 [n_batch] struct ggml_tensor * inp_s_copy; // I32 [kv_size] - struct ggml_tensor * inp_s_mask; // F32 [1, n_kv] - struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch] struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch] struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc] struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch] @@ -3813,6 +3815,15 @@ static bool llama_kv_cache_find_slot( } } + // Find first to-be-cleared cell + cache.rs_z = -1; + for (int i = min; i <= max; ++i) { + if (cache.cells[i].src == -1) { + cache.rs_z = i; + break; + } + } + // allow getting the range of used cells, from head to head + n cache.head = min; cache.n = max - min + 1; @@ -9569,36 +9580,42 @@ static struct ggml_tensor * llm_build_kv( return cur; } -static struct ggml_tensor * llm_build_copy_mask_state( +static struct ggml_tensor * llm_build_rs( struct ggml_context * ctx, struct ggml_cgraph * graph, struct ggml_tensor * s, struct ggml_tensor * state_copy, - struct ggml_tensor * state_mask, + int32_t rs_zero, int32_t n_state, int32_t kv_size, int32_t kv_head, int32_t n_kv, - int32_t n_seqs) { + int32_t n_seqs, + bool avoid_copies = false) { struct ggml_tensor * states = ggml_reshape_2d(ctx, s, n_state, kv_size); - // copy states - // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv - // this shrinks the tensors's ne[1] to n_kv - states = ggml_get_rows(ctx, states, state_copy); - - // clear states of sequences which are starting at the beginning of this batch - // FIXME: zero-out NANs? - states = ggml_mul(ctx, states, state_mask); + // Clear a single state which will then be copied to the other cleared states. + // Note that this is a no-op when the view is zero-sized. + struct ggml_tensor * state_zero = ggml_view_1d(ctx, states, n_state*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0)); + ggml_build_forward_expand(graph, ggml_scale_inplace(ctx, state_zero, 0)); // copy states which won't be changed further (between n_seqs and n_kv) + struct ggml_tensor * states_extra = ggml_get_rows(ctx, states, ggml_view_1d(ctx, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0])); ggml_build_forward_expand(graph, ggml_cpy(ctx, - ggml_view_1d(ctx, states, n_state*(n_kv - n_seqs), n_seqs*n_state*ggml_element_size(states)), + states_extra, ggml_view_1d(ctx, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s)))); - // the part of the states that will be used and modified - return ggml_view_2d(ctx, states, n_state, n_seqs, states->nb[1], 0); + if (!avoid_copies) { + // copy states + // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv + // this shrinks the tensors's ne[1] to n_kv + states = ggml_get_rows(ctx, states, ggml_view_1d(ctx, state_copy, n_seqs, 0)); + // the part of the states that will be used and modified + states = ggml_view_2d(ctx, states, n_state, n_seqs, states->nb[1], 0); + } + + return states; } // TODO: split @@ -9609,7 +9626,7 @@ static struct ggml_tensor * llm_build_mamba( struct ggml_cgraph * graph, struct ggml_tensor * cur, struct ggml_tensor * state_copy, - struct ggml_tensor * state_mask, + int32_t rs_zero, int32_t kv_head, int32_t n_kv, const llm_build_cb & cb, @@ -9639,14 +9656,14 @@ static struct ggml_tensor * llm_build_mamba( struct ggml_tensor * ssm_states_all = kv.v_l[il]; // (ab)using the KV cache to store the states - struct ggml_tensor * conv = llm_build_copy_mask_state(ctx, - graph, conv_states_all, state_copy, state_mask, + struct ggml_tensor * conv = llm_build_rs(ctx, + graph, conv_states_all, state_copy, rs_zero, hparams.n_embd_k_s(), kv.size, kv_head, n_kv, n_seqs); conv = ggml_reshape_3d(ctx, conv, d_conv - 1, d_inner, n_seqs); - struct ggml_tensor * ssm = llm_build_copy_mask_state(ctx, - graph, ssm_states_all, state_copy, state_mask, - hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs); - ssm = ggml_reshape_4d(ctx, ssm, d_state, head_dim, n_head, n_seqs); + struct ggml_tensor * ssm = llm_build_rs(ctx, + graph, ssm_states_all, state_copy, rs_zero, + hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs, true); + ssm = ggml_reshape_4d(ctx, ssm, d_state, head_dim, n_head, kv.size); // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} cur = ggml_reshape_3d(ctx, cur, cur->ne[0], n_seq_tokens, n_seqs); @@ -9711,10 +9728,11 @@ static struct ggml_tensor * llm_build_mamba( x = ggml_reshape_4d(ctx, x, head_dim, n_head, n_seq_tokens, n_seqs); + struct ggml_tensor * ssm_ids = ggml_view_1d(ctx, state_copy, n_seqs, 0); // Custom operator to optimize the parallel associative scan // as described in the Annex D of the Mamba paper. // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} - struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C, model.layers[il].ssm_d); + struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C, model.layers[il].ssm_d, ssm_ids); // store last states ggml_build_forward_expand(graph, @@ -9746,7 +9764,7 @@ static struct ggml_tensor * llm_build_mamba2( struct ggml_cgraph * graph, struct ggml_tensor * cur, struct ggml_tensor * state_copy, - struct ggml_tensor * state_mask, + int32_t rs_zero, int32_t kv_head, int32_t n_kv, const llm_build_cb & cb, @@ -9772,14 +9790,14 @@ static struct ggml_tensor * llm_build_mamba2( struct ggml_tensor * ssm_states_all = kv.v_l[il]; // (ab)using the KV cache to store the states - struct ggml_tensor * conv = llm_build_copy_mask_state(ctx, - graph, conv_states_all, state_copy, state_mask, + struct ggml_tensor * conv = llm_build_rs(ctx, + graph, conv_states_all, state_copy, rs_zero, hparams.n_embd_k_s(), kv.size, kv_head, n_kv, n_seqs); conv = ggml_reshape_3d(ctx, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs); - struct ggml_tensor * ssm = llm_build_copy_mask_state(ctx, - graph, ssm_states_all, state_copy, state_mask, - hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs); - ssm = ggml_reshape_4d(ctx, ssm, d_state, head_dim, n_head, n_seqs); + struct ggml_tensor * ssm = llm_build_rs(ctx, + graph, ssm_states_all, state_copy, rs_zero, + hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs, true); + ssm = ggml_reshape_4d(ctx, ssm, d_state, head_dim, n_head, kv.size); // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} cur = ggml_reshape_3d(ctx, cur, cur->ne[0], n_seq_tokens, n_seqs); @@ -9835,9 +9853,12 @@ static struct ggml_tensor * llm_build_mamba2( // {n_head, n_seq_tokens, n_seqs} dt = ggml_add(ctx, dt, model.layers[il].ssm_dt_b); + struct ggml_tensor * ssm_ids = ggml_view_1d(ctx, state_copy, n_seqs, 0); + // Use the same shape semantics for A as Mamba-1 + struct ggml_tensor * A = ggml_reshape_2d(ctx, model.layers[il].ssm_a, 1, n_head); // TODO: use semistructured matrices to implement state-space duality // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} - struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C, model.layers[il].ssm_d); + struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, model.layers[il].ssm_d, ssm_ids); // store last states ggml_build_forward_expand(graph, @@ -10069,6 +10090,7 @@ struct llm_build_context { const int32_t n_outputs; const int32_t n_outputs_enc; const int32_t kv_head; // index of where we store new KV data in the cache + const int32_t rs_zero; // the first zero-ed recurrent state const int32_t n_ctx_orig; const bool flash_attn; @@ -10119,6 +10141,7 @@ struct llm_build_context { n_outputs (worst_case ? n_tokens : lctx.n_outputs), n_outputs_enc (worst_case ? n_tokens : lctx.embd_enc.size() / hparams.n_embd), kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head), + rs_zero (kv_self.rs_z), n_ctx_orig (cparams.n_ctx_orig_yarn), flash_attn (cparams.flash_attn), pooling_type (cparams.pooling_type), @@ -10147,8 +10170,6 @@ struct llm_build_context { lctx.inp_mean = nullptr; lctx.inp_cls = nullptr; lctx.inp_s_copy = nullptr; - lctx.inp_s_mask = nullptr; - lctx.inp_s_seq = nullptr; lctx.inp_pos_bucket = nullptr; lctx.inp_embd_enc = nullptr; lctx.inp_KQ_mask_cross = nullptr; @@ -10332,13 +10353,6 @@ struct llm_build_context { return lctx.inp_s_copy; } - struct ggml_tensor * build_inp_s_mask() { - lctx.inp_s_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv); - cb(lctx.inp_s_mask, "inp_s_mask", -1); - ggml_set_input(lctx.inp_s_mask); - return lctx.inp_s_mask; - } - struct ggml_cgraph * append_pooling(struct ggml_cgraph * gf) { // find result_norm tensor for input struct ggml_tensor * inp = nullptr; @@ -13901,7 +13915,6 @@ struct llm_build_context { inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); struct ggml_tensor * state_copy = build_inp_s_copy(); - struct ggml_tensor * state_mask = build_inp_s_mask(); for (int il = 0; il < n_layer; ++il) { // norm @@ -13912,15 +13925,13 @@ struct llm_build_context { switch (version) { case 2: - cur = llm_build_mamba2(ctx0, lctx, batch, gf, cur, - state_copy, state_mask, - kv_head, n_kv, cb, il); + cur = llm_build_mamba2(ctx0, lctx, batch, gf, cur, state_copy, + rs_zero, kv_head, n_kv, cb, il); break; case 1: default: - cur = llm_build_mamba(ctx0, lctx, batch, gf, cur, - state_copy, state_mask, - kv_head, n_kv, cb, il); + cur = llm_build_mamba(ctx0, lctx, batch, gf, cur, state_copy, + rs_zero, kv_head, n_kv, cb, il); break; } @@ -15946,7 +15957,6 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; struct ggml_tensor * state_copy = build_inp_s_copy(); - struct ggml_tensor * state_mask = build_inp_s_mask(); inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); inpL = llm_build_norm(ctx0, inpL, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1); @@ -15955,11 +15965,11 @@ struct llm_build_context { const llama_layer * layer = &model.layers[il]; // (ab)using the KV cache to store the states - struct ggml_tensor * token_shift = llm_build_copy_mask_state(ctx0, - gf, kv_self.k_l[il], state_copy, state_mask, + struct ggml_tensor * token_shift = llm_build_rs(ctx0, + gf, kv_self.k_l[il], state_copy, rs_zero, hparams.n_embd_k_s(), kv_self.size, kv_head, n_kv, n_seqs); - struct ggml_tensor * wkv_states = llm_build_copy_mask_state(ctx0, - gf, kv_self.v_l[il], state_copy, state_mask, + struct ggml_tensor * wkv_states = llm_build_rs(ctx0, + gf, kv_self.v_l[il], state_copy, rs_zero, hparams.n_embd_v_s(), kv_self.size, kv_head, n_kv, n_seqs); cur = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); @@ -16329,18 +16339,6 @@ static void llama_set_k_shift(llama_context & lctx) { } } -static void llama_set_s_copy(llama_context & lctx) { - const int64_t kv_size = lctx.kv_self.size; - - assert(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer)); - - int32_t * data = (int32_t *) lctx.inp_s_copy->data; - - for (int i = 0; i < kv_size; ++i) { - data[i] = lctx.kv_self.cells[i].src; - } -} - static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) { // TODO move to hparams if a T5 variant appears that uses a different value const int64_t max_distance = 128; @@ -16656,24 +16654,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { if (kv_self.recurrent) { const int64_t n_kv = kv_self.n; - if (lctx.inp_s_mask) { - GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer)); - float * data = (float *) lctx.inp_s_mask->data; - - // clear unused states - for (int i = 0; i < n_kv; ++i) { - const uint32_t cell_id = i + kv_self.head; - llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id]; - - data[i] = (float) (kv_cell.src >= 0); - - // only clear once - if (kv_cell.src < 0) { - kv_cell.src = cell_id; - } - } - } - if (lctx.inp_s_copy) { GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer)); int32_t * data = (int32_t *) lctx.inp_s_copy->data; @@ -16683,8 +16663,12 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { const uint32_t cell_id = i + kv_self.head; llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id]; - // prevent out-of-bound sources - if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self.size) { + if (kv_cell.src < 0) { + GGML_ASSERT(kv_self.rs_z >= 0); // Need a valid zero-ed cell as a source + kv_cell.src = kv_self.rs_z; + } + if ((uint32_t) kv_cell.src >= kv_self.size) { + // ignore out-of-bound sources kv_cell.src = cell_id; } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index aa7896defdad0..092639eed42e1 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1530,27 +1530,58 @@ struct test_ssm_scan : public test_case { const int64_t d_state; const int64_t d_inner; + const int64_t n_head; + const int64_t n_group; const int64_t n_seq_tokens; const int64_t n_seqs; std::string vars() override { - return VARS_TO_STR5(type, d_state, d_inner, n_seq_tokens, n_seqs); + return VARS_TO_STR7(type, d_state, d_inner, n_head, n_group, n_seq_tokens, n_seqs); } test_ssm_scan(ggml_type type = GGML_TYPE_F32, - int64_t d_state = 32, int64_t d_inner = 32, int64_t n_seq_tokens = 32, int64_t n_seqs = 32) - : type(type), d_state(d_state), d_inner(d_inner), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {} + int64_t d_state = 32, + int64_t d_inner = 1, // non-zero for Mamba-2 + int64_t n_head = 32, + int64_t n_group = 1, + int64_t n_seq_tokens = 32, + int64_t n_seqs = 32) + : type(type), d_state(d_state), d_inner(d_inner), n_head(n_head), n_group(n_group), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {} ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * s = ggml_new_tensor(ctx, type, 4, std::vector{ d_state, d_inner, n_seqs, 1 }.data()); - ggml_tensor * x = ggml_new_tensor(ctx, type, 4, std::vector{ d_inner, n_seq_tokens, n_seqs, 1 }.data()); - ggml_tensor * dt = ggml_new_tensor(ctx, type, 4, std::vector{ d_inner, n_seq_tokens, n_seqs, 1 }.data()); - ggml_tensor * A = ggml_new_tensor(ctx, type, 4, std::vector{ d_state, d_inner, 1 , 1 }.data()); - ggml_tensor * B = ggml_new_tensor(ctx, type, 4, std::vector{ d_state, n_seq_tokens, n_seqs, 1 }.data()); - ggml_tensor * C = ggml_new_tensor(ctx, type, 4, std::vector{ d_state, n_seq_tokens, n_seqs, 1 }.data()); - ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C); + ggml_tensor * s = ggml_new_tensor_4d(ctx, type, d_state, d_inner, n_head, n_seqs); + ggml_tensor * x = ggml_new_tensor_4d(ctx, type, d_inner, n_head, n_seq_tokens, n_seqs); + ggml_tensor * dt = ggml_new_tensor_3d(ctx, type, n_head, n_seq_tokens, n_seqs); + ggml_tensor * A = ggml_new_tensor_2d(ctx, type, (d_inner > 1) ? 1 : d_state, n_head); + ggml_tensor * B = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * C = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * D = ggml_new_tensor_1d(ctx, type, n_head); + ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs); + ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C, D, ids); return out; } + + // similar to test_mul_mat_id + void initialize_tensors(ggml_context * ctx) override { + std::random_device rd; + std::default_random_engine rng(rd()); + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + if (t->type == GGML_TYPE_I32) { + if (ggml_is_view_op(t->op)) { continue; } + // ids + for (int64_t r = 0; r < ggml_nrows(t); r++) { + std::vector data(t->ne[0]); + for (int i = 0; i < t->ne[0]; i++) { + data[i] = i; + } + std::shuffle(data.begin(), data.end(), rng); + ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t)); + } + } else { + init_tensor_uniform(t); + } + } + } }; // GGML_OP_MUL_MAT @@ -3255,7 +3286,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {8, 1536, 1, 1}, {4, 1536, 1, 1})); test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 4, 1}, {4, 1536, 1, 1})); - test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1024, 32, 4)); + test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1, 1024, 1, 32, 4)); // Mamba-1 + test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 32, 32, 2, 32, 4)); // Mamba-2 #if 1 for (ggml_type type_a : base_types) { From 2c77d799f9387f5971289139aaca23b4ce37c435 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 2 Oct 2024 10:36:22 -0400 Subject: [PATCH 09/18] metal : attempt to adapt SSM_SCAN for Mamba-2 --- ggml/src/ggml-metal.m | 107 ++++++++++++++++++++-------- ggml/src/ggml-metal.metal | 146 ++++++++++++++++++++++++++++++++------ 2 files changed, 202 insertions(+), 51 deletions(-) diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 9da08fe2e9771..5d5b98307d264 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -95,6 +95,7 @@ GGML_METAL_KERNEL_TYPE_NORM, GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, + GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, @@ -591,6 +592,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, ssm_scan_f32_group, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction); @@ -1629,47 +1631,74 @@ static void ggml_metal_encode_node( struct ggml_tensor * src3 = node->src[3]; struct ggml_tensor * src4 = node->src[4]; struct ggml_tensor * src5 = node->src[5]; + struct ggml_tensor * src6 = node->src[6]; + struct ggml_tensor * src7 = node->src[7]; GGML_ASSERT(src3); GGML_ASSERT(src4); GGML_ASSERT(src5); + GGML_ASSERT(src6); + GGML_ASSERT(src7); size_t offs_src3 = 0; size_t offs_src4 = 0; size_t offs_src5 = 0; + size_t offs_src6 = 0; + size_t offs_src7 = 0; id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; id id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil; id id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil; + id id_src6 = src6 ? ggml_metal_get_buffer(src6, &offs_src6) : nil; + id id_src7 = src7 ? ggml_metal_get_buffer(src7, &offs_src7) : nil; - const int64_t ne30 = src3->ne[0]; GGML_UNUSED(ne30); + const int64_t ne30 = src3->ne[0]; const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31); const uint64_t nb30 = src3->nb[0]; const uint64_t nb31 = src3->nb[1]; const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40); - const int64_t ne41 = src4->ne[1]; GGML_UNUSED(ne41); + const int64_t ne41 = src4->ne[1]; const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42); + const int64_t ne43 = src4->ne[3]; GGML_UNUSED(ne43); const uint64_t nb40 = src4->nb[0]; const uint64_t nb41 = src4->nb[1]; const uint64_t nb42 = src4->nb[2]; + const uint64_t nb43 = src4->nb[3]; const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50); const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51); const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52); + const int64_t ne53 = src5->ne[3]; GGML_UNUSED(ne53); const uint64_t nb50 = src5->nb[0]; const uint64_t nb51 = src5->nb[1]; const uint64_t nb52 = src5->nb[2]; + const uint64_t nb53 = src5->nb[3]; + + const int64_t ne60 = src6->ne[0]; GGML_UNUSED(ne60); + + const uint64_t nb60 = src6->nb[0]; + + const int64_t ne70 = src7->ne[0]; GGML_UNUSED(ne70); + + const uint64_t nb70 = src7->nb[0]; const int64_t d_state = ne00; const int64_t d_inner = ne01; + const int64_t n_head = ne02; + const int64_t n_group = ne41; const int64_t n_seq_tokens = ne11; - const int64_t n_seqs = ne02; + const int64_t n_seqs = ne13; - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; + if (ne30 == 1) { + // Mamba-2 + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP].pipeline; + } else { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; + } [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -1678,33 +1707,49 @@ static void ggml_metal_encode_node( [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4]; [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:6]; - - [encoder setBytes:&d_state length:sizeof(d_state) atIndex:7]; - [encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:8]; - [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:9]; - [encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:10]; - - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:11]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:12]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:13]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17]; - [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:18]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:19]; - [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:20]; - [encoder setBytes:&nb30 length:sizeof(nb30) atIndex:21]; - [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22]; - [encoder setBytes:&nb40 length:sizeof(nb40) atIndex:23]; - [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24]; - [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25]; - [encoder setBytes:&nb50 length:sizeof(nb50) atIndex:26]; - [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27]; - [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28]; - - [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder setBuffer:id_src6 offset:offs_src6 atIndex:6]; + [encoder setBuffer:id_src7 offset:offs_src7 atIndex:7]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:8]; + + [encoder setBytes:&d_state length:sizeof(d_state) atIndex:9]; + [encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:10]; + [encoder setBytes:&n_head length:sizeof(n_head) atIndex:11]; + [encoder setBytes:&n_group length:sizeof(n_group) atIndex:12]; + [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:13]; + [encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:14]; + + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:15]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:16]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:17]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:18]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:19]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:20]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:21]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:22]; + [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:23]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:24]; + [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:25]; + [encoder setBytes:&nb23 length:sizeof(nb23) atIndex:26]; + [encoder setBytes:&nb30 length:sizeof(nb30) atIndex:27]; + [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:28]; + [encoder setBytes:&nb40 length:sizeof(nb40) atIndex:29]; + [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:30]; + [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:31]; + [encoder setBytes:&nb43 length:sizeof(nb43) atIndex:32]; + [encoder setBytes:&nb50 length:sizeof(nb50) atIndex:33]; + [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:34]; + [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:35]; + [encoder setBytes:&nb53 length:sizeof(nb53) atIndex:36]; + [encoder setBytes:&nb60 length:sizeof(nb60) atIndex:37]; + [encoder setBytes:&nb70 length:sizeof(nb70) atIndex:38]; + + if (ne30 == 1) { + // Mamba-2 + [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } else { + GGML_ASSERT(d_inner == 1); + [encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } } break; case GGML_OP_MUL_MAT: { diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 2b200032394b1..c75fa25c34e7d 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -795,7 +795,7 @@ kernel void kernel_ssm_conv_f32( x[0] = sumf; } -// ref: ggml.c:ggml_compute_forward_ssm_scan_f32 +// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part // TODO: optimize kernel void kernel_ssm_scan_f32( device const void * src0, @@ -804,14 +804,19 @@ kernel void kernel_ssm_scan_f32( device const void * src3, device const void * src4, device const void * src5, + device const void * src6, + device const void * src7, device float * dst, constant int64_t & d_state, constant int64_t & d_inner, + constant int64_t & n_head, + constant int64_t & n_group, constant int64_t & n_seq_tokens, constant int64_t & n_seqs, constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, @@ -824,47 +829,148 @@ kernel void kernel_ssm_scan_f32( constant uint64_t & nb40, constant uint64_t & nb41, constant uint64_t & nb42, + constant uint64_t & nb43, constant uint64_t & nb50, constant uint64_t & nb51, constant uint64_t & nb52, + constant uint64_t & nb53, + constant uint64_t & nb60, + constant uint64_t & nb70, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { - const int64_t ir = tgpig.x; - const int64_t i3 = tgpig.y; + const int64_t i1 = 0; + const int64_t ir = tgpig.x; // current head + const int64_t i3 = tgpig.y; // current seq const int64_t nc = d_state; const int64_t nr = d_inner; + const int64_t nh = n_head; + const int64_t ng = n_group; const int64_t n_t = n_seq_tokens; const int64_t n_s = n_seqs; + const int64_t s_off = d_inner * n_head * n_seq_tokens * n_seqs * sizeof(float); + + device const int32_t * ids = (device const int32_t *) src7; + + device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + ids[i3]*nb03); + device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb03 + s_off); + for (int64_t i2 = 0; i2 < n_t; ++i2) { - device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02); - device const float * x = (device const float *) ((device const char *) src1 + ir*nb10 + i2*nb11 + i3*nb12); - device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22); - device const float * A = (device const float *) ((device const char *) src3 + ir*nb31); - device const float * B = (device const float *) ((device const char *) src4 + i2*nb41 + i3*nb42); - device const float * C = (device const float *) ((device const char *) src5 + i2*nb51 + i3*nb52); - device float * y = (device float *) ((device char *) dst + ir*nb10 + i2*nb11 + i3*nb12); // TODO: do not use src1 strides - device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb02 + nb13); - - if (i2 > 0) { - s0 = s; + device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*nb11 + i2*nb12 + i3*nb13); // {dim, nh, nt, ns} + device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22); // {nh, nt, ns} + device const float * A = (device const float *) ((device const char *) src3 + ir*nb31); // {d_state, nh} + device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*nb41 + i2*nb42 + i3*nb43); // {d_state, ng, nt, ns} + device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*nb51 + i2*nb52 + i3*nb53); // {d_state, ng, nt, ns} + device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh} + device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*sizeof(float)); // {dim, nh, nt, ns} + + const float dt_soft_plus = dt[0] <= 20.0f ? log1pf(expf(dt[0])) : dt[0]; + const float x_dt = x[0] * dt_soft_plus; + float sumf = 0.0f; + + for (int64_t i0 = 0; i0 < nc; ++i0) { + const int64_t i = i0 + i1*nc; + const float state = (s0[i] * expf(dt_soft_plus * A[i0])) + (B[i0] * x_dt); + sumf += state * C[i0]; + s[i] = state; } - // i1 == 0 - float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; - float x_dt = x[0] * dt_soft_plus; + y[0] = sumf + x[0] * D[0]; + + // recurse + s0 = s; + } +} + +// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part +// TODO: optimize (e.g. by parallelizing over d_state) +kernel void kernel_ssm_scan_f32_group( + device const void * src0, + device const void * src1, + device const void * src2, + device const void * src3, + device const void * src4, + device const void * src5, + device const void * src6, + device const void * src7, + device float * dst, + constant int64_t & d_state, + constant int64_t & d_inner, + constant int64_t & n_head, + constant int64_t & n_group, + constant int64_t & n_seq_tokens, + constant int64_t & n_seqs, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant uint64_t & nb20, + constant uint64_t & nb21, + constant uint64_t & nb22, + constant uint64_t & nb30, + constant uint64_t & nb31, + constant uint64_t & nb40, + constant uint64_t & nb41, + constant uint64_t & nb42, + constant uint64_t & nb43, + constant uint64_t & nb50, + constant uint64_t & nb51, + constant uint64_t & nb52, + constant uint64_t & nb53, + constant uint64_t & nb60, + constant uint64_t & nb70, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i1 = tgpig.x; + const int64_t ir = tgpig.y; // current head + const int64_t i3 = tgpig.z; // current seq + + const int64_t nc = d_state; + const int64_t nr = d_inner; + const int64_t nh = n_head; + const int64_t ng = n_group; + const int64_t n_t = n_seq_tokens; + const int64_t n_s = n_seqs; + + const int64_t s_off = d_inner * n_head * n_seq_tokens * n_seqs * sizeof(float); + + device const int32_t * ids = (device const int32_t *) src7; + + device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + ids[i3]*nb03); + device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb03 + s_off); + + for (int64_t i2 = 0; i2 < n_t; ++i2) { + device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*nb11 + i2*nb12 + i3*nb13); // {dim, nh, nt, ns} + device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22); // {nh, nt, ns} + device const float * A = (device const float *) ((device const char *) src3 + ir*nb31); // {1, nh} + device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*nb41 + i2*nb42 + i3*nb43); // {d_state, ng, nt, ns} + device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*nb51 + i2*nb52 + i3*nb53); // {d_state, ng, nt, ns} + device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh} + device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*sizeof(float)); // {dim, nh, nt, ns} + + const float dt_soft_plus = dt[0] <= 20.0f ? log1pf(expf(dt[0])) : dt[0]; + const float x_dt = x[0] * dt_soft_plus; + const float dA = expf(dt_soft_plus * A[0]); float sumf = 0.0f; for (int64_t i0 = 0; i0 < nc; ++i0) { - int64_t i = i0; - float state = (s0[i] * exp(dt_soft_plus * A[i])) + (B[i0] * x_dt); + const int64_t i = i0 + i1*nc; + const float state = (s0[i] * dA) + (B[i0] * x_dt); sumf += state * C[i0]; s[i] = state; } - y[0] = sumf; + y[0] = sumf + x[0] * D[0]; + + // recurse + s0 = s; } } From 87b97d08f43652c7a2e73929e34432ae5f9e8713 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 2 Oct 2024 10:41:10 -0400 Subject: [PATCH 10/18] metal : fix SSM_SCAN pipeline scope --- ggml/src/ggml-metal.m | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 5d5b98307d264..477f720a0e32f 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -1693,11 +1693,13 @@ static void ggml_metal_encode_node( const int64_t n_seq_tokens = ne11; const int64_t n_seqs = ne13; + id pipeline = nil; + if (ne30 == 1) { // Mamba-2 - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP].pipeline; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP].pipeline; } else { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; } [encoder setComputePipelineState:pipeline]; From 03d0e6eabe6172a56a7d470bfd844012f2c2b291 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 2 Oct 2024 10:58:41 -0400 Subject: [PATCH 11/18] metal : use log and exp instead of log1pf and expf in SSM_SCAN --- ggml/src/ggml-metal.metal | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index c75fa25c34e7d..cee9980a75619 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -866,13 +866,13 @@ kernel void kernel_ssm_scan_f32( device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh} device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*sizeof(float)); // {dim, nh, nt, ns} - const float dt_soft_plus = dt[0] <= 20.0f ? log1pf(expf(dt[0])) : dt[0]; + const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; const float x_dt = x[0] * dt_soft_plus; float sumf = 0.0f; for (int64_t i0 = 0; i0 < nc; ++i0) { const int64_t i = i0 + i1*nc; - const float state = (s0[i] * expf(dt_soft_plus * A[i0])) + (B[i0] * x_dt); + const float state = (s0[i] * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt); sumf += state * C[i0]; s[i] = state; } @@ -955,9 +955,9 @@ kernel void kernel_ssm_scan_f32_group( device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh} device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*sizeof(float)); // {dim, nh, nt, ns} - const float dt_soft_plus = dt[0] <= 20.0f ? log1pf(expf(dt[0])) : dt[0]; + const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; const float x_dt = x[0] * dt_soft_plus; - const float dA = expf(dt_soft_plus * A[0]); + const float dA = exp(dt_soft_plus * A[0]); float sumf = 0.0f; for (int64_t i0 = 0; i0 < nc; ++i0) { From 7a351abc28e36aeb73d1fd8ce172db56fbb3ebcb Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 2 Oct 2024 11:28:16 -0400 Subject: [PATCH 12/18] metal : remove unused arguments for SSM_SCAN The max index is 31, so trimming the arguments is necessary. --- ggml/src/ggml-metal.m | 53 ++++++++++++++++----------------------- ggml/src/ggml-metal.metal | 34 +++++++++---------------- 2 files changed, 34 insertions(+), 53 deletions(-) diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 477f720a0e32f..5127b34f8edaa 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -1655,7 +1655,7 @@ static void ggml_metal_encode_node( const int64_t ne30 = src3->ne[0]; const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31); - const uint64_t nb30 = src3->nb[0]; + const uint64_t nb30 = src3->nb[0]; GGML_UNUSED(nb30); const uint64_t nb31 = src3->nb[1]; const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40); @@ -1663,7 +1663,7 @@ static void ggml_metal_encode_node( const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42); const int64_t ne43 = src4->ne[3]; GGML_UNUSED(ne43); - const uint64_t nb40 = src4->nb[0]; + const uint64_t nb40 = src4->nb[0]; GGML_UNUSED(nb40); const uint64_t nb41 = src4->nb[1]; const uint64_t nb42 = src4->nb[2]; const uint64_t nb43 = src4->nb[3]; @@ -1673,18 +1673,18 @@ static void ggml_metal_encode_node( const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52); const int64_t ne53 = src5->ne[3]; GGML_UNUSED(ne53); - const uint64_t nb50 = src5->nb[0]; + const uint64_t nb50 = src5->nb[0]; GGML_UNUSED(nb50); const uint64_t nb51 = src5->nb[1]; const uint64_t nb52 = src5->nb[2]; const uint64_t nb53 = src5->nb[3]; const int64_t ne60 = src6->ne[0]; GGML_UNUSED(ne60); - const uint64_t nb60 = src6->nb[0]; + const uint64_t nb60 = src6->nb[0]; GGML_UNUSED(nb60); const int64_t ne70 = src7->ne[0]; GGML_UNUSED(ne70); - const uint64_t nb70 = src7->nb[0]; + const uint64_t nb70 = src7->nb[0]; GGML_UNUSED(nb70); const int64_t d_state = ne00; const int64_t d_inner = ne01; @@ -1718,32 +1718,23 @@ static void ggml_metal_encode_node( [encoder setBytes:&n_head length:sizeof(n_head) atIndex:11]; [encoder setBytes:&n_group length:sizeof(n_group) atIndex:12]; [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:13]; - [encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:14]; - - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:15]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:16]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:17]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:18]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:19]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:20]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:21]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:22]; - [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:23]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:24]; - [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:25]; - [encoder setBytes:&nb23 length:sizeof(nb23) atIndex:26]; - [encoder setBytes:&nb30 length:sizeof(nb30) atIndex:27]; - [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:28]; - [encoder setBytes:&nb40 length:sizeof(nb40) atIndex:29]; - [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:30]; - [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:31]; - [encoder setBytes:&nb43 length:sizeof(nb43) atIndex:32]; - [encoder setBytes:&nb50 length:sizeof(nb50) atIndex:33]; - [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:34]; - [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:35]; - [encoder setBytes:&nb53 length:sizeof(nb53) atIndex:36]; - [encoder setBytes:&nb60 length:sizeof(nb60) atIndex:37]; - [encoder setBytes:&nb70 length:sizeof(nb70) atIndex:38]; + + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:14]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:15]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:16]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:17]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:18]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:19]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:20]; + [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:21]; + [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22]; + [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:23]; + [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:24]; + [encoder setBytes:&nb43 length:sizeof(nb43) atIndex:25]; + [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:26]; + [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:27]; + [encoder setBytes:&nb53 length:sizeof(nb53) atIndex:28]; + // NOTE: max index is 31 if (ne30 == 1) { // Mamba-2 diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index cee9980a75619..3745f2f225512 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -812,30 +812,21 @@ kernel void kernel_ssm_scan_f32( constant int64_t & n_head, constant int64_t & n_group, constant int64_t & n_seq_tokens, - constant int64_t & n_seqs, - constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, constant uint64_t & nb03, - constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, constant uint64_t & nb13, - constant uint64_t & nb20, constant uint64_t & nb21, constant uint64_t & nb22, - constant uint64_t & nb30, constant uint64_t & nb31, - constant uint64_t & nb40, constant uint64_t & nb41, constant uint64_t & nb42, constant uint64_t & nb43, - constant uint64_t & nb50, constant uint64_t & nb51, constant uint64_t & nb52, constant uint64_t & nb53, - constant uint64_t & nb60, - constant uint64_t & nb70, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { @@ -843,12 +834,16 @@ kernel void kernel_ssm_scan_f32( const int64_t ir = tgpig.x; // current head const int64_t i3 = tgpig.y; // current seq + const uint64_t nb00 = sizeof(float); + const uint64_t nb10 = sizeof(float); + const uint64_t nb20 = sizeof(float); + const uint64_t nb60 = sizeof(float); + const int64_t nc = d_state; const int64_t nr = d_inner; const int64_t nh = n_head; const int64_t ng = n_group; const int64_t n_t = n_seq_tokens; - const int64_t n_s = n_seqs; const int64_t s_off = d_inner * n_head * n_seq_tokens * n_seqs * sizeof(float); @@ -864,7 +859,7 @@ kernel void kernel_ssm_scan_f32( device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*nb41 + i2*nb42 + i3*nb43); // {d_state, ng, nt, ns} device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*nb51 + i2*nb52 + i3*nb53); // {d_state, ng, nt, ns} device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh} - device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*sizeof(float)); // {dim, nh, nt, ns} + device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns} const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; const float x_dt = x[0] * dt_soft_plus; @@ -901,30 +896,21 @@ kernel void kernel_ssm_scan_f32_group( constant int64_t & n_head, constant int64_t & n_group, constant int64_t & n_seq_tokens, - constant int64_t & n_seqs, - constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, constant uint64_t & nb03, - constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, constant uint64_t & nb13, - constant uint64_t & nb20, constant uint64_t & nb21, constant uint64_t & nb22, - constant uint64_t & nb30, constant uint64_t & nb31, - constant uint64_t & nb40, constant uint64_t & nb41, constant uint64_t & nb42, constant uint64_t & nb43, - constant uint64_t & nb50, constant uint64_t & nb51, constant uint64_t & nb52, constant uint64_t & nb53, - constant uint64_t & nb60, - constant uint64_t & nb70, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { @@ -932,12 +918,16 @@ kernel void kernel_ssm_scan_f32_group( const int64_t ir = tgpig.y; // current head const int64_t i3 = tgpig.z; // current seq + const uint64_t nb00 = sizeof(float); + const uint64_t nb10 = sizeof(float); + const uint64_t nb20 = sizeof(float); + const uint64_t nb60 = sizeof(float); + const int64_t nc = d_state; const int64_t nr = d_inner; const int64_t nh = n_head; const int64_t ng = n_group; const int64_t n_t = n_seq_tokens; - const int64_t n_s = n_seqs; const int64_t s_off = d_inner * n_head * n_seq_tokens * n_seqs * sizeof(float); @@ -953,7 +943,7 @@ kernel void kernel_ssm_scan_f32_group( device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*nb41 + i2*nb42 + i3*nb43); // {d_state, ng, nt, ns} device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*nb51 + i2*nb52 + i3*nb53); // {d_state, ng, nt, ns} device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh} - device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*sizeof(float)); // {dim, nh, nt, ns} + device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns} const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; const float x_dt = x[0] * dt_soft_plus; From 8b15bc6fa0fbb7a0d831b90955430c0a9e281ac2 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 2 Oct 2024 11:47:56 -0400 Subject: [PATCH 13/18] metal : add back n_seqs to SSM_SCAN args Whoops, this is needed for the offset in the concatenated output. --- ggml/src/ggml-metal.m | 33 +++++++++++++++++---------------- ggml/src/ggml-metal.metal | 2 ++ 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 5127b34f8edaa..3f7183060d83d 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -1718,22 +1718,23 @@ static void ggml_metal_encode_node( [encoder setBytes:&n_head length:sizeof(n_head) atIndex:11]; [encoder setBytes:&n_group length:sizeof(n_group) atIndex:12]; [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:13]; - - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:14]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:15]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:16]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:17]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:18]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:19]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:20]; - [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:21]; - [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22]; - [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:23]; - [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:24]; - [encoder setBytes:&nb43 length:sizeof(nb43) atIndex:25]; - [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:26]; - [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:27]; - [encoder setBytes:&nb53 length:sizeof(nb53) atIndex:28]; + [encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:14]; + + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:15]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:16]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:17]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:20]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:21]; + [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:22]; + [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:23]; + [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24]; + [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25]; + [encoder setBytes:&nb43 length:sizeof(nb43) atIndex:26]; + [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27]; + [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28]; + [encoder setBytes:&nb53 length:sizeof(nb53) atIndex:29]; // NOTE: max index is 31 if (ne30 == 1) { diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 3745f2f225512..c36eedb010de1 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -812,6 +812,7 @@ kernel void kernel_ssm_scan_f32( constant int64_t & n_head, constant int64_t & n_group, constant int64_t & n_seq_tokens, + constant int64_t & n_seqs, constant uint64_t & nb01, constant uint64_t & nb02, constant uint64_t & nb03, @@ -896,6 +897,7 @@ kernel void kernel_ssm_scan_f32_group( constant int64_t & n_head, constant int64_t & n_group, constant int64_t & n_seq_tokens, + constant int64_t & n_seqs, constant uint64_t & nb01, constant uint64_t & nb02, constant uint64_t & nb03, From 5b8ec2b978b84dfdb05e6fca4def928f72b1090c Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 2 Oct 2024 12:11:45 -0400 Subject: [PATCH 14/18] metal : fix SSM_SCAN state head offset --- ggml/src/ggml-metal.metal | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index c36eedb010de1..9e1d14ff5d8b5 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -850,8 +850,8 @@ kernel void kernel_ssm_scan_f32( device const int32_t * ids = (device const int32_t *) src7; - device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + ids[i3]*nb03); - device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb03 + s_off); + device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb02 + ids[i3]*nb03); + device float * s = (device float *) ((device char *) dst + ir*nb02 + i3*nb03 + s_off); for (int64_t i2 = 0; i2 < n_t; ++i2) { device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*nb11 + i2*nb12 + i3*nb13); // {dim, nh, nt, ns} @@ -935,8 +935,8 @@ kernel void kernel_ssm_scan_f32_group( device const int32_t * ids = (device const int32_t *) src7; - device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + ids[i3]*nb03); - device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb03 + s_off); + device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb02 + ids[i3]*nb03); + device float * s = (device float *) ((device char *) dst + ir*nb02 + i3*nb03 + s_off); for (int64_t i2 = 0; i2 < n_t; ++i2) { device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*nb11 + i2*nb12 + i3*nb13); // {dim, nh, nt, ns} From 62b09b343c6c4e35486368f1a7b653c9ae58574a Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 2 Oct 2024 21:35:50 -0400 Subject: [PATCH 15/18] metal : fix wrong number of tokens per sequence in SSM_SCAN --- ggml/src/ggml-metal.m | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 3f7183060d83d..a39770bd4ed1b 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -1690,7 +1690,7 @@ static void ggml_metal_encode_node( const int64_t d_inner = ne01; const int64_t n_head = ne02; const int64_t n_group = ne41; - const int64_t n_seq_tokens = ne11; + const int64_t n_seq_tokens = ne12; const int64_t n_seqs = ne13; id pipeline = nil; From 805512a73b9876853f0e7d0cd612259806fa5d93 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sat, 12 Oct 2024 16:20:26 -0400 Subject: [PATCH 16/18] ggml : remove unused fast broadcast path in GGML_MUL This was initially added because states were masked with ggml_mul, but this is no longer done and so this "optimisation" is no longer necessary, or at least not worth the additional code complexity. --- ggml/src/ggml.c | 32 +------------------------------- 1 file changed, 1 insertion(+), 31 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index e8a5e3d153548..8fd335270dd5a 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -10173,37 +10173,7 @@ static void ggml_compute_forward_mul_f32( GGML_ASSERT( nb0 == sizeof(float)); GGML_ASSERT(nb00 == sizeof(float)); - if (ne00 > 1 && ne10 == 1) { - // fast broadcast path - for (int64_t ir = ith; ir < nr; ir += nth) { - // src0 and dst are same shape => same indices - const int64_t i03 = ir/(ne02*ne01); - const int64_t i02 = (ir - i03*ne02*ne01)/ne01; - const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - - float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); - - const float scale = src1_ptr[0]; - - if (scale == 0.0f) { - // NOTE: this also sets NANs to zero, which is not compliant with IEEE754, - // but it is useful when resetting the state of recurrent models. - memset((char *) dst->data + ir*nb1, 0, ne0 * sizeof(float)); - } else { - if (dst->data != src0->data) { - // src0 is same shape as dst => same indices - memcpy((char *) dst->data + ir*nb1, (char *) src0->data + ir*nb01, ne0 * sizeof(float)); - } - if (scale != 1.0f) { - ggml_vec_scale_f32(ne0, (float *) ((char *) dst->data + ir*nb1), scale); - } - } - } - } else if (nb10 == sizeof(float)) { + if (nb10 == sizeof(float)) { for (int64_t ir = ith; ir < nr; ir += nth) { // src0 and dst are same shape => same indices const int64_t i03 = ir/(ne02*ne01); From 3bc7103d2ef1c41cd380a1ad8d918cf9c26694d8 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 4 Nov 2024 11:36:37 -0500 Subject: [PATCH 17/18] ggml : avoid multiply by D in GGML_OP_SSM_SCAN This makes the weight buft detection in src/llama.cpp simpler. * convert : transpose Mamba-2 A, D and reshape SSM_NORM This breaks existing conversions of Mamba-2 models to avoid some reshapes. Not sure if it's a good idea, but it makes the graph slightly cleaner. * llama : more appropriate SSM_SCAN and SSM_CONV buft support checks --- convert_hf_to_gguf.py | 26 ++++++++++++++++- ggml/include/ggml.h | 1 - ggml/src/ggml-metal.m | 57 ++++++++++++++++---------------------- ggml/src/ggml-metal.metal | 14 +++------- ggml/src/ggml.c | 20 ++++--------- src/llama.cpp | 54 +++++++++++++++++++----------------- tests/test-backend-ops.cpp | 25 ++++++++--------- 7 files changed, 100 insertions(+), 97 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index f307b1ac69202..f0a63d921d65f 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -264,6 +264,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [(self.map_tensor_name(name), data_torch)] + # TODO: merge into modify_tensors? (need to check tensor shapes for all arches before doing that) + def reshape_tensors(self, data_torch: Tensor, new_name: str, bid: int | None) -> Tensor: + del new_name, bid # unused + + return data_torch.squeeze() + def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool: del name, new_name, bid, n_dims # unused @@ -295,7 +301,7 @@ def prepare_tensors(self): break for new_name, data_torch in (self.modify_tensors(data_torch, name, bid)): - data = data_torch.squeeze().numpy() + data = self.reshape_tensors(data_torch, new_name, bid).numpy() # if data ends up empty, it means data_torch was a scalar tensor -> restore if len(data.shape) == 0: @@ -3063,6 +3069,24 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter yield (new_name, data_torch) + def reshape_tensors(self, data_torch: Tensor, new_name: str, bid: int | None) -> Tensor: + if any(self.match_model_tensor_name(new_name, t, bid, suffix="") for t in [ + gguf.MODEL_TENSOR.SSM_A, + gguf.MODEL_TENSOR.SSM_D, + ]): + # unsqueeze A to use similar shape semantics as Mamba-1 + # (D is also unsqueezed, but for more straightforward broadcast internally) + return data_torch.reshape((*data_torch.shape, 1)) + + elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_NORM, bid): + d_model = self.find_hparam(["hidden_size", "d_model", "dim"]) + d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model + n_group = self.hparams.get("n_groups", 1) + return data_torch.reshape((n_group, d_inner // n_group)) + + return data_torch.squeeze() + + @Model.register("CohereForCausalLM") class CommandR2Model(Model): diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 0d2e5cb011a3b..735f56b005a28 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1828,7 +1828,6 @@ extern "C" { struct ggml_tensor * A, struct ggml_tensor * B, struct ggml_tensor * C, - struct ggml_tensor * D, struct ggml_tensor * ids); // partition into non-overlapping windows with padding if needed diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 73e2fedc36544..902728d8e6b55 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -1649,25 +1649,21 @@ static void ggml_metal_encode_node( struct ggml_tensor * src4 = node->src[4]; struct ggml_tensor * src5 = node->src[5]; struct ggml_tensor * src6 = node->src[6]; - struct ggml_tensor * src7 = node->src[7]; GGML_ASSERT(src3); GGML_ASSERT(src4); GGML_ASSERT(src5); GGML_ASSERT(src6); - GGML_ASSERT(src7); size_t offs_src3 = 0; size_t offs_src4 = 0; size_t offs_src5 = 0; size_t offs_src6 = 0; - size_t offs_src7 = 0; id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; id id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil; id id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil; id id_src6 = src6 ? ggml_metal_get_buffer(src6, &offs_src6) : nil; - id id_src7 = src7 ? ggml_metal_get_buffer(src7, &offs_src7) : nil; const int64_t ne30 = src3->ne[0]; const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31); @@ -1699,10 +1695,6 @@ static void ggml_metal_encode_node( const uint64_t nb60 = src6->nb[0]; GGML_UNUSED(nb60); - const int64_t ne70 = src7->ne[0]; GGML_UNUSED(ne70); - - const uint64_t nb70 = src7->nb[0]; GGML_UNUSED(nb70); - const int64_t d_state = ne00; const int64_t d_inner = ne01; const int64_t n_head = ne02; @@ -1727,31 +1719,30 @@ static void ggml_metal_encode_node( [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4]; [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5]; [encoder setBuffer:id_src6 offset:offs_src6 atIndex:6]; - [encoder setBuffer:id_src7 offset:offs_src7 atIndex:7]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:8]; - - [encoder setBytes:&d_state length:sizeof(d_state) atIndex:9]; - [encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:10]; - [encoder setBytes:&n_head length:sizeof(n_head) atIndex:11]; - [encoder setBytes:&n_group length:sizeof(n_group) atIndex:12]; - [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:13]; - [encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:14]; - - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:15]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:16]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:17]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:20]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:21]; - [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:22]; - [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:23]; - [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24]; - [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25]; - [encoder setBytes:&nb43 length:sizeof(nb43) atIndex:26]; - [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27]; - [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28]; - [encoder setBytes:&nb53 length:sizeof(nb53) atIndex:29]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:7]; + + [encoder setBytes:&d_state length:sizeof(d_state) atIndex:8]; + [encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:9]; + [encoder setBytes:&n_head length:sizeof(n_head) atIndex:10]; + [encoder setBytes:&n_group length:sizeof(n_group) atIndex:11]; + [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:12]; + [encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:13]; + + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:14]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:15]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:16]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:17]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:18]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:19]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:20]; + [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:21]; + [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22]; + [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:23]; + [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:24]; + [encoder setBytes:&nb43 length:sizeof(nb43) atIndex:25]; + [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:26]; + [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:27]; + [encoder setBytes:&nb53 length:sizeof(nb53) atIndex:28]; // NOTE: max index is 31 if (ne30 == 1) { diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 2f5a4d12eeec3..05d04e8f3fdbf 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -805,7 +805,6 @@ kernel void kernel_ssm_scan_f32( device const void * src4, device const void * src5, device const void * src6, - device const void * src7, device float * dst, constant int64_t & d_state, constant int64_t & d_inner, @@ -838,7 +837,6 @@ kernel void kernel_ssm_scan_f32( const uint64_t nb00 = sizeof(float); const uint64_t nb10 = sizeof(float); const uint64_t nb20 = sizeof(float); - const uint64_t nb60 = sizeof(float); const int64_t nc = d_state; const int64_t nr = d_inner; @@ -848,7 +846,7 @@ kernel void kernel_ssm_scan_f32( const int64_t s_off = d_inner * n_head * n_seq_tokens * n_seqs * sizeof(float); - device const int32_t * ids = (device const int32_t *) src7; + device const int32_t * ids = (device const int32_t *) src6; device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb02 + ids[i3]*nb03); device float * s = (device float *) ((device char *) dst + ir*nb02 + i3*nb03 + s_off); @@ -859,7 +857,6 @@ kernel void kernel_ssm_scan_f32( device const float * A = (device const float *) ((device const char *) src3 + ir*nb31); // {d_state, nh} device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*nb41 + i2*nb42 + i3*nb43); // {d_state, ng, nt, ns} device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*nb51 + i2*nb52 + i3*nb53); // {d_state, ng, nt, ns} - device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh} device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns} const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; @@ -873,7 +870,7 @@ kernel void kernel_ssm_scan_f32( s[i] = state; } - y[0] = sumf + x[0] * D[0]; + y[0] = sumf; // recurse s0 = s; @@ -890,7 +887,6 @@ kernel void kernel_ssm_scan_f32_group( device const void * src4, device const void * src5, device const void * src6, - device const void * src7, device float * dst, constant int64_t & d_state, constant int64_t & d_inner, @@ -923,7 +919,6 @@ kernel void kernel_ssm_scan_f32_group( const uint64_t nb00 = sizeof(float); const uint64_t nb10 = sizeof(float); const uint64_t nb20 = sizeof(float); - const uint64_t nb60 = sizeof(float); const int64_t nc = d_state; const int64_t nr = d_inner; @@ -933,7 +928,7 @@ kernel void kernel_ssm_scan_f32_group( const int64_t s_off = d_inner * n_head * n_seq_tokens * n_seqs * sizeof(float); - device const int32_t * ids = (device const int32_t *) src7; + device const int32_t * ids = (device const int32_t *) src6; device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb02 + ids[i3]*nb03); device float * s = (device float *) ((device char *) dst + ir*nb02 + i3*nb03 + s_off); @@ -944,7 +939,6 @@ kernel void kernel_ssm_scan_f32_group( device const float * A = (device const float *) ((device const char *) src3 + ir*nb31); // {1, nh} device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*nb41 + i2*nb42 + i3*nb43); // {d_state, ng, nt, ns} device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*nb51 + i2*nb52 + i3*nb53); // {d_state, ng, nt, ns} - device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh} device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns} const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; @@ -959,7 +953,7 @@ kernel void kernel_ssm_scan_f32_group( s[i] = state; } - y[0] = sumf + x[0] * D[0]; + y[0] = sumf; // recurse s0 = s; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 91b256a4c25f0..9036fc0be9858 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -7181,7 +7181,6 @@ struct ggml_tensor * ggml_ssm_conv( const int64_t n_s = sx->ne[2]; // TODO: maybe support other strides than 1? - // FIXME: this is always true? GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t); GGML_ASSERT(sx->ne[1] == d_inner); GGML_ASSERT(n_t >= 0); @@ -7205,7 +7204,6 @@ struct ggml_tensor * ggml_ssm_scan( struct ggml_tensor * A, struct ggml_tensor * B, struct ggml_tensor * C, - struct ggml_tensor * D, struct ggml_tensor * ids) { GGML_ASSERT(ggml_is_contiguous(s)); GGML_ASSERT(ggml_is_contiguous(dt)); @@ -7235,8 +7233,6 @@ struct ggml_tensor * ggml_ssm_scan( GGML_ASSERT(B->ne[0] == d_state); GGML_ASSERT(B->ne[2] == n_seq_tokens); GGML_ASSERT(B->ne[3] == n_seqs); - GGML_ASSERT(D->ne[0] == n_head); - GGML_ASSERT(ggml_is_vector(D)); GGML_ASSERT(ids->ne[0] == n_seqs); GGML_ASSERT(ggml_is_vector(ids)); GGML_ASSERT(A->ne[1] == n_head); @@ -7258,8 +7254,7 @@ struct ggml_tensor * ggml_ssm_scan( result->src[3] = A; result->src[4] = B; result->src[5] = C; - result->src[6] = D; - result->src[7] = ids; + result->src[6] = ids; return result; } @@ -16217,8 +16212,7 @@ static void ggml_compute_forward_ssm_scan_f32( const struct ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head} const struct ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs} const struct ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs} - const struct ggml_tensor * src6 = dst->src[6]; // D {n_head} - const struct ggml_tensor * src7 = dst->src[7]; // ids {n_seqs} + const struct ggml_tensor * src6 = dst->src[6]; // ids {n_seqs} const int ith = params->ith; const int nth = params->nth; @@ -16240,8 +16234,7 @@ static void ggml_compute_forward_ssm_scan_f32( GGML_ASSERT(src3->nb[0] == sizeof(float)); GGML_ASSERT(src4->nb[0] == sizeof(float)); GGML_ASSERT(src5->nb[0] == sizeof(float)); - GGML_ASSERT(src6->nb[0] == sizeof(float)); - GGML_ASSERT(src7->nb[0] == sizeof(int32_t)); + GGML_ASSERT(src6->nb[0] == sizeof(int32_t)); // allows optimizing the modulo since n_group should be a power of 2 GGML_ASSERT((ng & -ng) == ng); @@ -16252,7 +16245,7 @@ static void ggml_compute_forward_ssm_scan_f32( const int ih0 = dh*ith; const int ih1 = MIN(ih0 + dh, nh); - const int32_t * ids = (const int32_t *) src7->data; + const int32_t * ids = (const int32_t *) src6->data; for (int i3 = 0; i3 < ns; ++i3) { const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns} @@ -16264,7 +16257,6 @@ static void ggml_compute_forward_ssm_scan_f32( const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh} const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns} const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns} - const float * D = (const float *) ((const char *) src6->data); // {nh} float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns} if (src3->ne[0] == 1) { @@ -16325,7 +16317,7 @@ static void ggml_compute_forward_ssm_scan_f32( sumf += state * C[ig]; s[i] = state; } - y[ii] = sumf + x[ii] * D[h]; + y[ii] = sumf; } } } else { @@ -16353,7 +16345,7 @@ static void ggml_compute_forward_ssm_scan_f32( sumf += state * C[ig]; s[i] = state; } - y[ii] = sumf + x[ii] * D[h]; + y[ii] = sumf; } } } diff --git a/src/llama.cpp b/src/llama.cpp index e84510ce8ffd1..52052caf250b1 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -7120,6 +7120,7 @@ static const std::map llm_tensor_info_mapping = { {LLM_TENSOR_SSM_CONV1D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}}, {LLM_TENSOR_SSM_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_SCAN}}, {LLM_TENSOR_SSM_D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_SSM_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_TIME_MIX_LERP_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, @@ -7227,23 +7228,27 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w } break; case GGML_OP_SSM_CONV: { - // FIXME - ggml_tensor * conv_x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 12345, w->ne[1], 6789); + const int64_t n_seq_tokens = 512; + const int64_t n_seqs = 3; + ggml_tensor * conv_x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0] - 1 + n_seq_tokens, w->ne[1], n_seqs); op_tensor = ggml_ssm_conv(ctx, conv_x, w); } break; case GGML_OP_SSM_SCAN: { - // FIXME - const int64_t d_state = w->ne[0]; - const int64_t d_inner = w->ne[1]; + // w is ssm_a + const int64_t d_state = w->ne[0] == 1 ? hparams.ssm_d_state : w->ne[0]; + const int64_t n_head = w->ne[1]; + const int64_t head_dim = hparams.ssm_d_inner / n_head; + const int64_t n_group = hparams.ssm_n_group; const int64_t n_seq_tokens = 512; - const int64_t n_seqs = 1; - ggml_tensor * s = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, d_inner, n_seqs); - ggml_tensor * x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs); - ggml_tensor * dt = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs); - ggml_tensor * B = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs); - ggml_tensor * C = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs); - op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C); + const int64_t n_seqs = 3; + ggml_tensor * s = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, head_dim, n_head, n_seqs); + ggml_tensor * x = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, head_dim, n_head, n_seq_tokens, n_seqs); + ggml_tensor * dt = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_head, n_seq_tokens, n_seqs); + ggml_tensor * B = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * C = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs); + op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C, ids); } break; case GGML_OP_RWKV_WKV: { @@ -8572,10 +8577,10 @@ static bool llm_load_tensors( layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_head}, 0); // no "weight" suffix for these - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {n_head}, 0); - layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {n_head}, 0); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_head}, 0); - layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner}, 0); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0); // out_proj layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); @@ -9994,7 +9999,7 @@ static struct ggml_tensor * llm_build_rs( return states; } -// TODO: split +// TODO: split conv and ssm static struct ggml_tensor * llm_build_mamba( struct ggml_context * ctx, struct llama_context & lctx, @@ -10102,13 +10107,14 @@ static struct ggml_tensor * llm_build_mamba( dt = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_dt, dt); dt = ggml_add(ctx, dt, model.layers[il].ssm_dt_b); + cur = x; x = ggml_reshape_4d(ctx, x, head_dim, n_head, n_seq_tokens, n_seqs); struct ggml_tensor * ssm_ids = ggml_view_1d(ctx, state_copy, n_seqs, 0); // Custom operator to optimize the parallel associative scan // as described in the Annex D of the Mamba paper. // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} - struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C, model.layers[il].ssm_d, ssm_ids); + struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C, ssm_ids); // store last states ggml_build_forward_expand(graph, @@ -10120,6 +10126,7 @@ static struct ggml_tensor * llm_build_mamba( // TODO: skip computing output earlier for unused tokens + y = ggml_add(ctx, y, ggml_mul(ctx, cur, model.layers[il].ssm_d)); y = ggml_mul(ctx, y, ggml_silu(ctx, ggml_cont(ctx, z))); // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} @@ -10184,7 +10191,7 @@ static struct ggml_tensor * llm_build_mamba2( struct ggml_tensor * zxBCdt = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_in, cur); // split the above in three - struct ggml_tensor * z = ggml_view_3d(ctx, zxBCdt, d_inner, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], 0); + struct ggml_tensor * z = ggml_view_4d(ctx, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*zxBCdt->nb[0], zxBCdt->nb[1], zxBCdt->nb[2], 0); struct ggml_tensor * xBC = ggml_view_3d(ctx, zxBCdt, d_inner + 2*n_group*d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], d_inner*ggml_element_size(zxBCdt)); struct ggml_tensor * dt = ggml_view_3d(ctx, zxBCdt, n_head, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], (2*d_inner + 2*n_group*d_state)*ggml_element_size(zxBCdt)); @@ -10230,11 +10237,9 @@ static struct ggml_tensor * llm_build_mamba2( dt = ggml_add(ctx, dt, model.layers[il].ssm_dt_b); struct ggml_tensor * ssm_ids = ggml_view_1d(ctx, state_copy, n_seqs, 0); - // Use the same shape semantics for A as Mamba-1 - struct ggml_tensor * A = ggml_reshape_2d(ctx, model.layers[il].ssm_a, 1, n_head); // TODO: use semistructured matrices to implement state-space duality // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} - struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, model.layers[il].ssm_d, ssm_ids); + struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C, ssm_ids); // store last states ggml_build_forward_expand(graph, @@ -10242,17 +10247,16 @@ static struct ggml_tensor * llm_build_mamba2( ggml_view_1d(ctx, y_ssm, d_state*d_inner*n_seqs, ggml_nelements(x)*x->nb[0]), ggml_view_1d(ctx, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); - struct ggml_tensor * y = ggml_view_3d(ctx, y_ssm, d_inner, n_seq_tokens, n_seqs, n_head*x->nb[1], n_seq_tokens*n_head*x->nb[1], 0); + struct ggml_tensor * y = ggml_view_4d(ctx, y_ssm, head_dim, n_head, n_seq_tokens, n_seqs, x->nb[1], n_head*x->nb[1], n_seq_tokens*n_head*x->nb[1], 0); // TODO: skip computing output earlier for unused tokens + y = ggml_add(ctx, y, ggml_mul(ctx, x, model.layers[il].ssm_d)); y = ggml_mul(ctx, y, ggml_silu(ctx, ggml_cont(ctx, z))); // grouped RMS norm y = ggml_reshape_4d(ctx, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs); - y = llm_build_norm(ctx, y, hparams, - ggml_reshape_2d(ctx, model.layers[il].ssm_norm, d_inner / n_group, n_group), NULL, - LLM_NORM_RMS, cb, il); + y = llm_build_norm(ctx, y, hparams, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, cb, il); y = ggml_reshape_3d(ctx, y, d_inner, n_seq_tokens, n_seqs); // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 6ca254a45f23f..95f8abbd80968 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1589,35 +1589,34 @@ struct test_ssm_scan : public test_case { const ggml_type type; const int64_t d_state; - const int64_t d_inner; + const int64_t head_dim; const int64_t n_head; const int64_t n_group; const int64_t n_seq_tokens; const int64_t n_seqs; std::string vars() override { - return VARS_TO_STR7(type, d_state, d_inner, n_head, n_group, n_seq_tokens, n_seqs); + return VARS_TO_STR7(type, d_state, head_dim, n_head, n_group, n_seq_tokens, n_seqs); } test_ssm_scan(ggml_type type = GGML_TYPE_F32, int64_t d_state = 32, - int64_t d_inner = 1, // non-zero for Mamba-2 + int64_t head_dim = 1, // non-zero for Mamba-2 int64_t n_head = 32, int64_t n_group = 1, int64_t n_seq_tokens = 32, int64_t n_seqs = 32) - : type(type), d_state(d_state), d_inner(d_inner), n_head(n_head), n_group(n_group), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {} + : type(type), d_state(d_state), head_dim(head_dim), n_head(n_head), n_group(n_group), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {} ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * s = ggml_new_tensor_4d(ctx, type, d_state, d_inner, n_head, n_seqs); - ggml_tensor * x = ggml_new_tensor_4d(ctx, type, d_inner, n_head, n_seq_tokens, n_seqs); - ggml_tensor * dt = ggml_new_tensor_3d(ctx, type, n_head, n_seq_tokens, n_seqs); - ggml_tensor * A = ggml_new_tensor_2d(ctx, type, (d_inner > 1) ? 1 : d_state, n_head); - ggml_tensor * B = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs); - ggml_tensor * C = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs); - ggml_tensor * D = ggml_new_tensor_1d(ctx, type, n_head); - ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs); - ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C, D, ids); + ggml_tensor * s = ggml_new_tensor_4d(ctx, type, d_state, head_dim, n_head, n_seqs); + ggml_tensor * x = ggml_new_tensor_4d(ctx, type, head_dim, n_head, n_seq_tokens, n_seqs); + ggml_tensor * dt = ggml_new_tensor_3d(ctx, type, n_head, n_seq_tokens, n_seqs); + ggml_tensor * A = ggml_new_tensor_2d(ctx, type, (head_dim > 1) ? 1 : d_state, n_head); + ggml_tensor * B = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * C = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs); + ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C, ids); return out; } From b4e9c5998dea2d657cfd22bc2e6fa0630fba2fa9 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 4 Nov 2024 15:26:15 -0500 Subject: [PATCH 18/18] convert : fix flake8 lint --- convert_hf_to_gguf.py | 1 - 1 file changed, 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index f0efe5d5b0c7c..019e7b7ef93b6 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3088,7 +3088,6 @@ def reshape_tensors(self, data_torch: Tensor, new_name: str, bid: int | None) -> return data_torch.squeeze() - @Model.register("CohereForCausalLM") class CommandR2Model(Model): model_arch = gguf.MODEL_ARCH.COMMAND_R