Skip to content

Commit

Permalink
fix onnx export of speech foundation models (#34224)
Browse files Browse the repository at this point in the history
* added expanded attention/padding masks prior to indexing the hidden_states

* consistency fix in WavLMForSequenceClassification

---------

Co-authored-by: Nikos Antoniou <[email protected]>
  • Loading branch information
nikosanto13 and Nikos Antoniou authored Dec 20, 2024
1 parent f42084e commit ff9141b
Show file tree
Hide file tree
Showing 10 changed files with 31 additions and 17 deletions.
3 changes: 2 additions & 1 deletion src/transformers/models/data2vec/modeling_data2vec_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -1421,7 +1421,8 @@ def forward(
pooled_output = hidden_states.mean(dim=1)
else:
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
hidden_states[~padding_mask] = 0.0
expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_padding_mask] = 0.0
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)

logits = self.classifier(pooled_output)
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/hubert/modeling_hubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1629,7 +1629,8 @@ def forward(
pooled_output = hidden_states.mean(dim=1)
else:
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
hidden_states[~padding_mask] = 0.0
expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_padding_mask] = 0.0
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)

logits = self.classifier(pooled_output)
Expand Down
9 changes: 5 additions & 4 deletions src/transformers/models/sew/modeling_sew.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,15 +882,15 @@ def forward(
all_self_attentions = () if output_attentions else None

if attention_mask is not None:
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
if self._use_flash_attention_2:
# make sure padded tokens output 0
hidden_states[~attention_mask] = 0.0
hidden_states[~expand_attention_mask] = 0.0
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
# make sure padded tokens output 0
hidden_states[~attention_mask] = 0.0

hidden_states[~expand_attention_mask] = 0.0
input_lengths = (attention_mask.long()).sum(-1)
# apply pooling formula to get real output_lengths
output_lengths = input_lengths // self.config.squeeze_factor
Expand Down Expand Up @@ -1473,7 +1473,8 @@ def forward(
pooled_output = hidden_states.mean(dim=1)
else:
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
hidden_states[~padding_mask] = 0.0
expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_padding_mask] = 0.0
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)

logits = self.classifier(pooled_output)
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/sew_d/modeling_sew_d.py
Original file line number Diff line number Diff line change
Expand Up @@ -1175,7 +1175,8 @@ def forward(
)
else:
# make sure padded tokens output 0
hidden_states[~attention_mask.bool()] = 0.0
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask.bool()] = 0.0

input_lengths = (attention_mask.long()).sum(-1)
# apply pooling formula to get real output_lengths
Expand Down Expand Up @@ -1721,7 +1722,8 @@ def forward(
pooled_output = hidden_states.mean(dim=1)
else:
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
hidden_states[~padding_mask] = 0.0
expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_padding_mask] = 0.0
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)

logits = self.classifier(pooled_output)
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/unispeech/modeling_unispeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -1876,7 +1876,8 @@ def forward(
pooled_output = hidden_states.mean(dim=1)
else:
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
hidden_states[~padding_mask] = 0.0
expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_padding_mask] = 0.0
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)

logits = self.classifier(pooled_output)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1886,7 +1886,8 @@ def forward(
pooled_output = hidden_states.mean(dim=1)
else:
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
hidden_states[~padding_mask] = 0.0
expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_padding_mask] = 0.0
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)

logits = self.classifier(pooled_output)
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/wav2vec2/modeling_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2376,7 +2376,8 @@ def forward(
pooled_output = hidden_states.mean(dim=1)
else:
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
hidden_states[~padding_mask] = 0.0
expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_padding_mask] = 0.0
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)

logits = self.classifier(pooled_output)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1359,7 +1359,8 @@ def forward(
pooled_output = hidden_states.mean(dim=1)
else:
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
hidden_states[~padding_mask] = 0.0
expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_padding_mask] = 0.0
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)

logits = self.classifier(pooled_output)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -878,7 +878,8 @@ def forward(

if attention_mask is not None:
# make sure padded tokens output 0
hidden_states[~attention_mask] = 0.0
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0.0

# extend attention_mask
attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
Expand Down Expand Up @@ -1791,7 +1792,8 @@ def forward(
pooled_output = hidden_states.mean(dim=1)
else:
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
hidden_states[~padding_mask] = 0.0
expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_padding_mask] = 0.0
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)

logits = self.classifier(pooled_output)
Expand Down
9 changes: 6 additions & 3 deletions src/transformers/models/wavlm/modeling_wavlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,8 @@ def forward(

if attention_mask is not None:
# make sure padded tokens output 0
hidden_states[~attention_mask] = 0.0
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0

position_embeddings = self.pos_conv_embed(hidden_states)
hidden_states = hidden_states + position_embeddings
Expand Down Expand Up @@ -776,7 +777,8 @@ def forward(

if attention_mask is not None:
# make sure padded tokens are not attended to
hidden_states[~attention_mask] = 0
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0

position_embeddings = self.pos_conv_embed(hidden_states)
hidden_states = hidden_states + position_embeddings
Expand Down Expand Up @@ -1508,7 +1510,8 @@ def forward(
pooled_output = hidden_states.mean(dim=1)
else:
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
hidden_states[~padding_mask] = 0.0
expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_padding_mask] = 0.0
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)

logits = self.classifier(pooled_output)
Expand Down

0 comments on commit ff9141b

Please sign in to comment.