Skip to content

Commit

Permalink
Fix attention patch compatibility for llama2 (#705)
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea authored Oct 30, 2023
1 parent db9227a commit e40689f
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 0 deletions.
4 changes: 4 additions & 0 deletions llmfoundry/models/layers/llama_attention_monkeypatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ def llama_attention_patch_torch(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
# Temporary fix for llama2 transformers compatibility, padding_mask will be deprecated in the next transformers release after 4.34.1.
padding_mask: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if use_cache:
raise NotImplementedError(
Expand Down Expand Up @@ -186,6 +188,8 @@ def llama_attention_patch_triton(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
# Temporary fix for llama2 transformers compatibility, padding_mask will be deprecated in the next transformers release after 4.34.1.
padding_mask: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if use_cache:
raise NotImplementedError(
Expand Down
50 changes: 50 additions & 0 deletions tests/test_huggingface_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import transformers
from composer.core.precision import get_precision_context
from composer.utils import reproducibility
from omegaconf import DictConfig
from omegaconf import OmegaConf as om

from llmfoundry import COMPOSER_MODEL_REGISTRY
Expand Down Expand Up @@ -107,6 +108,55 @@ def test_patch_equivalence(patch_fn_name: str, explicit_mask: bool,
assert torch.allclose(attn_output, new_output, atol=atol, rtol=rtol)


@pytest.mark.gpu
@pytest.mark.parametrize('patch', ['triton', 'torch'])
def test_attn_patch_integration(patch: str):
if 'HUGGING_FACE_HUB_TOKEN' not in os.environ:
pytest.skip(
'The CI cluster does not have access to the Llama models, so skip this test.'
)

# Save the original attention function to restore at the end of the test.
from transformers.models.llama.modeling_llama import LlamaAttention
original_attn = LlamaAttention.forward

name = 'meta-llama/Llama-2-7b-hf'
model_cfg = DictConfig({
'name': 'hf_causal_lm',
'pretrained_model_name_or_path': name,
'config_overrides': {
'num_hidden_layers': 2,
'intermediate_size': 64,
},
'use_auth_token': True,
'pretrained': False,
'init_device': 'cpu',
'attention_patch_type': patch
})

tokenizer = build_tokenizer(name, tokenizer_kwargs={})
tokenizer.pad_token = tokenizer.eos_token

model = COMPOSER_MODEL_REGISTRY[model_cfg['name']](model_cfg, tokenizer)

tokenized_input = tokenizer(['Hello world blah blah', 'Goodbye world'],
return_tensors='pt',
padding=True)
tokenized_input['labels'] = tokenized_input['input_ids'].clone()

tokenized_input = {k: v.cuda() for k, v in tokenized_input.items()}
model.to('cuda')

with get_precision_context('amp_bf16'):
# We're just testing that the attention patch runs okay
outputs = model(tokenized_input)
loss = outputs.loss
loss.backward()

# Ensure the patch does not persist beyond this test.
LlamaAttention.forward = original_attn


@pytest.mark.gpu
@pytest.mark.parametrize('model_name', ['llama2', 'mistral'])
@pytest.mark.parametrize('use_flash_attention_2', [True, False])
Expand Down

0 comments on commit e40689f

Please sign in to comment.