Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove flash patching for HF #1436

Merged
merged 4 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 6 additions & 17 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,23 +256,6 @@ def build_inner_model(
False, # Necessary due to https://github.com/huggingface/transformers/issues/28056
)

# This is not ideal, however Hugging Face's _autoset_attn_implementation function
# forces you to load the model in fp16/bf16 if you want to use flash attention. Rather than loading
# the model and then casting it back to fp32, we are monkeypatching their check.
# https://github.com/huggingface/transformers/issues/28052
def _autoset_attn_implementation_monkeypatch(
cls, # type: ignore
config, # type: ignore
*args, # type: ignore
**kwargs, # type: ignore
): # type: ignore
config._attn_implementation = requested_attention_implementation
return config

PreTrainedModel._autoset_attn_implementation = classmethod(
_autoset_attn_implementation_monkeypatch,
)

set_config_overrides(config, config_overrides)

# We need to have all non-zero local ranks be not-pretrained
Expand All @@ -293,13 +276,16 @@ def _autoset_attn_implementation_monkeypatch(
pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
use_auth_token=use_auth_token,
attn_implementation=
requested_attention_implementation,
config=config,
)
else:
with init_empty_weights(include_buffers=False):
AutoModelForCausalLM.from_config(
config,
trust_remote_code=trust_remote_code,
attn_implementation=requested_attention_implementation,
)

dist.barrier()
Expand All @@ -312,12 +298,14 @@ def _autoset_attn_implementation_monkeypatch(
trust_remote_code=trust_remote_code,
use_auth_token=use_auth_token,
load_in_8bit=load_in_8bit,
attn_implementation=requested_attention_implementation,
config=config,
)
else:
model = AutoModelForCausalLM.from_config(
config,
trust_remote_code=trust_remote_code,
attn_implementation=requested_attention_implementation,
)
elif resolved_init_device == 'meta':
if pretrained:
Expand All @@ -328,6 +316,7 @@ def _autoset_attn_implementation_monkeypatch(
model = AutoModelForCausalLM.from_config(
config,
trust_remote_code=trust_remote_code,
attn_implementation=requested_attention_implementation,
)
else:
raise ValueError(
Expand Down
44 changes: 44 additions & 0 deletions tests/models/hf/test_hf_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
from unittest.mock import Mock, patch

import pytest
import torch
from omegaconf import OmegaConf as om
from transformers import PretrainedConfig

from llmfoundry.models.hf.hf_fsdp import rgetattr
from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM
from llmfoundry.utils import build_tokenizer
from llmfoundry.utils.builders import build_composer_model
Expand Down Expand Up @@ -235,3 +237,45 @@ def test_nested_override():
assert isinstance(model.config.ffn_config, PretrainedConfig)
# Ensure the other values still exist and are not set back to their defaults
assert model.config.ffn_config.moe_num_experts == 16


@pytest.mark.gpu
def test_use_flash():
model_cfg = {
'name': 'hf_causal_lm',
'pretrained_model_name_or_path': 'codellama/CodeLlama-7b-hf',
'config_overrides': {
'num_hidden_layers': 2,
'hidden_size': 32,
'intermediate_size': 64,
'torch_dtype': 'bfloat16',
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
},
'pretrained': False,
'init_device': 'cpu',
'use_flash_attention_2': True,
}

name = model_cfg.pop('name')
model = build_composer_model(
name=name,
cfg=model_cfg,
tokenizer=None, # type: ignore
)

from transformers.models.llama.modeling_llama import (
LlamaFlashAttention2,
)
flash_attn_class = LlamaFlashAttention2
attention_layers_attr = 'model.model.layers'
attention_attr = 'self_attn'

# check that it actually used flash attention 2
assert model.model.config._attn_implementation == ('flash_attention_2')
attention_layer = rgetattr(
rgetattr(model, attention_layers_attr)[0],
attention_attr,
)
assert isinstance(attention_layer, flash_attn_class)

# Make sure that HF has not cast the parameters to bf16
assert next(model.parameters()).dtype == torch.float32
Loading