From 816f4424964c1a1631e303b663fc3d68f731e923 Mon Sep 17 00:00:00 2001 From: Matthew Hoffman Date: Fri, 18 Oct 2024 09:15:26 -0700 Subject: [PATCH] Only cast logits to float when computing loss (#34147) * Only cast logits to float when computing loss Some misses from #31292 and #33902 * Move logits.float() into existing if labels is not None branch --- src/transformers/models/chameleon/modeling_chameleon.py | 3 ++- src/transformers/models/granite/modeling_granite.py | 3 ++- src/transformers/models/granitemoe/modeling_granitemoe.py | 3 ++- src/transformers/models/idefics3/modeling_idefics3.py | 3 ++- src/transformers/models/paligemma/modeling_paligemma.py | 3 ++- src/transformers/models/phimoe/modeling_phimoe.py | 8 +------- src/transformers/models/qwen2_vl/modeling_qwen2_vl.py | 3 ++- .../models/recurrent_gemma/modeling_recurrent_gemma.py | 3 ++- src/transformers/models/zamba/modeling_zamba.py | 8 +------- 9 files changed, 16 insertions(+), 21 deletions(-) diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index d0b964a7a6f484..797908277930cf 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -1602,7 +1602,6 @@ def forward( hidden_states = outputs[0] logits = self.lm_head(hidden_states) - logits = logits.float() # Disallow image tokens which does not include special begin-image and end-image tokens image_tokens = self.model.vocabulary_mapping.image_tokens @@ -1610,6 +1609,8 @@ def forward( loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index bb8c157df30c89..50c5b538af306c 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -1101,10 +1101,11 @@ def forward( hidden_states = outputs[0] logits = self.lm_head(hidden_states) logits = logits / self.config.logits_scaling - logits = logits.float() loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index f3e2d67734a703..07b42822621a3e 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -1345,10 +1345,11 @@ def forward( hidden_states = outputs[0] logits = self.lm_head(hidden_states) logits = logits / self.config.logits_scaling - logits = logits.float() loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() diff --git a/src/transformers/models/idefics3/modeling_idefics3.py b/src/transformers/models/idefics3/modeling_idefics3.py index fb9f0a7c58fa5a..748eda8c026377 100644 --- a/src/transformers/models/idefics3/modeling_idefics3.py +++ b/src/transformers/models/idefics3/modeling_idefics3.py @@ -1210,10 +1210,11 @@ def forward( hidden_states = outputs[0] logits = self.lm_head(hidden_states) - logits = logits.float() loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() labels = labels.to(logits.device) # Shift so that tokens < n predict n if attention_mask is not None: diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index 0eb2d50e0ad4c4..1607261eaac673 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -526,9 +526,10 @@ def forward( ) logits = outputs.logits - logits = logits.float() loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() shift_logits = logits[..., :-1, :] shift_labels = labels[..., 1:] if attention_mask is not None: diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index 9e24e59c64c2fe..e96eae799cda88 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -38,7 +38,6 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, - is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -1463,13 +1462,8 @@ def forward( ) hidden_states = outputs[0] - if labels is None and not is_torchdynamo_compiling(): - logger.warning_once( - "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" - ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - # TODO: remove the float() operation in v4.46 - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) loss = None if labels is not None: diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 5464b40546498a..f4cb84a2444eb6 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -1760,10 +1760,11 @@ def forward( hidden_states = outputs[0] logits = self.lm_head(hidden_states) - logits = logits.float() loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py index 17744188d40178..d3164b17fe130c 100644 --- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py @@ -870,9 +870,10 @@ def forward( cap = self.config.logits_soft_cap logits = nn.functional.tanh(logits / cap) * cap - logits = logits.float() loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index 8a61c15e30a0a9..921d07f287dca5 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -51,7 +51,6 @@ from ...utils.import_utils import ( is_causal_conv1d_available, is_mamba_ssm_available, - is_torchdynamo_compiling, ) from .configuration_zamba import ZambaConfig @@ -1473,13 +1472,8 @@ def forward( ) hidden_states = outputs[0] - if labels is None and not is_torchdynamo_compiling(): - logger.warning_once( - "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" - ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - # TODO: remove the float() operation in v4.46 - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) loss = None if labels is not None: