From 6529a5b5c13210b41bcd87c555c72696cd7083a5 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Tue, 6 Feb 2024 11:05:23 +0100 Subject: [PATCH] Fix `FastSpeech2ConformerModelTest` and skip it on CPU (#28888) * fix * fix --------- Co-authored-by: ydshieh --- .../fastspeech2_conformer/modeling_fastspeech2_conformer.py | 2 +- .../test_modeling_fastspeech2_conformer.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py b/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py index 9b8fa4ab004f29..cc57747c59a4be 100644 --- a/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +++ b/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py @@ -1256,7 +1256,7 @@ def forward( ) if attention_mask is None: - attention_mask = torch.ones(input_ids.shape) + attention_mask = torch.ones(input_ids.shape, device=input_ids.device) has_missing_labels = ( spectrogram_labels is None or duration_labels is None or pitch_labels is None or energy_labels is None diff --git a/tests/models/fastspeech2_conformer/test_modeling_fastspeech2_conformer.py b/tests/models/fastspeech2_conformer/test_modeling_fastspeech2_conformer.py index 19b13ac77ebfb4..7a4ec2c723b428 100644 --- a/tests/models/fastspeech2_conformer/test_modeling_fastspeech2_conformer.py +++ b/tests/models/fastspeech2_conformer/test_modeling_fastspeech2_conformer.py @@ -25,7 +25,7 @@ FastSpeech2ConformerWithHifiGanConfig, is_torch_available, ) -from transformers.testing_utils import require_g2p_en, require_torch, slow, torch_device +from transformers.testing_utils import require_g2p_en, require_torch, require_torch_accelerator, slow, torch_device from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor @@ -117,6 +117,7 @@ def prepare_config_and_inputs_for_common(self): return config, inputs_dict +@require_torch_accelerator @require_torch class FastSpeech2ConformerModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (FastSpeech2ConformerModel,) if is_torch_available() else ()