diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index 753d3a11e8..b0e192d33b 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -13,7 +13,6 @@ List, Optional, Tuple, - Type, Union, ) @@ -27,6 +26,7 @@ PreTrainedModel, PreTrainedTokenizerBase, ) +from transformers.models.auto.auto_factory import _BaseAutoModelClass from llmfoundry.metrics import ( DEFAULT_CAUSAL_LM_EVAL_METRICS, @@ -194,7 +194,8 @@ def build_inner_model( config_overrides: Dict[str, Any], load_in_8bit: bool, pretrained: bool, - model_cls: Union[Type, Type[PreTrainedModel]] = AutoModelForCausalLM, + model_cls: Union[_BaseAutoModelClass, + PreTrainedModel] = AutoModelForCausalLM, prepare_for_fsdp: bool = False, ) -> Union[PreTrainedModel, 'PeftModel']: """Builds the inner model for the ComposerHFCausalLM.