Skip to content

Commit

Permalink
Fix GPU OOM for mistral.py::Mask4DTestHard (#31212)
Browse files Browse the repository at this point in the history
* build

* build

* build

* build

---------

Co-authored-by: ydshieh <[email protected]>
  • Loading branch information
ydshieh and ydshieh authored Jun 3, 2024
1 parent df5abae commit 8a1a23a
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions tests/models/mistral/test_modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,15 +734,24 @@ def test_compile_static_cache(self):
@slow
@require_torch_gpu
class Mask4DTestHard(unittest.TestCase):
model_name = "mistralai/Mistral-7B-v0.1"
_model = None

def tearDown(self):
gc.collect()
torch.cuda.empty_cache()

@property
def model(self):
if self.__class__._model is None:
self.__class__._model = MistralForCausalLM.from_pretrained(
self.model_name, torch_dtype=self.model_dtype
).to(torch_device)
return self.__class__._model

def setUp(self):
model_name = "mistralai/Mistral-7B-v0.1"
self.model_dtype = torch.float32
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
self.model = MistralForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device)
self.model_dtype = torch.float16
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=False)

def get_test_data(self):
template = "my favorite {}"
Expand Down

0 comments on commit 8a1a23a

Please sign in to comment.