diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index cc57d7bc41846f..68912432247bfa 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -490,8 +490,8 @@ def forward( value_states = self.v_proj(hidden_states) # Flash attention requires the input to have the shape - # batch_size x seq_length x num_heads x head_dim - # but rotary embeddings require batch_size x num_heads x seq_length x head_dim + # batch_size, seq_length, num_heads, head_dim + # but rotary embeddings require batch_size, num_heads, seq_length, head_dim query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)