Skip to content

Commit

Permalink
Add fc to HF export (#1209)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored May 16, 2024
1 parent 8cd23d5 commit dc3212e
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 34 deletions.
1 change: 1 addition & 0 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
generic_param_init_fn_, # type: ignore (see note)
)
from llmfoundry.models.layers.ffn import resolve_ffn_act_fn # type: ignore (see note)
from llmfoundry.models.layers.fc import fcs # type: ignore (see note)

from llmfoundry.models.utils.act_ckpt import (
pass_on_block_idx,
Expand Down
57 changes: 56 additions & 1 deletion tests/a_scripts/inference/test_convert_composer_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Any, Callable, Dict, Optional, cast
from unittest.mock import ANY, MagicMock, patch

import catalogue
import pytest
import torch
import transformers
Expand All @@ -25,7 +26,8 @@
from llmfoundry.callbacks import HuggingFaceCheckpointer
from llmfoundry.callbacks.hf_checkpointer import _maybe_get_license_filename
from llmfoundry.data.finetuning import build_finetuning_dataloader
from llmfoundry.models.mpt import MPTConfig
from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM
from llmfoundry.utils import edit_files_for_hf_compatibility
from llmfoundry.utils.builders import (
build_composer_model,
build_optimizer,
Expand Down Expand Up @@ -1407,6 +1409,59 @@ def test_mptmoe_huggingface_conversion_callback(
delete_transformers_cache()


def test_mpt_convert_simple(
monkeypatch: pytest.MonkeyPatch,
tmp_path: pathlib.Path,
):
delete_transformers_cache()

from transformers.models.auto.configuration_auto import CONFIG_MAPPING
original_config_auto_class = MPTConfig._auto_class
original_model_auto_class = MPTForCausalLM._auto_class
CONFIG_MAPPING._extra_content['mpt'] = MPTConfig
MPTConfig.register_for_auto_class()
MPTForCausalLM.register_for_auto_class('AutoModelForCausalLM')

model_cfg = {
'name': 'mpt_causal_lm',
'init_device': 'cpu',
'd_model': 64,
'n_heads': 2,
'n_layers': 2,
'expansion_ratio': 4,
'max_seq_len': 256,
'vocab_size': 50368,
'attn_config': {
'attn_impl': 'torch',
},
'loss_fn': 'torch_crossentropy',
'tie_word_embeddings': False,
}

original_model = build_composer_model(
name='mpt_causal_lm',
tokenizer=None,
cfg=model_cfg,
)

original_model.model.save_pretrained(str(tmp_path))

edit_files_for_hf_compatibility(str(tmp_path))

monkeypatch.setattr(catalogue, 'REGISTRY', {})

_ = transformers.AutoModelForCausalLM.from_pretrained(
tmp_path,
trust_remote_code=True,
)

delete_transformers_cache()

del CONFIG_MAPPING._extra_content['mpt']
MPTConfig._auto_class = original_config_auto_class
MPTForCausalLM._auto_class = original_model_auto_class


@pytest.mark.parametrize(
'license_file_name',
['LICENSE', 'LICENSE.txt', 'license', 'license.md', None],
Expand Down
2 changes: 1 addition & 1 deletion tests/eval/test_in_context_learning_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import transformers
from composer import Evaluator
from composer.core import DataSpec
from composer.datasets.utils import MultiTokenEOSCriteria
from composer.loggers import InMemoryLogger
from composer.models import HuggingFaceModel
from composer.trainer import Trainer
Expand All @@ -24,6 +23,7 @@
InContextLearningGenerationTaskWithAnswersDataset,
InContextLearningMultipleChoiceTaskDataset,
InContextLearningSchemaTaskDataset,
MultiTokenEOSCriteria,
get_continuation_span,
get_fewshot_sample_idxs,
get_icl_task_dataloader,
Expand Down
82 changes: 50 additions & 32 deletions tests/models/hf/test_hf_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,15 @@
# SPDX-License-Identifier: Apache-2.0

import os
import tempfile
import shutil
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, Mapping
from unittest.mock import Mock, patch

import pytest
import torch
from composer.utils import dist
from omegaconf import OmegaConf as om
from transformers import AutoModelForCausalLM, PretrainedConfig
from transformers import PretrainedConfig

from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM
from llmfoundry.utils import build_tokenizer
Expand Down Expand Up @@ -84,6 +83,45 @@ def test_tie_weights(tie_word_embeddings: bool):
assert mpt.lm_head is not None


# TODO(GRT-2435): Change to fixture
def delete_transformers_cache():
# Only delete the files on local rank 0, otherwise race conditions are created
if not dist.get_local_rank() == 0:
return

hf_cache_home = os.path.expanduser(
os.getenv(
'HF_HOME',
os.path.join(
os.getenv('XDG_CACHE_HOME', '~/.cache'),
'huggingface',
),
),
)
HF_MODULES_CACHE = os.getenv(
'HF_MODULES_CACHE',
os.path.join(hf_cache_home, 'modules'),
)
if os.path.exists(HF_MODULES_CACHE) and os.path.isdir(HF_MODULES_CACHE):
shutil.rmtree(HF_MODULES_CACHE)


# def test_mpt_convert_simple(
# monkeypatch: pytest.MonkeyPatch,
# tmp_path: pathlib.Path,
# ):
# delete_transformers_cache()

# from transformers.models.auto.configuration_auto import CONFIG_MAPPING
# original_config_auto_class = MPTConfig._auto_class
# original_model_auto_class = MPTForCausalLM._auto_class
# CONFIG_MAPPING._extra_content['mpt'] = MPTConfig
# MPTConfig.register_for_auto_class()
# MPTForCausalLM.register_for_auto_class('AutoModelForCausalLM')

# delete_transformers_cache()


@pytest.mark.parametrize(
'model_cfg_overrides',
[
Expand All @@ -92,7 +130,7 @@ def test_tie_weights(tie_word_embeddings: bool):
},
{
'attn_config': {
'attn_impl': 'flash',
'attn_pdrop': 1.0,
},
},
{
Expand All @@ -103,7 +141,7 @@ def test_tie_weights(tie_word_embeddings: bool):
{
'max_seq_len': 1024,
'attn_config': {
'attn_impl': 'flash',
'attn_pdrop': 1.0,
},
'init_config': {
'emb_init_std': 5,
Expand Down Expand Up @@ -131,49 +169,29 @@ def test_hf_config_override(
model_cfg_overrides: Dict[str, Any],
conf_path: str = 'scripts/train/yamls/pretrain/testing.yaml',
):
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
CONFIG_MAPPING._extra_content['mpt'] = MPTConfig
AutoModelForCausalLM.register(MPTConfig, MPTForCausalLM)

with open(conf_path) as f:
test_cfg = om.load(f)

# Build Model
# For fast initialization, use `meta` device
print('Initializing model...')
device = 'cpu'
test_cfg.model.init_device = device
test_cfg.device = device
test_cfg.precision = 'fp16'
test_cfg.model.attn_config = {'attn_impl': 'torch', 'alibi': True}

tokenizer_cfg: Dict[str, Any] = om.to_container(
test_cfg.tokenizer,
resolve=True,
) # type: ignore
tokenizer_name = tokenizer_cfg['name']
tokenizer_kwargs = tokenizer_cfg.get('kwargs', {})
tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs)
name = test_cfg.model.pop('name')
model = build_composer_model(
name=name,
cfg=to_dict_container(test_cfg.model),
tokenizer=tokenizer,
)

# save model
tmp_dir = tempfile.TemporaryDirectory()
save_path = tmp_dir.name
tiny_overrides = {
'n_layers': 2,
'd_model': 128,
}

tokenizer.save_pretrained(save_path)
model.config.save_pretrained(save_path)
torch.save(model.state_dict(), Path(save_path) / 'pytorch_model.bin')
model_cfg_overrides.update(tiny_overrides)

# load hf causal lm model with config_overrides
hf_model_config = deepcopy(test_cfg)
model_cfg = om.create({
'name': 'hf_causal_lm',
'pretrained_model_name_or_path': save_path,
'pretrained_model_name_or_path': 'mosaicml/mpt-7b',
'pretrained': False,
'config_overrides': model_cfg_overrides,
})
Expand Down

0 comments on commit dc3212e

Please sign in to comment.