diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index 34ce22d694..f1f38e2f7d 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -256,23 +256,6 @@ def build_inner_model( False, # Necessary due to https://github.com/huggingface/transformers/issues/28056 ) - # This is not ideal, however Hugging Face's _autoset_attn_implementation function - # forces you to load the model in fp16/bf16 if you want to use flash attention. Rather than loading - # the model and then casting it back to fp32, we are monkeypatching their check. - # https://github.com/huggingface/transformers/issues/28052 - def _autoset_attn_implementation_monkeypatch( - cls, # type: ignore - config, # type: ignore - *args, # type: ignore - **kwargs, # type: ignore - ): # type: ignore - config._attn_implementation = requested_attention_implementation - return config - - PreTrainedModel._autoset_attn_implementation = classmethod( - _autoset_attn_implementation_monkeypatch, - ) - set_config_overrides(config, config_overrides) # We need to have all non-zero local ranks be not-pretrained @@ -293,6 +276,8 @@ def _autoset_attn_implementation_monkeypatch( pretrained_model_name_or_path, trust_remote_code=trust_remote_code, use_auth_token=use_auth_token, + attn_implementation= + requested_attention_implementation, config=config, ) else: @@ -300,6 +285,7 @@ def _autoset_attn_implementation_monkeypatch( AutoModelForCausalLM.from_config( config, trust_remote_code=trust_remote_code, + attn_implementation=requested_attention_implementation, ) dist.barrier() @@ -312,12 +298,14 @@ def _autoset_attn_implementation_monkeypatch( trust_remote_code=trust_remote_code, use_auth_token=use_auth_token, load_in_8bit=load_in_8bit, + attn_implementation=requested_attention_implementation, config=config, ) else: model = AutoModelForCausalLM.from_config( config, trust_remote_code=trust_remote_code, + attn_implementation=requested_attention_implementation, ) elif resolved_init_device == 'meta': if pretrained: @@ -328,6 +316,7 @@ def _autoset_attn_implementation_monkeypatch( model = AutoModelForCausalLM.from_config( config, trust_remote_code=trust_remote_code, + attn_implementation=requested_attention_implementation, ) else: raise ValueError( diff --git a/tests/models/hf/test_hf_config.py b/tests/models/hf/test_hf_config.py index d0ec544de8..844ccd7fe5 100644 --- a/tests/models/hf/test_hf_config.py +++ b/tests/models/hf/test_hf_config.py @@ -7,9 +7,11 @@ from unittest.mock import Mock, patch import pytest +import torch from omegaconf import OmegaConf as om from transformers import PretrainedConfig +from llmfoundry.models.hf.hf_fsdp import rgetattr from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM from llmfoundry.utils import build_tokenizer from llmfoundry.utils.builders import build_composer_model @@ -235,3 +237,45 @@ def test_nested_override(): assert isinstance(model.config.ffn_config, PretrainedConfig) # Ensure the other values still exist and are not set back to their defaults assert model.config.ffn_config.moe_num_experts == 16 + + +@pytest.mark.gpu +def test_use_flash(): + model_cfg = { + 'name': 'hf_causal_lm', + 'pretrained_model_name_or_path': 'codellama/CodeLlama-7b-hf', + 'config_overrides': { + 'num_hidden_layers': 2, + 'hidden_size': 32, + 'intermediate_size': 64, + 'torch_dtype': 'bfloat16', + }, + 'pretrained': False, + 'init_device': 'cpu', + 'use_flash_attention_2': True, + } + + name = model_cfg.pop('name') + model = build_composer_model( + name=name, + cfg=model_cfg, + tokenizer=None, # type: ignore + ) + + from transformers.models.llama.modeling_llama import ( + LlamaFlashAttention2, + ) + flash_attn_class = LlamaFlashAttention2 + attention_layers_attr = 'model.model.layers' + attention_attr = 'self_attn' + + # check that it actually used flash attention 2 + assert model.model.config._attn_implementation == ('flash_attention_2') + attention_layer = rgetattr( + rgetattr(model, attention_layers_attr)[0], + attention_attr, + ) + assert isinstance(attention_layer, flash_attn_class) + + # Make sure that HF has not cast the parameters to bf16 + assert next(model.parameters()).dtype == torch.float32