diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index f1f38e2f7d..c1900646f2 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -22,6 +22,7 @@ from transformers import ( AutoConfig, AutoModelForCausalLM, + GenerationConfig, PreTrainedModel, PreTrainedTokenizerBase, ) @@ -337,6 +338,17 @@ def build_inner_model( if dist.get_local_rank() == 0: os.remove(signal_file_path) + # Use the pretrained generation config for the model if it exists. + try: + model.generation_config = GenerationConfig.from_pretrained( + pretrained_model_name_or_path, + use_auth_token=use_auth_token, + ) + except OSError: + log.warning( + f'No existing generation config found for the model with name or path {pretrained_model_name_or_path}. Using default generation config.', + ) + # Hugging Face's weight tying does not succeed if the model is inited on meta device # so we manually apply the weight tying here if model.config.tie_word_embeddings and resolved_init_device == 'meta': diff --git a/tests/models/hf/test_hf_config.py b/tests/models/hf/test_hf_config.py index 844ccd7fe5..e6bf73bcee 100644 --- a/tests/models/hf/test_hf_config.py +++ b/tests/models/hf/test_hf_config.py @@ -3,19 +3,28 @@ import os from copy import deepcopy +from pathlib import Path from typing import Any, Dict, Mapping from unittest.mock import Mock, patch import pytest import torch from omegaconf import OmegaConf as om -from transformers import PretrainedConfig +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + PretrainedConfig, + PreTrainedModel, +) from llmfoundry.models.hf.hf_fsdp import rgetattr from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM from llmfoundry.utils import build_tokenizer from llmfoundry.utils.builders import build_composer_model -from llmfoundry.utils.config_utils import to_dict_container +from llmfoundry.utils.config_utils import ( + set_config_overrides, + to_dict_container, +) def test_remote_code_false_mpt( @@ -279,3 +288,57 @@ def test_use_flash(): # Make sure that HF has not cast the parameters to bf16 assert next(model.parameters()).dtype == torch.float32 + + +def test_generation_config(tmp_path: Path): + # Create a small llama model to edit and save. + config = AutoConfig.from_pretrained('meta-llama/Llama-2-7b-hf') + set_config_overrides( + config, + config_overrides={ + 'num_hidden_layers': 2, + 'hidden_size': 32, + 'intermediate_size': 64, + }, + ) + model = AutoModelForCausalLM.from_config(config) + + assert isinstance(model, PreTrainedModel) + assert model.generation_config is not None + + new_bos_token_id = 100 + + # Set the bos_token_id to something else + model.generation_config.bos_token_id = new_bos_token_id + + # Generation config and model config no longer match + assert model.generation_config.bos_token_id != model.config.bos_token_id + + save_dir = tmp_path / 'model' + + # Save the model. + model.save_pretrained(save_dir) + + # Now load the model from the save directory and check that the bos_token_id is the same as what we set. + model_cfg = { + 'name': 'hf_causal_lm', + 'pretrained_model_name_or_path': str(save_dir), + 'use_auth_token': True, + 'pretrained': False, + 'init_device': 'cpu', + } + + name = model_cfg.pop('name') + model = build_composer_model( + name=name, + cfg=model_cfg, + tokenizer=None, # type: ignore + ) + + inner_model = model.model + + assert isinstance(inner_model, PreTrainedModel) + assert inner_model.generation_config is not None + + # save_pretrained and reloading with hf_causal_lm should use the bos_token_id we set from earlier. + assert inner_model.generation_config.bos_token_id == new_bos_token_id