diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 5fb155775a2f53..3b280930c0cbdf 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -613,15 +613,18 @@ def forward( # cast them back in float16 just to be sure everything works as expected. input_dtype = query_layer.dtype if input_dtype == torch.float32: + # Handle the case where the model is quantized + target_dtype = getattr(self.config, "_pre_quantization_dtype", self.query_key_value.weight.dtype) + logger.warning_once( - "The input hidden states seems to be silently casted in float32, this might be related to" - " the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - " float16." + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." ) - query_layer = query_layer.to(torch.float16) - key_layer = key_layer.to(torch.float16) - value_layer = value_layer.to(torch.float16) + query_layer = query_layer.to(target_dtype) + key_layer = key_layer.to(target_dtype) + value_layer = value_layer.to(target_dtype) attn_output = self._flash_attention_forward( query_layer, key_layer, value_layer, padding_mask, query_length, dropout=attn_dropout diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index b697387f5f5b36..b0767cf4967dca 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -469,20 +469,24 @@ def forward( # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. + # cast them back in the correct dtype just to be sure everything works as expected. # This might slowdown training & inference so it is recommended to not cast the LayerNorms # in fp32. (LlamaRMSNorm handles it correctly) + input_dtype = query_states.dtype if input_dtype == torch.float32: + # Handle the case where the model is quantized + target_dtype = getattr(self.config, "_pre_quantization_dtype", self.q_proj.weight.dtype) + logger.warning_once( - "The input hidden states seems to be silently casted in float32, this might be related to" - " the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - " float16." + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." ) - query_states = query_states.to(torch.float16) - key_states = key_states.to(torch.float16) - value_states = value_states.to(torch.float16) + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) attn_output = self._flash_attention_forward( query_states, key_states, value_states, padding_mask, q_len, dropout=dropout_rate diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 28d7b914d6f80e..f636e514823e7d 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -408,15 +408,18 @@ def forward( # cast them back in float16 just to be sure everything works as expected. input_dtype = query_states.dtype if input_dtype == torch.float32: + # Handle the case where the model is quantized + target_dtype = getattr(self.config, "_pre_quantization_dtype", self.q_proj.weight.dtype) + logger.warning_once( - "The input hidden states seems to be silently casted in float32, this might be related to" - " the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - " float16." + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." ) - query_states = query_states.to(torch.float16) - key_states = key_states.to(torch.float16) - value_states = value_states.to(torch.float16) + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) # Reashape to the expected shape for Flash Attention query_states = query_states.transpose(1, 2) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 019650a98ef78a..34f5bae3746f03 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -64,6 +64,7 @@ is_pt_flax_cross_test, is_pt_tf_cross_test, require_accelerate, + require_bitsandbytes, require_flash_attn, require_safetensors, require_torch, @@ -2959,6 +2960,45 @@ def test_flash_attn_2_generate_use_cache(self): dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=30, do_sample=False ) + @require_flash_attn + @require_torch_gpu + @require_bitsandbytes + @mark.flash_attn_test + @slow + def test_flash_attn_2_fp32_ln(self): + import torch + + for model_class in self.all_generative_model_classes: + if not model_class._supports_flash_attn_2: + return + + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device) + dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [0, 1, 1, 1]]).to(torch_device) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + use_flash_attention_2=True, + low_cpu_mem_usage=True, + load_in_4bit=True, + ) + + for _, param in model.named_parameters(): + # upcast only layer norms + if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16): + param.data = param.data.to(torch.float32) + + _ = model(input_ids=dummy_input) + + # with attention mask + _ = model(input_ids=dummy_input, attention_mask=dummy_attention_mask) + global_rng = random.Random()