Skip to content

Commit

Permalink
Update src/transformers/models/zamba/modeling_zamba.py
Browse files Browse the repository at this point in the history
Co-authored-by: Arthur <[email protected]>
  • Loading branch information
pglorio and ArthurZucker authored Sep 27, 2024
1 parent 4fcd130 commit 8d29964
Showing 1 changed file with 2 additions and 10 deletions.
12 changes: 2 additions & 10 deletions src/transformers/models/zamba/modeling_zamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,16 +196,8 @@ def __init__(self, config, batch_size, dtype=torch.float16, device=None):
self.conv_states += [
torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype)
]
self.ssm_states += [
torch.zeros(
batch_size,
self.n_mamba_heads,
intermediate_size // self.n_mamba_heads,
ssm_state_size,
device=device,
dtype=dtype,
)
]
cache_shape = (batch_size, self.n_mamba_heads, intermediate_size // self.n_mamba_heads, ssm_state_size)
self.ssm_states += [torch.zeros(cache_shape, device=device, dtype=dtype)]
if self.layers_block_type[i] == "hybrid":
self.transformer_layers.append(i)

Expand Down

0 comments on commit 8d29964

Please sign in to comment.