diff --git a/src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py index 0a6aff8c819922..5613f83a86b4e7 100644 --- a/src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py +++ b/src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py @@ -94,8 +94,17 @@ def set_recursively(key, value, full_name, weight_type, hf_pointer): hf_param_name = PARAM_MAPPING[full_name.split(".")[-1]] weight_type = "param" + # fairseq uses nn.utils.weight_norm() while transformers switches to nn.utils.parametrizations.weight_norm() + # the mapping between two versions: + # https://github.com/pytorch/pytorch/blob/56935684c3dfad7841c83c719eeebecb560fe466/torch/nn/utils/parametrizations.py#L389-L395 + if weight_type is not None and weight_type != "param": - hf_shape = getattr(hf_pointer, weight_type).shape + if weight_type == "weight_g" and not hasattr(hf_pointer, "weight_g"): + hf_shape = hf_pointer.parametrizations.weight.original0.shape + elif weight_type == "weight_v" and not hasattr(hf_pointer, "weight_v"): + hf_shape = hf_pointer.parametrizations.weight.original1.shape + else: + hf_shape = getattr(hf_pointer, weight_type).shape elif weight_type is not None and weight_type == "param": shape_pointer = hf_pointer for attribute in hf_param_name.split("."): @@ -116,9 +125,15 @@ def set_recursively(key, value, full_name, weight_type, hf_pointer): if weight_type == "weight": hf_pointer.weight.data = value elif weight_type == "weight_g": - hf_pointer.weight_g.data = value + if hasattr(hf_pointer, "weight_g"): + hf_pointer.weight_g.data = value + else: + hf_pointer.parametrizations.weight.original0.data = value elif weight_type == "weight_v": - hf_pointer.weight_v.data = value + if hasattr(hf_pointer, "weight_v"): + hf_pointer.weight_v.data = value + else: + hf_pointer.parametrizations.weight.original1.data = value elif weight_type == "bias": hf_pointer.bias.data = value elif weight_type == "param":