From 010566e2f600c2eace5290c0542112ad856f68a8 Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Tue, 26 Dec 2023 08:36:41 +0530 Subject: [PATCH] fix FA2 when using quantization (#28203) --- src/transformers/models/falcon/modeling_falcon.py | 6 +++--- src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py | 6 +++--- src/transformers/models/llama/modeling_llama.py | 6 +++--- src/transformers/models/mistral/modeling_mistral.py | 6 +++--- src/transformers/models/mixtral/modeling_mixtral.py | 6 +++--- 5 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index a92650ec219933..7c2c63f5c5f5e2 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -617,11 +617,11 @@ def forward( # cast them back in float16 just to be sure everything works as expected. input_dtype = query_layer.dtype if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() # Handle the case where the model is quantized - if hasattr(self.config, "_pre_quantization_dtype"): + elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype - elif torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() else: target_dtype = self.query_key_value.weight.dtype diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 468916c28cb935..71f709e3e15d57 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -375,11 +375,11 @@ def forward( # cast them back in float16 just to be sure everything works as expected. input_dtype = query.dtype if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() # Handle the case where the model is quantized - if hasattr(self.config, "_pre_quantization_dtype"): + elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype - elif torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() else: target_dtype = self.c_attn.weight.dtype diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 2cbe3bc9592131..5f54fea8c40a9d 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -528,11 +528,11 @@ def forward( input_dtype = query_states.dtype if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() # Handle the case where the model is quantized - if hasattr(self.config, "_pre_quantization_dtype"): + elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype - elif torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() else: target_dtype = self.q_proj.weight.dtype diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index a83c31bce47d42..c8b5c211777472 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -428,11 +428,11 @@ def forward( # cast them back in float16 just to be sure everything works as expected. input_dtype = query_states.dtype if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() # Handle the case where the model is quantized - if hasattr(self.config, "_pre_quantization_dtype"): + elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype - elif torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() else: target_dtype = self.q_proj.weight.dtype diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 0f6985c30855ea..a32b9b1457a5b9 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -477,11 +477,11 @@ def forward( # cast them back in float16 just to be sure everything works as expected. input_dtype = query_states.dtype if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() # Handle the case where the model is quantized - if hasattr(self.config, "_pre_quantization_dtype"): + elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype - elif torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() else: target_dtype = self.q_proj.weight.dtype