Skip to content

Commit

Permalink
fix FA2 when using quantization for remaining models (huggingface#28341
Browse files Browse the repository at this point in the history
)

* fix fa2 autocasting when using quantization

* Update src/transformers/models/distilbert/modeling_distilbert.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/models/distilbert/modeling_distilbert.py

Co-authored-by: Arthur <[email protected]>

---------

Co-authored-by: Arthur <[email protected]>
  • Loading branch information
2 people authored and wgifford committed Jan 21, 2024
1 parent 111f6c9 commit d1c9017
Show file tree
Hide file tree
Showing 8 changed files with 24 additions and 8 deletions.
4 changes: 3 additions & 1 deletion src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,8 +382,10 @@ def forward(

input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"):
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/distilbert/modeling_distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,10 @@ def reshape(x: torch.Tensor) -> torch.Tensor:
# in fp32. (LlamaRMSNorm handles it correctly)

if query_states.dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"):
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_lin.weight.dtype
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/gpt_neo/modeling_gpt_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,8 +357,10 @@ def forward(
# in fp32. (LlamaRMSNorm handles it correctly)

if query.dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"):
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/gpt_neox/modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,8 +384,10 @@ def forward(
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
input_dtype = query.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"):
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/mbart/modeling_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,8 +372,10 @@ def forward(

input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"):
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/opt/modeling_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,8 +363,10 @@ def forward(
# cast them back in float16 just to be sure everything works as expected.
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"):
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/phi/modeling_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,8 +484,10 @@ def forward(
# in fp32.

if query_states.dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"):
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/whisper/modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,8 +562,10 @@ def forward(

input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"):
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
Expand Down

0 comments on commit d1c9017

Please sign in to comment.