Skip to content

Commit

Permalink
Support BatchNorm in Hubert pos_conv_emb as in fairseq (#34389)
Browse files Browse the repository at this point in the history
* Support BatchNorm in Hubert pos_conv_emb as in fairseq

* Correct the new defaults (#34377)

* Correct the new defaults

* CIs

* add check

* Update utils.py

* Update utils.py

* Add the max_length in generate test checking shape without passing length

* style

* CIs

* fix fx CI issue

* [auto. ping] Avoid sending empty info + add more team members (#34383)

* update

* update

---------

Co-authored-by: ydshieh <[email protected]>

* Fix glm  (#34388)

* Fix duplicated

* fix import

* Use non nested images and batched text Idefics2/3  (#34222)

* add support for non nested images and add tests

* add tests error scenario

* fix style

* added single and no image to error tests

* Fix onnx non-expotable inplace aten op (#34376)

* fix onnx non-expotable inplace op

* mistral, qwen2, qwen2_vl, starcoder2

* fixup copies

* Fix right padding in LLaVA models (#34305)

* fix right pad llavas

* device mismatch

* no filter (#34391)

* no filter

* no filter

* no filter

---------

Co-authored-by: ydshieh <[email protected]>

* SynthID: better example (#34372)

* better example

* Update src/transformers/generation/configuration_utils.py

* Update src/transformers/generation/logits_process.py

* nits

* Tests: upgrade `test_eager_matches_sdpa_generate` (#34386)

* Fix bnb training test failure (#34414)

* Fix bnb training test: compatibility with OPTSdpaAttention

* Avoid check expected exception when it is on CUDA (#34408)

* update

* update

---------

Co-authored-by: ydshieh <[email protected]>

* Fix typos in agents_advanced.md (#34405)

* [docs] Cache implementations (#34325)

cache

* [run-slow] hubert

* Support BatchNorm in Hubert pos_conv_emb as in fairseq
Add conversion integration test, and make batchnorm explicit variable

* Support BatchNorm in Hubert pos_conv_emb as in fairseq
fix make fixup styling changes

* [run-slow] hubert

* Support BatchNorm in Hubert pos_conv_emb as in fairseq

* [run-slow] hubert

* Support BatchNorm in Hubert pos_conv_emb as in fairseq
Add conversion integration test, and make batchnorm explicit variable

* Support BatchNorm in Hubert pos_conv_emb as in fairseq
fix make fixup styling changes

* [run-slow] hubert

* [run-slow] hubert

---------

Co-authored-by: Cyril Vallez <[email protected]>
Co-authored-by: Yih-Dar <[email protected]>
Co-authored-by: ydshieh <[email protected]>
Co-authored-by: Yoni Gozlan <[email protected]>
Co-authored-by: Ilyas Moutawwakil <[email protected]>
Co-authored-by: Raushan Turganbay <[email protected]>
Co-authored-by: Joao Gante <[email protected]>
Co-authored-by: Matthew Douglas <[email protected]>
Co-authored-by: Rudy Delouya <[email protected]>
Co-authored-by: Steven Liu <[email protected]>
Co-authored-by: Yoach Lacombe <[email protected]>
  • Loading branch information
12 people authored Dec 10, 2024
1 parent 80f2b16 commit 6acb4e4
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 19 deletions.
4 changes: 4 additions & 0 deletions src/transformers/models/hubert/configuration_hubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ class HubertConfig(PretrainedConfig):
embeddings layer.
num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16):
Number of groups of 1D convolutional positional embeddings layer.
conv_pos_batch_norm (`bool`, *optional*, defaults to `False`):
Whether to use batch norm instead of weight norm in conv_pos
do_stable_layer_norm (`bool`, *optional*, defaults to `False`):
Whether do apply *stable* layer norm architecture of the Transformer encoder. `do_stable_layer_norm is
True` corresponds to applying layer norm before the attention layer, whereas `do_stable_layer_norm is
Expand Down Expand Up @@ -182,6 +184,7 @@ def __init__(
conv_bias=False,
num_conv_pos_embeddings=128,
num_conv_pos_embedding_groups=16,
conv_pos_batch_norm=False,
do_stable_layer_norm=False,
apply_spec_augment=True,
mask_time_prob=0.05,
Expand Down Expand Up @@ -209,6 +212,7 @@ def __init__(
self.conv_bias = conv_bias
self.num_conv_pos_embeddings = num_conv_pos_embeddings
self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
self.conv_pos_batch_norm = conv_pos_batch_norm
self.num_feat_extract_layers = len(self.conv_dim)
self.num_hidden_layers = num_hidden_layers
self.intermediate_size = intermediate_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@

MAPPING = {
"post_extract_proj": "feature_projection.projection",
"encoder.pos_conv.0": "encoder.pos_conv_embed.conv",
"encoder.pos_conv.0": "encoder.pos_conv_embed.batch_norm",
"encoder.pos_conv.1": "encoder.pos_conv_embed.conv",
"self_attn.k_proj": "encoder.layers.*.attention.k_proj",
"self_attn.v_proj": "encoder.layers.*.attention.v_proj",
"self_attn.q_proj": "encoder.layers.*.attention.q_proj",
Expand Down Expand Up @@ -76,6 +77,12 @@ def set_recursively(hf_pointer, key, value, full_name, weight_type):
hf_pointer.weight_v.data = value
elif weight_type == "bias":
hf_pointer.bias.data = value
elif weight_type == "running_mean":
hf_pointer.running_mean.data = value
elif weight_type == "running_var":
hf_pointer.running_var.data = value
elif weight_type == "num_batches_tracked":
hf_pointer.num_batches_tracked.data = value
else:
hf_pointer.data = value

Expand Down Expand Up @@ -116,6 +123,12 @@ def recursively_load_weights(fairseq_model, hf_model, is_finetuned):
weight_type = "weight"
elif "bias" in name:
weight_type = "bias"
elif "running_mean" in name:
weight_type = "running_mean"
elif "running_var" in name:
weight_type = "running_var"
elif "num_batches_tracked" in name:
weight_type = "num_batches_tracked"
else:
weight_type = None
set_recursively(hf_model, mapped_key, value, name, weight_type)
Expand Down
40 changes: 22 additions & 18 deletions src/transformers/models/hubert/modeling_hubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,6 @@ def forward(self, hidden_states):
return hidden_states


# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->Hubert
class HubertPositionalConvEmbedding(nn.Module):
def __init__(self, config):
super().__init__()
Expand All @@ -272,32 +271,37 @@ def __init__(self, config):
groups=config.num_conv_pos_embedding_groups,
)

weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm
self.batch_norm = None
if config.conv_pos_batch_norm:
self.batch_norm = nn.BatchNorm1d(config.hidden_size)
else:
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm

if is_deepspeed_zero3_enabled():
import deepspeed
if is_deepspeed_zero3_enabled():
import deepspeed

with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
self.conv = weight_norm(self.conv, name="weight", dim=2)
if hasattr(self.conv, "parametrizations"):
weight_g = self.conv.parametrizations.weight.original0
weight_v = self.conv.parametrizations.weight.original1
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
self.conv = weight_norm(self.conv, name="weight", dim=2)
if hasattr(self.conv, "parametrizations"):
weight_g = self.conv.parametrizations.weight.original0
weight_v = self.conv.parametrizations.weight.original1
else:
weight_g = self.conv.weight_g
weight_v = self.conv.weight_v
deepspeed.zero.register_external_parameter(self, weight_v)
deepspeed.zero.register_external_parameter(self, weight_g)
else:
weight_g = self.conv.weight_g
weight_v = self.conv.weight_v
deepspeed.zero.register_external_parameter(self, weight_v)
deepspeed.zero.register_external_parameter(self, weight_g)
else:
self.conv = weight_norm(self.conv, name="weight", dim=2)
self.conv = weight_norm(self.conv, name="weight", dim=2)

self.padding = HubertSamePadLayer(config.num_conv_pos_embeddings)
self.activation = ACT2FN[config.feat_extract_activation]

def forward(self, hidden_states):
hidden_states = hidden_states.transpose(1, 2)

if self.batch_norm is not None:
hidden_states = self.batch_norm(hidden_states)
hidden_states = self.conv(hidden_states)
hidden_states = self.padding(hidden_states)
hidden_states = self.activation(hidden_states)
Expand Down
37 changes: 37 additions & 0 deletions tests/models/hubert/test_modeling_hubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,3 +943,40 @@ def test_inference_distilhubert(self):
self.assertTrue(torch.allclose(outputs[:, :4, :4], expected_outputs_first, atol=5e-3))
self.assertTrue(torch.allclose(outputs[:, -4:, -4:], expected_outputs_last, atol=5e-3))
self.assertTrue(abs(outputs.sum() - expected_output_sum) < 0.1)

def test_inference_hubert_25hz(self):
model = HubertModel.from_pretrained("slprl/mhubert-base-25hz").to(torch_device)

sample = self._load_datasamples(1)
input_speech = torch.tensor(sample[0], dtype=torch.float, device=torch_device).unsqueeze(0)

with torch.no_grad():
outputs = model(input_speech, output_hidden_states=True).hidden_states[11]

# expected outputs taken from the original textlesslib implementation by:
# model = SpeechEncoder.by_name(dense_model_name='mhubert-base-25hz', quantizer_model_name='kmeans',
# vocab_size=500, deduplicate=False, need_f0=False)
# model(wav)['dense']
expected_outputs_first = torch.tensor(
[
[0.0267, 0.1776, -0.1706, -0.4559],
[-0.2430, -0.2943, -0.1864, -0.1187],
[-0.1812, -0.4239, -0.1916, -0.0858],
[-0.1495, -0.4758, -0.4036, 0.0302],
],
device=torch_device,
)
expected_outputs_last = torch.tensor(
[
[0.3366, -0.2734, -0.1415, -0.3055],
[0.2329, -0.3580, -0.1421, -0.3197],
[0.1631, -0.4301, -0.1965, -0.2956],
[0.3342, -0.2185, -0.2253, -0.2363],
],
device=torch_device,
)
expected_output_sum = 1681.7603

self.assertTrue(torch.allclose(outputs[:, :4, :4], expected_outputs_first, atol=5e-3))
self.assertTrue(torch.allclose(outputs[:, -4:, -4:], expected_outputs_last, atol=5e-3))
self.assertTrue(abs(outputs.sum() - expected_output_sum) < 0.1)

0 comments on commit 6acb4e4

Please sign in to comment.