From 74cae670ce542b62c44a5603f0675ff31932793c Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Fri, 15 Dec 2023 09:45:31 -0500 Subject: [PATCH] Make GPT2 traceable in meta state (#28054) * Put device in tensor constructor instead of to() * Fix copy --- .../decision_transformer/modeling_decision_transformer.py | 2 +- src/transformers/models/gpt2/modeling_gpt2.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index d07a25c8915877..fdfb5b37d22e62 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -185,7 +185,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): mask_value = torch.finfo(attn_weights.dtype).min # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` - mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device) + mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device) attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value) if attention_mask is not None: diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index bc95c774039ffc..494aecaeabe1e3 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -198,7 +198,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): mask_value = torch.finfo(attn_weights.dtype).min # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` - mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device) + mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device) attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value) if attention_mask is not None: