Skip to content

Commit

Permalink
Fix some fa2 tests (#35340)
Browse files Browse the repository at this point in the history
* remove fa2 test

* remove other failing tests

* style
  • Loading branch information
ArthurZucker authored Dec 19, 2024
1 parent 667ed56 commit 1fa807f
Show file tree
Hide file tree
Showing 4 changed files with 0 additions and 119 deletions.
29 changes: 0 additions & 29 deletions tests/models/granite/test_modeling_granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,12 @@
# limitations under the License.
"""Testing suite for the PyTorch Granite model."""

import tempfile
import unittest

from parameterized import parameterized

from transformers import GraniteConfig, is_torch_available, set_seed
from transformers.testing_utils import (
require_flash_attn,
require_read_token,
require_torch,
require_torch_gpu,
Expand Down Expand Up @@ -417,33 +415,6 @@ def test_model_rope_scaling(self):
with self.assertRaises(AssertionError):
torch.testing.assert_close(yarn_sin_long, original_sin_long)

@require_flash_attn
@require_torch_gpu
@slow
def test_use_flash_attention_2_true(self):
"""
NOTE: this is the only test testing that the legacy `use_flash_attention=2` argument still works as intended.
"""
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with tempfile.TemporaryDirectory() as tmp_dir:
model = model_class(config)
model.save_pretrained(tmp_dir)

new_model = GraniteForCausalLM.from_pretrained(
tmp_dir, use_flash_attention_2=True, torch_dtype=torch.float16
).to("cuda")

self.assertTrue(new_model.config._attn_implementation == "flash_attention_2")

has_flash = False
for name, submodule in new_model.named_modules():
if "FlashAttention" in submodule.__class__.__name__:
has_flash = True
break
if not has_flash:
raise ValueError("The flash model should have flash attention layers")


@require_torch_gpu
class GraniteIntegrationTest(unittest.TestCase):
Expand Down
29 changes: 0 additions & 29 deletions tests/models/granitemoe/test_modeling_granitemoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,12 @@
# limitations under the License.
"""Testing suite for the PyTorch GraniteMoe model."""

import tempfile
import unittest

from parameterized import parameterized

from transformers import AutoTokenizer, GraniteMoeConfig, is_torch_available, set_seed
from transformers.testing_utils import (
require_flash_attn,
require_read_token,
require_torch,
require_torch_gpu,
Expand Down Expand Up @@ -416,33 +414,6 @@ def test_model_rope_scaling(self):
with self.assertRaises(AssertionError):
torch.testing.assert_close(yarn_sin_long, original_sin_long)

@require_flash_attn
@require_torch_gpu
@slow
def test_use_flash_attention_2_true(self):
"""
NOTE: this is the only test testing that the legacy `use_flash_attention=2` argument still works as intended.
"""
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with tempfile.TemporaryDirectory() as tmp_dir:
model = model_class(config)
model.save_pretrained(tmp_dir)

new_model = GraniteMoeForCausalLM.from_pretrained(
tmp_dir, use_flash_attention_2=True, torch_dtype=torch.float16
).to("cuda")

self.assertTrue(new_model.config._attn_implementation == "flash_attention_2")

has_flash = False
for name, submodule in new_model.named_modules():
if "FlashAttention" in submodule.__class__.__name__:
has_flash = True
break
if not has_flash:
raise ValueError("The flash model should have flash attention layers")


@require_torch_gpu
class GraniteMoeIntegrationTest(unittest.TestCase):
Expand Down
31 changes: 0 additions & 31 deletions tests/models/llama/test_modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,15 @@
# limitations under the License.
"""Testing suite for the PyTorch LLaMA model."""

import tempfile
import unittest

import pytest
from packaging import version
from parameterized import parameterized

from transformers import AutoTokenizer, LlamaConfig, StaticCache, is_torch_available, set_seed
from transformers.generation.configuration_utils import GenerationConfig
from transformers.testing_utils import (
cleanup,
require_flash_attn,
require_read_token,
require_torch,
require_torch_accelerator,
Expand Down Expand Up @@ -543,34 +540,6 @@ def _reinitialize_config(base_config, new_kwargs):
with self.assertRaises(KeyError):
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear"}}) # missing "factor"

@require_flash_attn
@require_torch_gpu
@slow
@pytest.mark.flash_attn_test
def test_use_flash_attention_2_true(self):
"""
NOTE: this is the only test testing that the legacy `use_flash_attention=2` argument still works as intended.
"""
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with tempfile.TemporaryDirectory() as tmp_dir:
model = model_class(config)
model.save_pretrained(tmp_dir)

new_model = LlamaForCausalLM.from_pretrained(
tmp_dir, use_flash_attention_2=True, torch_dtype=torch.float16
).to("cuda")

self.assertTrue(new_model.config._attn_implementation == "flash_attention_2")

has_flash = False
for name, submodule in new_model.named_modules():
if "FlashAttention" in submodule.__class__.__name__:
has_flash = True
break
if not has_flash:
raise ValueError("The flash model should have flash attention layers")


@require_torch_gpu
class LlamaIntegrationTest(unittest.TestCase):
Expand Down
30 changes: 0 additions & 30 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2769,8 +2769,6 @@ def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-4, n
attributes = tuple([f"{name}_{idx}" for idx in range(len(fx_outputs))])

for fx_output, pt_output, attr in zip(fx_outputs, pt_outputs, attributes):
if isinstance(pt_output, DynamicCache):
pt_output = pt_output.to_legacy_cache()
self.check_pt_flax_outputs(fx_output, pt_output, model_class, tol=tol, name=attr)

elif isinstance(fx_outputs, jnp.ndarray):
Expand Down Expand Up @@ -3612,34 +3610,6 @@ def test_model_is_small(self):
num_params < 1000000
), f"{model_class} is too big for the common tests ({num_params})! It should have 1M max."

@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attn_2_conversion(self):
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")

config, _ = self.model_tester.prepare_config_and_inputs_for_common()

for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")

model = model_class(config)

with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_2"
).to(torch_device)

for _, module in model.named_modules():
if "FlashAttention" in module.__class__.__name__:
return

self.assertTrue(False, "FlashAttention2 modules not found in model")

@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
Expand Down

0 comments on commit 1fa807f

Please sign in to comment.