diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 8f13d4aa230085..bf06d9c4053822 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -207,29 +207,21 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil # if no floating dtype was found return whatever the first dtype is return last_dtype - for t in parameter.buffers(): - last_dtype = t.dtype - if t.is_floating_point(): - return t.dtype - - if last_dtype is not None: - # if no floating dtype was found return whatever the first dtype is - return last_dtype - - # For nn.DataParallel compatibility in PyTorch > 1.5 - def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: - tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] - return tuples - - gen = parameter._named_members(get_members_fn=find_tensor_attributes) - last_tuple = None - for tuple in gen: - last_tuple = tuple - if tuple[1].is_floating_point(): - return tuple[1].dtype + else: + # For nn.DataParallel compatibility in PyTorch > 1.5 + def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples - # fallback to the last dtype - return last_tuple[1].dtype + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + last_tuple = None + for tuple in gen: + last_tuple = tuple + if tuple[1].is_floating_point(): + return tuple[1].dtype + + # fallback to the last dtype + return last_tuple[1].dtype def get_state_dict_float_dtype(state_dict):