diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index 652903d82e..753d3a11e8 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -13,6 +13,7 @@ List, Optional, Tuple, + Type, Union, ) @@ -21,7 +22,6 @@ from torchmetrics import Metric from transformers import ( AutoConfig, - AutoModel, AutoModelForCausalLM, GenerationConfig, PreTrainedModel, @@ -194,7 +194,7 @@ def build_inner_model( config_overrides: Dict[str, Any], load_in_8bit: bool, pretrained: bool, - model_cls: Union[AutoModel, PreTrainedModel] = AutoModelForCausalLM, + model_cls: Union[Type, Type[PreTrainedModel]] = AutoModelForCausalLM, prepare_for_fsdp: bool = False, ) -> Union[PreTrainedModel, 'PeftModel']: """Builds the inner model for the ComposerHFCausalLM. @@ -209,7 +209,7 @@ def build_inner_model( config_overrides (Dict[str, Any]): The configuration overrides. load_in_8bit (bool): Whether to load in 8-bit. pretrained (bool): Whether the model is pretrained. - model_cls (Union[AutoModel, PreTrainedModel]): HF class for models. Default: ``AutoModelForCausalLM``. + model_cls (Union[Type, Type[PreTrainedModel]]): HF class for models. Default: ``AutoModelForCausalLM``. prepare_for_fsdp (bool, optional): Whether to prepare the model for FSDP wrapping. Default: ``False``. Returns: @@ -239,7 +239,7 @@ def build_inner_model( hasattr(model_cls, 'from_config') ): raise AttributeError( - f'{model_cls=} has missing `from_pretrained` and `from_config` support.', + f'{model_cls=} is missing `from_pretrained` and `from_config` support.', ) # Hugging Face copies the modules into the