diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index ed985ac5f0ded6..8b7785117eefe2 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -438,7 +438,7 @@ def forward( query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - softmax_scale = 1 / (query_states.shape[-1] / 2) ** 0.5 + softmax_scale = 1 / math.sqrt(self.head_dim / 2) attn_output = _flash_attention_forward( query_states,