From e40689f434a5bfa1ef5c261483fb77819324e0b9 Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Mon, 30 Oct 2023 15:12:00 -0700 Subject: [PATCH] Fix attention patch compatibility for llama2 (#705) --- .../layers/llama_attention_monkeypatch.py | 4 ++ tests/test_huggingface_flash.py | 50 +++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/llmfoundry/models/layers/llama_attention_monkeypatch.py b/llmfoundry/models/layers/llama_attention_monkeypatch.py index 88f61e3fef..9ceeb0747e 100644 --- a/llmfoundry/models/layers/llama_attention_monkeypatch.py +++ b/llmfoundry/models/layers/llama_attention_monkeypatch.py @@ -78,6 +78,8 @@ def llama_attention_patch_torch( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, + # Temporary fix for llama2 transformers compatibility, padding_mask will be deprecated in the next transformers release after 4.34.1. + padding_mask: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if use_cache: raise NotImplementedError( @@ -186,6 +188,8 @@ def llama_attention_patch_triton( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, + # Temporary fix for llama2 transformers compatibility, padding_mask will be deprecated in the next transformers release after 4.34.1. + padding_mask: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if use_cache: raise NotImplementedError( diff --git a/tests/test_huggingface_flash.py b/tests/test_huggingface_flash.py index a71217ea1f..834488bb6a 100644 --- a/tests/test_huggingface_flash.py +++ b/tests/test_huggingface_flash.py @@ -10,6 +10,7 @@ import transformers from composer.core.precision import get_precision_context from composer.utils import reproducibility +from omegaconf import DictConfig from omegaconf import OmegaConf as om from llmfoundry import COMPOSER_MODEL_REGISTRY @@ -107,6 +108,55 @@ def test_patch_equivalence(patch_fn_name: str, explicit_mask: bool, assert torch.allclose(attn_output, new_output, atol=atol, rtol=rtol) +@pytest.mark.gpu +@pytest.mark.parametrize('patch', ['triton', 'torch']) +def test_attn_patch_integration(patch: str): + if 'HUGGING_FACE_HUB_TOKEN' not in os.environ: + pytest.skip( + 'The CI cluster does not have access to the Llama models, so skip this test.' + ) + + # Save the original attention function to restore at the end of the test. + from transformers.models.llama.modeling_llama import LlamaAttention + original_attn = LlamaAttention.forward + + name = 'meta-llama/Llama-2-7b-hf' + model_cfg = DictConfig({ + 'name': 'hf_causal_lm', + 'pretrained_model_name_or_path': name, + 'config_overrides': { + 'num_hidden_layers': 2, + 'intermediate_size': 64, + }, + 'use_auth_token': True, + 'pretrained': False, + 'init_device': 'cpu', + 'attention_patch_type': patch + }) + + tokenizer = build_tokenizer(name, tokenizer_kwargs={}) + tokenizer.pad_token = tokenizer.eos_token + + model = COMPOSER_MODEL_REGISTRY[model_cfg['name']](model_cfg, tokenizer) + + tokenized_input = tokenizer(['Hello world blah blah', 'Goodbye world'], + return_tensors='pt', + padding=True) + tokenized_input['labels'] = tokenized_input['input_ids'].clone() + + tokenized_input = {k: v.cuda() for k, v in tokenized_input.items()} + model.to('cuda') + + with get_precision_context('amp_bf16'): + # We're just testing that the attention patch runs okay + outputs = model(tokenized_input) + loss = outputs.loss + loss.backward() + + # Ensure the patch does not persist beyond this test. + LlamaAttention.forward = original_attn + + @pytest.mark.gpu @pytest.mark.parametrize('model_name', ['llama2', 'mistral']) @pytest.mark.parametrize('use_flash_attention_2', [True, False])