diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index efa7b3d9ec974f..ed985ac5f0ded6 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -1002,9 +1002,9 @@ def forward( class HybridLayer(nn.Module): - def __init__(self, shared_transf: ZambaAttentionDecoderLayer, linear: nn.Linear, mamba: ZambaMambaDecoderLayer): + def __init__(self, shared_transformer: ZambaAttentionDecoderLayer, linear: nn.Linear, mamba: ZambaMambaDecoderLayer): super().__init__() - self.shared_transf = shared_transf + self.shared_transformer = shared_transformer self.linear = linear self.mamba = mamba @@ -1040,7 +1040,7 @@ def forward( Indices depicting the position of the input sequence tokens in the sequence. """ - layer_outputs = self.shared_transf( + layer_outputs = self.shared_transformer( hidden_states, original_hidden_states=original_hidden_states, layer_idx=layer_idx,