From 97f06bee2fc9cb59478280a33ea8c95d1a9f726f Mon Sep 17 00:00:00 2001 From: root Date: Tue, 13 Aug 2024 20:34:01 +0000 Subject: [PATCH] types --- llmfoundry/models/hf/hf_causal_lm.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index 753d3a11e8..b0e192d33b 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -13,7 +13,6 @@ List, Optional, Tuple, - Type, Union, ) @@ -27,6 +26,7 @@ PreTrainedModel, PreTrainedTokenizerBase, ) +from transformers.models.auto.auto_factory import _BaseAutoModelClass from llmfoundry.metrics import ( DEFAULT_CAUSAL_LM_EVAL_METRICS, @@ -194,7 +194,8 @@ def build_inner_model( config_overrides: Dict[str, Any], load_in_8bit: bool, pretrained: bool, - model_cls: Union[Type, Type[PreTrainedModel]] = AutoModelForCausalLM, + model_cls: Union[_BaseAutoModelClass, + PreTrainedModel] = AutoModelForCausalLM, prepare_for_fsdp: bool = False, ) -> Union[PreTrainedModel, 'PeftModel']: """Builds the inner model for the ComposerHFCausalLM.