diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index f1b07d840d..b0e192d33b 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -26,6 +26,7 @@ PreTrainedModel, PreTrainedTokenizerBase, ) +from transformers.models.auto.auto_factory import _BaseAutoModelClass from llmfoundry.metrics import ( DEFAULT_CAUSAL_LM_EVAL_METRICS, @@ -193,6 +194,8 @@ def build_inner_model( config_overrides: Dict[str, Any], load_in_8bit: bool, pretrained: bool, + model_cls: Union[_BaseAutoModelClass, + PreTrainedModel] = AutoModelForCausalLM, prepare_for_fsdp: bool = False, ) -> Union[PreTrainedModel, 'PeftModel']: """Builds the inner model for the ComposerHFCausalLM. @@ -207,7 +210,8 @@ def build_inner_model( config_overrides (Dict[str, Any]): The configuration overrides. load_in_8bit (bool): Whether to load in 8-bit. pretrained (bool): Whether the model is pretrained. - prepare_for_fsdp (bool, optional): Whether to prepare the model for FSDP wrapping. Default: False. + model_cls (Union[Type, Type[PreTrainedModel]]): HF class for models. Default: ``AutoModelForCausalLM``. + prepare_for_fsdp (bool, optional): Whether to prepare the model for FSDP wrapping. Default: ``False``. Returns: Union[PreTrainedModel, 'PeftModel']: The built inner model. @@ -231,6 +235,14 @@ def build_inner_model( + 'Please `pip install llm-foundry[gpu]`.', ) + if not ( + hasattr(model_cls, 'from_pretrained') and + hasattr(model_cls, 'from_config') + ): + raise AttributeError( + f'{model_cls=} is missing `from_pretrained` and `from_config` support.', + ) + # Hugging Face copies the modules into the # transformers modules cache. On particular systems, this operation seems to cause contention between # the different processes. To avoid this contention, we first create the config and generation config on local rank @@ -280,7 +292,7 @@ def build_inner_model( with init_empty_weights(include_buffers=False): with warnings.catch_warnings(): warnings.simplefilter('ignore', UserWarning) - AutoModelForCausalLM.from_pretrained( + model_cls.from_pretrained( pretrained_model_name_or_path, trust_remote_code=trust_remote_code, use_auth_token=use_auth_token, @@ -290,7 +302,7 @@ def build_inner_model( ) else: with init_empty_weights(include_buffers=False): - AutoModelForCausalLM.from_config( + model_cls.from_config( config, trust_remote_code=trust_remote_code, attn_implementation=requested_attention_implementation, @@ -301,7 +313,7 @@ def build_inner_model( # initialize the model on the correct device if resolved_init_device == 'cpu': if pretrained: - model = AutoModelForCausalLM.from_pretrained( + model = model_cls.from_pretrained( pretrained_model_name_or_path, trust_remote_code=trust_remote_code, use_auth_token=use_auth_token, @@ -310,7 +322,7 @@ def build_inner_model( config=config, ) else: - model = AutoModelForCausalLM.from_config( + model = model_cls.from_config( config, trust_remote_code=trust_remote_code, attn_implementation=requested_attention_implementation, @@ -321,7 +333,7 @@ def build_inner_model( 'Setting cfg.pretrained=True is not supported when init_device="meta".', ) with init_empty_weights(include_buffers=False): - model = AutoModelForCausalLM.from_config( + model = model_cls.from_config( config, trust_remote_code=trust_remote_code, attn_implementation=requested_attention_implementation,