Skip to content

Commit

Permalink
Use pretrained generation config if possible
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea committed Aug 9, 2024
1 parent 44b09f0 commit bc9d225
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 2 deletions.
12 changes: 12 additions & 0 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from transformers import (
AutoConfig,
AutoModelForCausalLM,
GenerationConfig,
PreTrainedModel,
PreTrainedTokenizerBase,
)
Expand Down Expand Up @@ -337,6 +338,17 @@ def build_inner_model(
if dist.get_local_rank() == 0:
os.remove(signal_file_path)

# Use the pretrained generation config for the model if it exists.
try:
model.generation_config = GenerationConfig.from_pretrained(
pretrained_model_name_or_path,
use_auth_token=use_auth_token,
)
except OSError:
log.warning(
f'No existing generation config found for the model with name or path {pretrained_model_name_or_path}. Using default generation config.',
)

# Hugging Face's weight tying does not succeed if the model is inited on meta device
# so we manually apply the weight tying here
if model.config.tie_word_embeddings and resolved_init_device == 'meta':
Expand Down
67 changes: 65 additions & 2 deletions tests/models/hf/test_hf_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,28 @@

import os
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, Mapping
from unittest.mock import Mock, patch

import pytest
import torch
from omegaconf import OmegaConf as om
from transformers import PretrainedConfig
from transformers import (
AutoConfig,
AutoModelForCausalLM,
PretrainedConfig,
PreTrainedModel,
)

from llmfoundry.models.hf.hf_fsdp import rgetattr
from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM
from llmfoundry.utils import build_tokenizer
from llmfoundry.utils.builders import build_composer_model
from llmfoundry.utils.config_utils import to_dict_container
from llmfoundry.utils.config_utils import (
set_config_overrides,
to_dict_container,
)


def test_remote_code_false_mpt(
Expand Down Expand Up @@ -279,3 +288,57 @@ def test_use_flash():

# Make sure that HF has not cast the parameters to bf16
assert next(model.parameters()).dtype == torch.float32


def test_generation_config(tmp_path: Path):
# Create a small llama model to edit and save.
config = AutoConfig.from_pretrained('meta-llama/Llama-2-7b-hf')
set_config_overrides(
config,
config_overrides={
'num_hidden_layers': 2,
'hidden_size': 32,
'intermediate_size': 64,
},
)
model = AutoModelForCausalLM.from_config(config)

assert isinstance(model, PreTrainedModel)
assert model.generation_config is not None

new_bos_token_id = 100

# Set the bos_token_id to something else
model.generation_config.bos_token_id = new_bos_token_id

# Generation config and model config no longer match
assert model.generation_config.bos_token_id != model.config.bos_token_id

save_dir = tmp_path / 'model'

# Save the model.
model.save_pretrained(save_dir)

# Now load the model from the save directory and check that the bos_token_id is the same as what we set.
model_cfg = {
'name': 'hf_causal_lm',
'pretrained_model_name_or_path': str(save_dir),
'use_auth_token': True,
'pretrained': False,
'init_device': 'cpu',
}

name = model_cfg.pop('name')
model = build_composer_model(
name=name,
cfg=model_cfg,
tokenizer=None, # type: ignore
)

inner_model = model.model

assert isinstance(inner_model, PreTrainedModel)
assert inner_model.generation_config is not None

# save_pretrained and reloading with hf_causal_lm should use the bos_token_id we set from earlier.
assert inner_model.generation_config.bos_token_id == new_bos_token_id

0 comments on commit bc9d225

Please sign in to comment.