Skip to content

Commit

Permalink
llama : fix Mamba-2 conv state saving
Browse files Browse the repository at this point in the history
* ggml : make the ggml_mul fast broadcast path more consistently formatted
  • Loading branch information
compilade committed Aug 21, 2024
1 parent 2bfe9de commit aff9692
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -10226,11 +10226,11 @@ static void ggml_compute_forward_mul_f32(
if (scale == 0.0f) {
// NOTE: this also sets NANs to zero, which is not compliant with IEEE754,
// but it is useful when resetting the state of recurrent models.
memset((char *)dst->data + ir*nb1, 0, nb1);
memset((char *) dst->data + ir*nb1, 0, ne0 * sizeof(float));
} else {
if (dst->data != src0->data) {
// src0 is same shape as dst => same indices
memcpy((char *)dst->data + ir*nb1, (char *)src0->data + ir*nb01, ne0 * sizeof(float));
memcpy((char *) dst->data + ir*nb1, (char *) src0->data + ir*nb01, ne0 * sizeof(float));
}
if (scale != 1.0f) {
ggml_vec_scale_f32(ne0, (float *) ((char *) dst->data + ir*nb1), scale);
Expand Down
2 changes: 1 addition & 1 deletion src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9335,7 +9335,7 @@ static struct ggml_tensor * llm_build_mamba2(
ggml_cpy(ctx, last_conv,
ggml_view_1d(ctx, conv_states_all,
(d_conv - 1)*(d_inner + 2*n_group*d_state)*(n_seqs),
kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all))));
kv_head*(d_conv - 1)*(d_inner + 2*n_group*d_state)*ggml_element_size(conv_states_all))));

// 1D convolution
// The equivalent is to make a self-overlapping view of conv_x
Expand Down

0 comments on commit aff9692

Please sign in to comment.