Skip to content

Commit

Permalink
[FA-2] Final fix for FA2 dtype (#26846)
Browse files Browse the repository at this point in the history
* final fix for FA2 dtype

* try

* oops

* Update src/transformers/models/falcon/modeling_falcon.py

Co-authored-by: Arthur <[email protected]>

* apply fix everywhere

---------

Co-authored-by: Arthur <[email protected]>
  • Loading branch information
younesbelkada and ArthurZucker authored Oct 18, 2023
1 parent 732d2a8 commit 5a73316
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 19 deletions.
15 changes: 9 additions & 6 deletions src/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,15 +613,18 @@ def forward(
# cast them back in float16 just to be sure everything works as expected.
input_dtype = query_layer.dtype
if input_dtype == torch.float32:
# Handle the case where the model is quantized
target_dtype = getattr(self.config, "_pre_quantization_dtype", self.query_key_value.weight.dtype)

logger.warning_once(
"The input hidden states seems to be silently casted in float32, this might be related to"
" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
" float16."
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)

query_layer = query_layer.to(torch.float16)
key_layer = key_layer.to(torch.float16)
value_layer = value_layer.to(torch.float16)
query_layer = query_layer.to(target_dtype)
key_layer = key_layer.to(target_dtype)
value_layer = value_layer.to(target_dtype)

attn_output = self._flash_attention_forward(
query_layer, key_layer, value_layer, padding_mask, query_length, dropout=attn_dropout
Expand Down
18 changes: 11 additions & 7 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,20 +469,24 @@ def forward(

# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected.
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly)

input_dtype = query_states.dtype
if input_dtype == torch.float32:
# Handle the case where the model is quantized
target_dtype = getattr(self.config, "_pre_quantization_dtype", self.q_proj.weight.dtype)

logger.warning_once(
"The input hidden states seems to be silently casted in float32, this might be related to"
" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
" float16."
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)

query_states = query_states.to(torch.float16)
key_states = key_states.to(torch.float16)
value_states = value_states.to(torch.float16)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)

attn_output = self._flash_attention_forward(
query_states, key_states, value_states, padding_mask, q_len, dropout=dropout_rate
Expand Down
15 changes: 9 additions & 6 deletions src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,15 +408,18 @@ def forward(
# cast them back in float16 just to be sure everything works as expected.
input_dtype = query_states.dtype
if input_dtype == torch.float32:
# Handle the case where the model is quantized
target_dtype = getattr(self.config, "_pre_quantization_dtype", self.q_proj.weight.dtype)

logger.warning_once(
"The input hidden states seems to be silently casted in float32, this might be related to"
" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
" float16."
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)

query_states = query_states.to(torch.float16)
key_states = key_states.to(torch.float16)
value_states = value_states.to(torch.float16)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)

# Reashape to the expected shape for Flash Attention
query_states = query_states.transpose(1, 2)
Expand Down
40 changes: 40 additions & 0 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
is_pt_flax_cross_test,
is_pt_tf_cross_test,
require_accelerate,
require_bitsandbytes,
require_flash_attn,
require_safetensors,
require_torch,
Expand Down Expand Up @@ -2959,6 +2960,45 @@ def test_flash_attn_2_generate_use_cache(self):
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=30, do_sample=False
)

@require_flash_attn
@require_torch_gpu
@require_bitsandbytes
@mark.flash_attn_test
@slow
def test_flash_attn_2_fp32_ln(self):
import torch

for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
return

config, _ = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)

with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)

dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [0, 1, 1, 1]]).to(torch_device)

model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
use_flash_attention_2=True,
low_cpu_mem_usage=True,
load_in_4bit=True,
)

for _, param in model.named_parameters():
# upcast only layer norms
if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16):
param.data = param.data.to(torch.float32)

_ = model(input_ids=dummy_input)

# with attention mask
_ = model(input_ids=dummy_input, attention_mask=dummy_attention_mask)


global_rng = random.Random()

Expand Down

0 comments on commit 5a73316

Please sign in to comment.