From e207a021fb1b46b8b577b34739eba00661ffbef8 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 13 Aug 2024 20:20:36 +0000 Subject: [PATCH] fixing comments --- llmfoundry/models/hf/hf_causal_lm.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index 652903d82e..753d3a11e8 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -13,6 +13,7 @@ List, Optional, Tuple, + Type, Union, ) @@ -21,7 +22,6 @@ from torchmetrics import Metric from transformers import ( AutoConfig, - AutoModel, AutoModelForCausalLM, GenerationConfig, PreTrainedModel, @@ -194,7 +194,7 @@ def build_inner_model( config_overrides: Dict[str, Any], load_in_8bit: bool, pretrained: bool, - model_cls: Union[AutoModel, PreTrainedModel] = AutoModelForCausalLM, + model_cls: Union[Type, Type[PreTrainedModel]] = AutoModelForCausalLM, prepare_for_fsdp: bool = False, ) -> Union[PreTrainedModel, 'PeftModel']: """Builds the inner model for the ComposerHFCausalLM. @@ -209,7 +209,7 @@ def build_inner_model( config_overrides (Dict[str, Any]): The configuration overrides. load_in_8bit (bool): Whether to load in 8-bit. pretrained (bool): Whether the model is pretrained. - model_cls (Union[AutoModel, PreTrainedModel]): HF class for models. Default: ``AutoModelForCausalLM``. + model_cls (Union[Type, Type[PreTrainedModel]]): HF class for models. Default: ``AutoModelForCausalLM``. prepare_for_fsdp (bool, optional): Whether to prepare the model for FSDP wrapping. Default: ``False``. Returns: @@ -239,7 +239,7 @@ def build_inner_model( hasattr(model_cls, 'from_config') ): raise AttributeError( - f'{model_cls=} has missing `from_pretrained` and `from_config` support.', + f'{model_cls=} is missing `from_pretrained` and `from_config` support.', ) # Hugging Face copies the modules into the