From 8a1a23ae4dfb022a2e483506db973ceed41f5fac Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Mon, 3 Jun 2024 19:25:15 +0200 Subject: [PATCH] Fix GPU OOM for `mistral.py::Mask4DTestHard` (#31212) * build * build * build * build --------- Co-authored-by: ydshieh --- tests/models/mistral/test_modeling_mistral.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index 9d3570bd4333c3..62197368383caf 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -734,15 +734,24 @@ def test_compile_static_cache(self): @slow @require_torch_gpu class Mask4DTestHard(unittest.TestCase): + model_name = "mistralai/Mistral-7B-v0.1" + _model = None + def tearDown(self): gc.collect() torch.cuda.empty_cache() + @property + def model(self): + if self.__class__._model is None: + self.__class__._model = MistralForCausalLM.from_pretrained( + self.model_name, torch_dtype=self.model_dtype + ).to(torch_device) + return self.__class__._model + def setUp(self): - model_name = "mistralai/Mistral-7B-v0.1" - self.model_dtype = torch.float32 - self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) - self.model = MistralForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device) + self.model_dtype = torch.float16 + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=False) def get_test_data(self): template = "my favorite {}"