Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FA-2] Final fix for FA2 dtype #26846

Merged
merged 6 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions src/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,15 +613,21 @@ def forward(
# cast them back in float16 just to be sure everything works as expected.
input_dtype = query_layer.dtype
if input_dtype == torch.float32:
# Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.query_key_value.weight.dtype
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved

logger.warning_once(
"The input hidden states seems to be silently casted in float32, this might be related to"
" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
" float16."
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)

query_layer = query_layer.to(torch.float16)
key_layer = key_layer.to(torch.float16)
value_layer = value_layer.to(torch.float16)
query_layer = query_layer.to(target_dtype)
key_layer = key_layer.to(target_dtype)
value_layer = value_layer.to(target_dtype)

attn_output = self._flash_attention_forward(
query_layer, key_layer, value_layer, padding_mask, query_length, dropout=attn_dropout
Expand Down
22 changes: 14 additions & 8 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,26 +469,32 @@ def forward(

# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected.
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly)

input_dtype = query_states.dtype
if input_dtype == torch.float32:
# Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once(
"The input hidden states seems to be silently casted in float32, this might be related to"
" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
" float16."
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)

query_states = query_states.to(torch.float16)
key_states = key_states.to(torch.float16)
value_states = value_states.to(torch.float16)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)

attn_output = self._flash_attention_forward(
query_states, key_states, value_states, padding_mask, q_len, dropout=dropout_rate
)

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)

if not output_attentions:
Expand Down
18 changes: 12 additions & 6 deletions src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,15 +408,21 @@ def forward(
# cast them back in float16 just to be sure everything works as expected.
input_dtype = query_states.dtype
if input_dtype == torch.float32:
# Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype

logger.warning_once(
"The input hidden states seems to be silently casted in float32, this might be related to"
" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
" float16."
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
Comment on lines 414 to 418
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove this now no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I think we need to keep it to inform users about that


query_states = query_states.to(torch.float16)
key_states = key_states.to(torch.float16)
value_states = value_states.to(torch.float16)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)

# Reashape to the expected shape for Flash Attention
query_states = query_states.transpose(1, 2)
Expand Down
40 changes: 40 additions & 0 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
is_pt_flax_cross_test,
is_pt_tf_cross_test,
require_accelerate,
require_bitsandbytes,
require_flash_attn,
require_safetensors,
require_torch,
Expand Down Expand Up @@ -2955,6 +2956,45 @@ def test_flash_attn_2_generate_use_cache(self):
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=30, do_sample=False
)

@require_flash_attn
@require_torch_gpu
@require_bitsandbytes
@mark.flash_attn_test
@slow
def test_flash_attn_2_fp32_ln(self):
import torch

for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
return

config, _ = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)

with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)

dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [0, 1, 1, 1]]).to(torch_device)

model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
use_flash_attention_2=True,
low_cpu_mem_usage=True,
load_in_4bit=True,
)

for _, param in model.named_parameters():
# upcast only layer norms
if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16):
param.data = param.data.to(torch.float32)

_ = model(input_ids=dummy_input)

# with attention mask
_ = model(input_ids=dummy_input, attention_mask=dummy_attention_mask)


global_rng = random.Random()

Expand Down
Loading