From 8d299644297d05701460ef408931bdb43d9d2bdf Mon Sep 17 00:00:00 2001 From: pglorio <85982602+pglorio@users.noreply.github.com> Date: Thu, 26 Sep 2024 17:01:35 -0700 Subject: [PATCH] Update src/transformers/models/zamba/modeling_zamba.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/models/zamba/modeling_zamba.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index 91b50508008f58..e8a10819688d7c 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -196,16 +196,8 @@ def __init__(self, config, batch_size, dtype=torch.float16, device=None): self.conv_states += [ torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype) ] - self.ssm_states += [ - torch.zeros( - batch_size, - self.n_mamba_heads, - intermediate_size // self.n_mamba_heads, - ssm_state_size, - device=device, - dtype=dtype, - ) - ] + cache_shape = (batch_size, self.n_mamba_heads, intermediate_size // self.n_mamba_heads, ssm_state_size) + self.ssm_states += [torch.zeros(cache_shape, device=device, dtype=dtype)] if self.layers_block_type[i] == "hybrid": self.transformer_layers.append(i)