Skip to content

Commit

Permalink
tests non model
Browse files Browse the repository at this point in the history
  • Loading branch information
Cyrilvallez committed Dec 17, 2024
1 parent 98b7f97 commit 88e2fe5
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 41 deletions.
7 changes: 1 addition & 6 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1488,12 +1488,7 @@ def _autoset_attn_implementation(
message += (
', `"attn_implementation=flex_attention"` (implementation using torch\'s flex_attention)'
)
if cls.model_type + "_" + config._attn_implementation in ALL_ATTENTION_FUNCTIONS:
config._attn_implementation_internal = cls.model_type + "_" + config._attn_implementation
if config._attn_implementation in ALL_ATTENTION_FUNCTIONS:
pass
else:
raise ValueError(message + ".")
raise ValueError(message + ".")

# If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available.
requested_attn_implementation = config._attn_implementation_internal
Expand Down
35 changes: 0 additions & 35 deletions tests/utils/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,32 +563,17 @@ def test_model_from_pretrained_attn_implementation(self):
if is_flash_attn_2_available():
attn_implementation_available.append("flash_attention_2")

mistral_attention_classes = {
"eager": "MistralAttention",
"sdpa": "MistralSdpaAttention",
"flash_attention_2": "MistralFlashAttention2",
}
for requested_attn_implementation in attn_implementation_available:
model = AutoModelForCausalLM.from_pretrained(
TINY_MISTRAL, attn_implementation=requested_attn_implementation
)
self.assertEqual(model.config._attn_implementation, requested_attn_implementation)
for module in model.modules():
if "Attention" in module.__class__.__name__:
self.assertEqual(
module.__class__.__name__, mistral_attention_classes[requested_attn_implementation]
)

config = AutoConfig.from_pretrained(TINY_MISTRAL)
model = AutoModelForCausalLM.from_pretrained(
TINY_MISTRAL, config=config, attn_implementation=requested_attn_implementation
)
self.assertEqual(model.config._attn_implementation, requested_attn_implementation)
for module in model.modules():
if "Attention" in module.__class__.__name__:
self.assertEqual(
module.__class__.__name__, mistral_attention_classes[requested_attn_implementation]
)

def test_model_from_config_attn_implementation(self):
# test that the model can be instantiated with attn_implementation of either
Expand All @@ -602,47 +587,27 @@ def test_model_from_config_attn_implementation(self):
if is_flash_attn_2_available():
attn_implementation_available.append("flash_attention_2")

mistral_attention_classes = {
"eager": "MistralAttention",
"sdpa": "MistralSdpaAttention",
"flash_attention_2": "MistralFlashAttention2",
}
for requested_attn_implementation in attn_implementation_available:
config = AutoConfig.from_pretrained(TINY_MISTRAL, attn_implementation=requested_attn_implementation)
# Ensure the config was set correctly
self.assertEqual(config._attn_implementation, requested_attn_implementation)
self.assertEqual(config._attn_implementation_internal, requested_attn_implementation)
model = AutoModelForCausalLM.from_config(config)
self.assertEqual(model.config._attn_implementation, requested_attn_implementation)
for module in model.modules():
if "Attention" in module.__class__.__name__:
self.assertEqual(
module.__class__.__name__, mistral_attention_classes[requested_attn_implementation]
)

config = AutoConfig.from_pretrained(TINY_MISTRAL)
# When the config is not set, the default is "eager"
self.assertEqual(config._attn_implementation, "eager")
self.assertEqual(config._attn_implementation_internal, None)
model = AutoModelForCausalLM.from_config(config=config, attn_implementation=requested_attn_implementation)
self.assertEqual(model.config._attn_implementation, requested_attn_implementation)
for module in model.modules():
if "Attention" in module.__class__.__name__:
self.assertEqual(
module.__class__.__name__, mistral_attention_classes[requested_attn_implementation]
)

# Set a nonsense attn_implementation in the config, which should be overridden by the explicit argument
config = AutoConfig.from_pretrained(TINY_MISTRAL, attn_implementation="foo-bar-baz")
self.assertEqual(config._attn_implementation, "foo-bar-baz")
self.assertEqual(config._attn_implementation_internal, "foo-bar-baz")
model = AutoModelForCausalLM.from_config(config=config, attn_implementation=requested_attn_implementation)
self.assertEqual(model.config._attn_implementation, requested_attn_implementation)
for module in model.modules():
if "Attention" in module.__class__.__name__:
self.assertEqual(
module.__class__.__name__, mistral_attention_classes[requested_attn_implementation]
)

def test_torch_dtype_byte_sizes(self):
torch_dtypes_and_bytes = [
Expand Down

0 comments on commit 88e2fe5

Please sign in to comment.