Skip to content

Commit

Permalink
Remove flash patching for HF (#1436)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Aug 7, 2024
1 parent 84cb2ed commit c262341
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 17 deletions.
23 changes: 6 additions & 17 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,23 +256,6 @@ def build_inner_model(
False, # Necessary due to https://github.com/huggingface/transformers/issues/28056
)

# This is not ideal, however Hugging Face's _autoset_attn_implementation function
# forces you to load the model in fp16/bf16 if you want to use flash attention. Rather than loading
# the model and then casting it back to fp32, we are monkeypatching their check.
# https://github.com/huggingface/transformers/issues/28052
def _autoset_attn_implementation_monkeypatch(
cls, # type: ignore
config, # type: ignore
*args, # type: ignore
**kwargs, # type: ignore
): # type: ignore
config._attn_implementation = requested_attention_implementation
return config

PreTrainedModel._autoset_attn_implementation = classmethod(
_autoset_attn_implementation_monkeypatch,
)

set_config_overrides(config, config_overrides)

# We need to have all non-zero local ranks be not-pretrained
Expand All @@ -293,13 +276,16 @@ def _autoset_attn_implementation_monkeypatch(
pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
use_auth_token=use_auth_token,
attn_implementation=
requested_attention_implementation,
config=config,
)
else:
with init_empty_weights(include_buffers=False):
AutoModelForCausalLM.from_config(
config,
trust_remote_code=trust_remote_code,
attn_implementation=requested_attention_implementation,
)

dist.barrier()
Expand All @@ -312,12 +298,14 @@ def _autoset_attn_implementation_monkeypatch(
trust_remote_code=trust_remote_code,
use_auth_token=use_auth_token,
load_in_8bit=load_in_8bit,
attn_implementation=requested_attention_implementation,
config=config,
)
else:
model = AutoModelForCausalLM.from_config(
config,
trust_remote_code=trust_remote_code,
attn_implementation=requested_attention_implementation,
)
elif resolved_init_device == 'meta':
if pretrained:
Expand All @@ -328,6 +316,7 @@ def _autoset_attn_implementation_monkeypatch(
model = AutoModelForCausalLM.from_config(
config,
trust_remote_code=trust_remote_code,
attn_implementation=requested_attention_implementation,
)
else:
raise ValueError(
Expand Down
44 changes: 44 additions & 0 deletions tests/models/hf/test_hf_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
from unittest.mock import Mock, patch

import pytest
import torch
from omegaconf import OmegaConf as om
from transformers import PretrainedConfig

from llmfoundry.models.hf.hf_fsdp import rgetattr
from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM
from llmfoundry.utils import build_tokenizer
from llmfoundry.utils.builders import build_composer_model
Expand Down Expand Up @@ -235,3 +237,45 @@ def test_nested_override():
assert isinstance(model.config.ffn_config, PretrainedConfig)
# Ensure the other values still exist and are not set back to their defaults
assert model.config.ffn_config.moe_num_experts == 16


@pytest.mark.gpu
def test_use_flash():
model_cfg = {
'name': 'hf_causal_lm',
'pretrained_model_name_or_path': 'codellama/CodeLlama-7b-hf',
'config_overrides': {
'num_hidden_layers': 2,
'hidden_size': 32,
'intermediate_size': 64,
'torch_dtype': 'bfloat16',
},
'pretrained': False,
'init_device': 'cpu',
'use_flash_attention_2': True,
}

name = model_cfg.pop('name')
model = build_composer_model(
name=name,
cfg=model_cfg,
tokenizer=None, # type: ignore
)

from transformers.models.llama.modeling_llama import (
LlamaFlashAttention2,
)
flash_attn_class = LlamaFlashAttention2
attention_layers_attr = 'model.model.layers'
attention_attr = 'self_attn'

# check that it actually used flash attention 2
assert model.model.config._attn_implementation == ('flash_attention_2')
attention_layer = rgetattr(
rgetattr(model, attention_layers_attr)[0],
attention_attr,
)
assert isinstance(attention_layer, flash_attn_class)

# Make sure that HF has not cast the parameters to bf16
assert next(model.parameters()).dtype == torch.float32

0 comments on commit c262341

Please sign in to comment.