diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index 9d3570bd4333c3..62197368383caf 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -734,15 +734,24 @@ def test_compile_static_cache(self): @slow @require_torch_gpu class Mask4DTestHard(unittest.TestCase): + model_name = "mistralai/Mistral-7B-v0.1" + _model = None + def tearDown(self): gc.collect() torch.cuda.empty_cache() + @property + def model(self): + if self.__class__._model is None: + self.__class__._model = MistralForCausalLM.from_pretrained( + self.model_name, torch_dtype=self.model_dtype + ).to(torch_device) + return self.__class__._model + def setUp(self): - model_name = "mistralai/Mistral-7B-v0.1" - self.model_dtype = torch.float32 - self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) - self.model = MistralForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device) + self.model_dtype = torch.float16 + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=False) def get_test_data(self): template = "my favorite {}"