From c99f25476312521d4425335f970b198da42f832d Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Thu, 7 Dec 2023 13:34:43 +0100 Subject: [PATCH] Fix device of masks in tests (#27887) fix device of mask in tests --- tests/models/llama/test_modeling_llama.py | 2 +- tests/models/mistral/test_modeling_mistral.py | 2 +- tests/models/persimmon/test_modeling_persimmon.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 1de49bf19eebb1..26aecaeb1ad9b7 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -104,7 +104,7 @@ def prepare_config_and_inputs(self): input_mask = None if self.use_input_mask: - input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)) + input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device) token_type_ids = None if self.use_token_type_ids: diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index e59a1726b0735d..fcb1f2495aab8a 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -107,7 +107,7 @@ def prepare_config_and_inputs(self): input_mask = None if self.use_input_mask: - input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)) + input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device) token_type_ids = None if self.use_token_type_ids: diff --git a/tests/models/persimmon/test_modeling_persimmon.py b/tests/models/persimmon/test_modeling_persimmon.py index e25a31b3ff4f0d..864db992772772 100644 --- a/tests/models/persimmon/test_modeling_persimmon.py +++ b/tests/models/persimmon/test_modeling_persimmon.py @@ -104,7 +104,7 @@ def prepare_config_and_inputs(self): input_mask = None if self.use_input_mask: - input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)) + input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device) token_type_ids = None if self.use_token_type_ids: