diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index 60fea55d87be5d..683b683008f2da 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -285,9 +285,9 @@ def forward( shape_q = (*query_states.shape[:-1], -1, self.head_dim) shape_kv = (*key_states.shape[:-1], -1, self.head_dim) - query_states = query_states.reshape(shape_q).transpose(1, 2) - key_states = key_states.reshape(shape_kv).transpose(1, 2) - value_states = value_states.reshape(shape_kv).transpose(1, 2) + query_states = query_states.view(shape_q).transpose(1, 2) + key_states = key_states.view(shape_kv).transpose(1, 2) + value_states = value_states.view(shape_kv).transpose(1, 2) if layer_past is not None: past_key, past_value = layer_past