diff --git a/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py b/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py index 9b8fa4ab004f29..cc57747c59a4be 100644 --- a/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +++ b/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py @@ -1256,7 +1256,7 @@ def forward( ) if attention_mask is None: - attention_mask = torch.ones(input_ids.shape) + attention_mask = torch.ones(input_ids.shape, device=input_ids.device) has_missing_labels = ( spectrogram_labels is None or duration_labels is None or pitch_labels is None or energy_labels is None diff --git a/tests/models/fastspeech2_conformer/test_modeling_fastspeech2_conformer.py b/tests/models/fastspeech2_conformer/test_modeling_fastspeech2_conformer.py index 19b13ac77ebfb4..7a4ec2c723b428 100644 --- a/tests/models/fastspeech2_conformer/test_modeling_fastspeech2_conformer.py +++ b/tests/models/fastspeech2_conformer/test_modeling_fastspeech2_conformer.py @@ -25,7 +25,7 @@ FastSpeech2ConformerWithHifiGanConfig, is_torch_available, ) -from transformers.testing_utils import require_g2p_en, require_torch, slow, torch_device +from transformers.testing_utils import require_g2p_en, require_torch, require_torch_accelerator, slow, torch_device from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor @@ -117,6 +117,7 @@ def prepare_config_and_inputs_for_common(self): return config, inputs_dict +@require_torch_accelerator @require_torch class FastSpeech2ConformerModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (FastSpeech2ConformerModel,) if is_torch_available() else ()