Skip to content

Commit

Permalink
metal : fix SSM_SCAN state head offset
Browse files Browse the repository at this point in the history
  • Loading branch information
compilade committed Oct 2, 2024
1 parent 8b15bc6 commit 5b8ec2b
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions ggml/src/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -850,8 +850,8 @@ kernel void kernel_ssm_scan_f32(

device const int32_t * ids = (device const int32_t *) src7;

device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + ids[i3]*nb03);
device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb03 + s_off);
device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb02 + ids[i3]*nb03);
device float * s = (device float *) ((device char *) dst + ir*nb02 + i3*nb03 + s_off);

for (int64_t i2 = 0; i2 < n_t; ++i2) {
device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*nb11 + i2*nb12 + i3*nb13); // {dim, nh, nt, ns}
Expand Down Expand Up @@ -935,8 +935,8 @@ kernel void kernel_ssm_scan_f32_group(

device const int32_t * ids = (device const int32_t *) src7;

device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + ids[i3]*nb03);
device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb03 + s_off);
device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb02 + ids[i3]*nb03);
device float * s = (device float *) ((device char *) dst + ir*nb02 + i3*nb03 + s_off);

for (int64_t i2 = 0; i2 < n_t; ++i2) {
device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*nb11 + i2*nb12 + i3*nb13); // {dim, nh, nt, ns}
Expand Down

0 comments on commit 5b8ec2b

Please sign in to comment.