From 21ffe37ec42c85a589596b25c00adb9d9a0d92ba Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Wed, 18 Oct 2023 19:16:16 +0200 Subject: [PATCH] Update src/transformers/models/falcon/modeling_falcon.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/models/falcon/modeling_falcon.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index a834131a27caaf..7e7c8d24bd6f4d 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -614,10 +614,7 @@ def forward( input_dtype = query_layer.dtype if input_dtype == torch.float32: # Handle the case where the model is quantized - if hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.query_key_value.weight.dtype + target_dtype = getattr(self.config, "_pre_quantization_dtype", self.query_key_value.weight.dtype) logger.warning_once( f"The input hidden states seems to be silently casted in float32, this might be related to"