diff --git a/tests/models/rag/test_modeling_rag.py b/tests/models/rag/test_modeling_rag.py index 3e3f7b9c457589..b219d5c74edff0 100644 --- a/tests/models/rag/test_modeling_rag.py +++ b/tests/models/rag/test_modeling_rag.py @@ -33,7 +33,7 @@ require_sentencepiece, require_tokenizers, require_torch, - require_torch_non_multi_gpu, + require_torch_non_multi_accelerator, slow, torch_device, ) @@ -678,7 +678,7 @@ def config_and_inputs(self): @require_retrieval @require_sentencepiece @require_tokenizers -@require_torch_non_multi_gpu +@require_torch_non_multi_accelerator class RagModelIntegrationTests(unittest.TestCase): def tearDown(self): super().tearDown() @@ -1002,7 +1002,7 @@ def test_rag_token_generate_batch(self): torch_device ) - if torch_device == "cuda": + if torch_device != "cpu": rag_token.half() input_dict = tokenizer(