From 0de3ff273f432b871345861ffdccf7a4c3a81953 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 18 Oct 2023 22:17:37 +0000 Subject: [PATCH] revert --- src/transformers/models/falcon/modeling_falcon.py | 5 ++++- src/transformers/models/llama/modeling_llama.py | 5 ++++- src/transformers/models/mistral/modeling_mistral.py | 5 ++++- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 3b280930c0cbdf..e9dca6df989472 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -614,7 +614,10 @@ def forward( input_dtype = query_layer.dtype if input_dtype == torch.float32: # Handle the case where the model is quantized - target_dtype = getattr(self.config, "_pre_quantization_dtype", self.query_key_value.weight.dtype) + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.query_key_value.weight.dtype logger.warning_once( f"The input hidden states seems to be silently casted in float32, this might be related to" diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index b0767cf4967dca..b67719ac327162 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -476,7 +476,10 @@ def forward( input_dtype = query_states.dtype if input_dtype == torch.float32: # Handle the case where the model is quantized - target_dtype = getattr(self.config, "_pre_quantization_dtype", self.q_proj.weight.dtype) + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype logger.warning_once( f"The input hidden states seems to be silently casted in float32, this might be related to" diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index f636e514823e7d..cfef5a427118e9 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -409,7 +409,10 @@ def forward( input_dtype = query_states.dtype if input_dtype == torch.float32: # Handle the case where the model is quantized - target_dtype = getattr(self.config, "_pre_quantization_dtype", self.q_proj.weight.dtype) + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype logger.warning_once( f"The input hidden states seems to be silently casted in float32, this might be related to"