Skip to content

Commit

Permalink
remove fa2 test
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker committed Dec 19, 2024
1 parent 56ff1e9 commit c946ed3
Showing 1 changed file with 0 additions and 30 deletions.
30 changes: 0 additions & 30 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2769,8 +2769,6 @@ def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-4, n
attributes = tuple([f"{name}_{idx}" for idx in range(len(fx_outputs))])

for fx_output, pt_output, attr in zip(fx_outputs, pt_outputs, attributes):
if isinstance(pt_output, DynamicCache):
pt_output = pt_output.to_legacy_cache()
self.check_pt_flax_outputs(fx_output, pt_output, model_class, tol=tol, name=attr)

elif isinstance(fx_outputs, jnp.ndarray):
Expand Down Expand Up @@ -3610,34 +3608,6 @@ def test_model_is_small(self):
num_params < 1000000
), f"{model_class} is too big for the common tests ({num_params})! It should have 1M max."

@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attn_2_conversion(self):
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")

config, _ = self.model_tester.prepare_config_and_inputs_for_common()

for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")

model = model_class(config)

with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_2"
).to(torch_device)

for _, module in model.named_modules():
if "FlashAttention" in module.__class__.__name__:
return

self.assertTrue(False, "FlashAttention2 modules not found in model")

@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
Expand Down

0 comments on commit c946ed3

Please sign in to comment.