Skip to content

Commit

Permalink
even more decorators
Browse files Browse the repository at this point in the history
  • Loading branch information
poedator committed Dec 13, 2023
1 parent 9ec1083 commit 561ba9a
Showing 1 changed file with 31 additions and 14 deletions.
45 changes: 31 additions & 14 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1858,21 +1858,14 @@ def test_not_available_sdpa(self):
@require_torch_gpu
@slow
class Mask4DTestBase(unittest.TestCase):
model_dtype = None

def setUp(self):
model_name = "JackFram/llama-68m" # small Llama-like model from FlexFlow
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device)

@require_torch
@require_torch_gpu
def tearDown(self):
r"""
TearDown function needs to be called at the end of each test to free the GPU memory and cache, also to
avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27
"""
gc.collect()
torch.cuda.empty_cache()

@require_torch
@require_torch_gpu
def get_test_data(self):
texts = ["the cat sat", "the cat had", "the cat is"]
encoded = [self.tokenizer.encode(t) for t in texts]
Expand Down Expand Up @@ -1913,8 +1906,17 @@ def get_test_data(self):
@require_torch_gpu
@slow
class Mask4DTestFP32(Mask4DTestBase):
model_dtype = torch.float32
@require_torch
@require_torch_gpu
def setUp(self):
model_name = "JackFram/llama-68m" # small Llama-like model from FlexFlow
model_dtype = torch.float32
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=model_dtype).to(torch_device)

@require_torch
@require_torch_gpu
@slow
def test_attention(self):
"""comparing outputs of attention layer"""
input_0, input_1, mask_1, position_ids_1 = self.get_test_data()
Expand All @@ -1933,6 +1935,9 @@ def test_attention(self):
outs_1_last_tokens = outs_1[0, -3:, :] # last three tokens
assert torch.allclose(outs_0_last_tokens, outs_1_last_tokens)

@require_torch
@require_torch_gpu
@slow
def test_inner_model(self):
"""comparing hidden outputs of whole inner model"""
input_0, input_1, mask_1, position_ids_1 = self.get_test_data()
Expand All @@ -1947,6 +1952,9 @@ def test_inner_model(self):
logits_1_last_tokens,
)

@require_torch
@require_torch_gpu
@slow
def test_causal_model_logits(self):
"""comparing logits outputs of whole inner model"""
input_0, input_1, mask_1, position_ids_1 = self.get_test_data()
Expand All @@ -1966,10 +1974,19 @@ def test_causal_model_logits(self):
@require_torch_gpu
@slow
class Mask4DTestFP16(Mask4DTestBase):
model_dtype = torch.float16

test_attention = Mask4DTestFP32.test_attention

@require_torch
@require_torch_gpu
def setUp(self):
model_name = "JackFram/llama-68m" # small Llama-like model from FlexFlow
model_dtype = torch.float16
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=model_dtype).to(torch_device)

@require_torch
@require_torch_gpu
@slow
def test_causal_model_logits(self):
"""comparing logits outputs of whole inner model"""
input_0, input_1, mask_1, position_ids_1 = self.get_test_data()
Expand Down

0 comments on commit 561ba9a

Please sign in to comment.