diff --git a/llmfoundry/models/hf/__init__.py b/llmfoundry/models/hf/__init__.py index b34281cd81..57892e0dda 100644 --- a/llmfoundry/models/hf/__init__.py +++ b/llmfoundry/models/hf/__init__.py @@ -9,8 +9,10 @@ ) from llmfoundry.models.hf.hf_t5 import ComposerHFT5 from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithFSDP +from llmfoundry.models.hf.hf_base import BaseHuggingFaceModel __all__ = [ + 'BaseHuggingFaceModel', 'ComposerHFCausalLM', 'ComposerHFT5', 'prepare_hf_causal_lm_model_for_fsdp', diff --git a/llmfoundry/models/hf/hf_base.py b/llmfoundry/models/hf/hf_base.py new file mode 100644 index 0000000000..da3b1b12d3 --- /dev/null +++ b/llmfoundry/models/hf/hf_base.py @@ -0,0 +1,437 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +"""Re-usable :class:`.ComposerModel` for LLM HF Models.""" + +from __future__ import annotations + +import logging +import os +import warnings +from collections import UserDict +from typing import TYPE_CHECKING, Mapping, Optional, Union, Any + +import transformers +from composer.models.huggingface import HuggingFaceModel, peft_installed +from composer.utils import dist +from torchmetrics import Metric +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + GenerationConfig, + PreTrainedModel, + PreTrainedTokenizerBase, + PretrainedConfig, +) +from transformers.utils.generic import ModelOutput +from transformers.models.auto.auto_factory import _BaseAutoModelClass + +from llmfoundry.models.hf.hf_fsdp import ( + prepare_hf_model_for_fsdp, + hf_get_init_device, +) +from llmfoundry.utils.config_utils import set_config_overrides +from llmfoundry.models.layers.attention import is_flash_v2_installed +from llmfoundry.models.utils import init_empty_weights + +if TYPE_CHECKING: + from peft import PeftConfig, PeftModel + +__all__ = ['BaseHuggingFaceModel'] + +# HuggingFace hardcodes the ignore index to -100 +_HF_IGNORE_INDEX = -100 + +log = logging.getLogger(__name__) + +class BaseHuggingFaceModel(HuggingFaceModel): + """Wrapper around HuggingFaceModel. + + Base class for HuggingFace based models. + """ + + model_cls: Union[_BaseAutoModelClass, + PreTrainedModel] = AutoModelForCausalLM + default_train_metrics: list = [] + default_eval_metrics: list = [] + + def __init__( + self, + pretrained_model_name_or_path: str, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + pretrained: bool = True, + pretrained_lora_id_or_path: Optional[str] = None, + trust_remote_code: bool = True, + use_auth_token: bool = False, + use_flash_attention_2: bool = False, + load_in_8bit: bool = False, + init_device: str = 'cpu', + config_overrides: Optional[dict[str, Any]] = None, + use_logits: bool = True, + shift_labels: bool = False, + peft_config: Optional['PeftConfig'] = None, + allow_embedding_resizing: bool = False, + use_train_metrics: bool = True, + additional_train_metrics: Optional[list] = None, + additional_eval_metrics: Optional[list] = None, + should_save_peft_only: bool = True, + ): + config_overrides = config_overrides or {} + + model = self.build_inner_model( + pretrained_model_name_or_path=pretrained_model_name_or_path, + pretrained_lora_id_or_path=pretrained_lora_id_or_path, + trust_remote_code=trust_remote_code, + init_device=init_device, + use_flash_attention_2=use_flash_attention_2, + use_auth_token=use_auth_token, + config_overrides=config_overrides, + load_in_8bit=load_in_8bit, + pretrained=pretrained, + ) + + model = self.transform_model(model) + + self.prepare_inner_model(model, init_device) + + metrics, eval_metrics = self.build_metrics( + use_train_metrics=use_train_metrics, + additional_train_metrics=additional_train_metrics, + additional_eval_metrics=additional_eval_metrics, + ) + + super().__init__( + model=model, + tokenizer=tokenizer, + use_logits=use_logits, + metrics=metrics, + eval_metrics=eval_metrics, + shift_labels=shift_labels, + allow_embedding_resizing=allow_embedding_resizing, + peft_config=peft_config, + should_save_peft_only=should_save_peft_only, + ) + + def forward(self, batch: Mapping): + if isinstance(batch, dict) or isinstance(batch, UserDict): + # Further input validation is left to the huggingface forward call + batch = { + k: v for k, v in batch.items() if k in self.model_forward_args + } + output = self.model(**batch) # type: ignore (thirdparty) + else: + raise ValueError( + 'Unexpected batch type. Expected a dictionary with keys corresponding to the inputs to the forward function of the Huggingface model', + ) + return output + + def loss(self, outputs: ModelOutput, batch: Mapping): + if self.config.use_return_dict: + return outputs['loss'] + # loss is at index 0 in the output tuple, logits are at index 1 + return outputs[:2] + + def transform_model(self, model: PreTrainedModel) -> PreTrainedModel: + """Transforms the model after initialization. + + Args: + model (PreTrainedModel): The model to transform. + + Returns: + PreTrainedModel: The transformed model. + """ + return model + + @classmethod + def build_config( + cls, + pretrained_model_name_or_path: str, + trust_remote_code: bool, + use_auth_token: bool, + attn_implementation: str, + config_overrides: dict[str, Any], + **kwargs: Any, + ) -> PretrainedConfig: + config = AutoConfig.from_pretrained( + pretrained_model_name_or_path, + trust_remote_code=trust_remote_code, + use_auth_token=use_auth_token, + attn_implementation=attn_implementation, + use_cache= + False, # Necessary due to https://github.com/huggingface/transformers/issues/28056 + ) + + set_config_overrides(config, config_overrides) + + return config + + @classmethod + def build_metrics( + cls, + use_train_metrics: bool, + additional_train_metrics: Optional[list[str]] = None, + additional_eval_metrics: Optional[list[str]] = None, + ) -> tuple[list[Metric], list[Metric]]: + """Builds the training and evaluation metrics for the model. + + Args: + use_train_metrics (bool): Whether to use training metrics. + additional_train_metrics (Optional[List[str]]): Additional training metrics to include. + additional_eval_metrics (Optional[List[str]]): Additional evaluation metrics to include. + + Returns: + Tuple[List[Metric], List[Metric]]: A tuple containing the list of training metrics and evaluation metrics. + """ + from llmfoundry.utils.builders import build_metric + + train_metric_names = cls.default_train_metrics + ( + additional_train_metrics or [] + ) + train_metrics = [ + build_metric(metric, {}) for metric in train_metric_names + ] if use_train_metrics else [] + eval_metric_names = cls.default_eval_metrics + ( + additional_eval_metrics or [] + ) + eval_metrics = [ + build_metric(metric, {}) for metric in eval_metric_names + ] + + return train_metrics, eval_metrics + + @classmethod + def build_inner_model( + cls, + pretrained_model_name_or_path: str, + pretrained_lora_id_or_path: Optional[str], + trust_remote_code: bool, + init_device: str, + use_flash_attention_2: bool, + use_auth_token: bool, + config_overrides: dict[str, Any], + load_in_8bit: bool, + pretrained: bool, + model_cls: Union[_BaseAutoModelClass, + PreTrainedModel] = AutoModelForCausalLM, + prepare_for_fsdp: bool = False, + ) -> Union[PreTrainedModel, 'PeftModel']: + """Builds the inner model for the ComposerHFCausalLM. + + Args: + pretrained_model_name_or_path (str): The pretrained model name or path. + pretrained_lora_id_or_path (Optional[str]): The pretrained LORA ID or path. + trust_remote_code (bool): Whether to trust remote code. + init_device (str): The initialization device. + use_flash_attention_2 (bool): Whether to use flash attention 2. + use_auth_token (bool): Whether to use an authentication token. + config_overrides (Dict[str, Any]): The configuration overrides. + load_in_8bit (bool): Whether to load in 8-bit. + pretrained (bool): Whether the model is pretrained. + model_cls (Union[Type, Type[PreTrainedModel]]): Kept for backwards compatibility. + prepare_for_fsdp (bool, optional): Kept for backwards compatilbility. + + Returns: + Union[PreTrainedModel, 'PeftModel']: The built inner model. + """ + if not trust_remote_code and pretrained_model_name_or_path.startswith( + 'mosaicml/mpt', + ): + raise ValueError( + 'trust_remote_code must be set to True for MPT models. Without this, the MPT model code will come from the transformers library, ' + + + 'which is significantly slower and not compatible with the LLM foundry training code, rather than the code release by MosaicML.', + ) + # Resolve "mixed" init device to either "cpu" or "meta" + resolved_init_device = hf_get_init_device(init_device) + requested_attention_implementation = 'flash_attention_2' if use_flash_attention_2 else 'eager' + + if use_flash_attention_2 and not is_flash_v2_installed(): + raise ValueError( + 'use_flash_attention_2 is set to True, but flash-attention 2 is not installed. ' + + 'Please `pip install llm-foundry[gpu]`.', + ) + + if not ( + hasattr(cls.model_cls, 'from_pretrained') and + hasattr(cls.model_cls, 'from_config') + ): + raise AttributeError( + f'{cls.model_cls=} is missing `from_pretrained` and `from_config` support.', + ) + + # Hugging Face copies the modules into the + # transformers modules cache. On particular systems, this operation seems to cause contention between + # the different processes. To avoid this contention, we first create the config and generation config on local rank + # zero. This will set up the transformers module cache and avoid the future contention. + if dist.get_local_rank() == 0: + AutoConfig.from_pretrained( + pretrained_model_name_or_path, + trust_remote_code=trust_remote_code, + use_auth_token=use_auth_token, + attn_implementation=requested_attention_implementation, + use_cache= + False, # Necessary due to https://github.com/huggingface/transformers/issues/28056 + ) + try: + GenerationConfig.from_pretrained( + pretrained_model_name_or_path, + use_auth_token=use_auth_token, + ) + except OSError: + pass + + dist.barrier() + + # Construct the Hugging Face config to use + config = cls.build_config( + pretrained_model_name_or_path, + trust_remote_code=trust_remote_code, + use_auth_token=use_auth_token, + attn_implementation=requested_attention_implementation, + config_overrides=config_overrides, + ) + + # We need to have all non-zero local ranks be not-pretrained + # Rank 0 will still be pretrained, and distribute the weights appropriately + if dist.get_local_rank() != 0 and init_device == 'mixed': + pretrained = False + + # Hugging Face copies the modules into the + # transformers modules cache. On particular systems, this operation seems to cause contention between + # the different processes. To avoid this contention, we first create the model (on meta device) on local rank + # zero. This will set up the transformers model cache and avoid the future contention. + if dist.get_local_rank() == 0: + if pretrained and os.path.isdir(pretrained_model_name_or_path): + with init_empty_weights(include_buffers=False): + with warnings.catch_warnings(): + warnings.simplefilter('ignore', UserWarning) + cls.model_cls.from_pretrained( + pretrained_model_name_or_path, + trust_remote_code=trust_remote_code, + use_auth_token=use_auth_token, + attn_implementation= + requested_attention_implementation, + config=config, + ) + else: + with init_empty_weights(include_buffers=False): + cls.model_cls.from_config( + config, + trust_remote_code=trust_remote_code, + attn_implementation=requested_attention_implementation, + ) + + dist.barrier() + + # initialize the model on the correct device + if resolved_init_device == 'cpu': + if pretrained: + model = cls.model_cls.from_pretrained( + pretrained_model_name_or_path, + trust_remote_code=trust_remote_code, + use_auth_token=use_auth_token, + load_in_8bit=load_in_8bit, + attn_implementation=requested_attention_implementation, + config=config, + ) + else: + model = cls.model_cls.from_config( + config, + trust_remote_code=trust_remote_code, + attn_implementation=requested_attention_implementation, + ) + elif resolved_init_device == 'meta': + if pretrained: + raise ValueError( + 'Setting cfg.pretrained=True is not supported when init_device="meta".', + ) + with init_empty_weights(include_buffers=False): + model = cls.model_cls.from_config( + config, + trust_remote_code=trust_remote_code, + attn_implementation=requested_attention_implementation, + ) + else: + raise ValueError( + f'init_device="{init_device}" must be either "cpu" or "meta".', + ) + + signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_completed' + if dist.get_local_rank() == 0: + with open(signal_file_path, 'wb') as f: + f.write(b'local_rank0_completed_download') + + # Avoid the collective call until the local rank zero has finished trying to download the checkpoint + # so that we don't timeout for large downloads. This syncs all processes on the node + with dist.local_rank_zero_download_and_wait(signal_file_path): + # Then, wait to ensure every node has finished downloading the checkpoint + dist.barrier() + + if dist.get_local_rank() == 0: + os.remove(signal_file_path) + + # Use the pretrained generation config for the model if it exists. + try: + model.generation_config = GenerationConfig.from_pretrained( + pretrained_model_name_or_path, + use_auth_token=use_auth_token, + ) + except OSError: + log.warning( + f'No existing generation config found for the model with name or path {pretrained_model_name_or_path}. Using default generation config.', + ) + + # Hugging Face's weight tying does not succeed if the model is inited on meta device + # so we manually apply the weight tying here + if model.config.tie_word_embeddings and resolved_init_device == 'meta': + model.tie_weights() + + if pretrained_lora_id_or_path is not None: + if not peft_installed: + raise ValueError( + 'PEFT is not installed, but lora_id_or_path was passed. Please install LLM Foundry with the peft extra to use lora_id_or_path.', + ) + from peft import PeftModelForCausalLM + model = PeftModelForCausalLM.from_pretrained( + model, + pretrained_lora_id_or_path, + ) + + return model + + def get_peft_config(self, peft_config_dict: dict[str, Any]) -> 'PeftConfig': + if peft_installed: + from peft import LoraConfig + peft_type = peft_config_dict.get('peft_type', '') + if peft_type.upper() != 'LORA': + raise ValueError( + f'Only LORA is supported for peft_type, but got {peft_type}.', + ) + task_type = peft_config_dict.get('task_type', '') + if task_type.upper() != 'CAUSAL_LM': + raise ValueError( + f'Only CAUSAL_LM is supported for task_type, but got {task_type}.', + ) + return LoraConfig(**peft_config_dict) + else: + raise ValueError( + 'PEFT is not installed, but peft_config was passed. Please install LLM Foundry with the peft extra to use peft_config.', + ) + + @staticmethod + def prepare_inner_model( + model: Union[transformers.PreTrainedModel, 'PeftModel'], + init_device: Optional[str] = None, + ): + """Prepare the inner model for FSDP wrapping. + + Args: + model: The model to prepare. + init_device: The device to initialize the model on. + """ + # Note: We need to add the FSDP related attributes to the model AFTER the super init, + # so that the (possible) embedding resizing doesn't destroy them + prepare_hf_model_for_fsdp(model, init_device) + + # This provides support for meta initialization when using FSDP + model.param_init_fn = lambda module: model._init_weights(module) diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index fab79cd742..66b8389738 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -4,22 +4,14 @@ """Implements a Hugging Causal LM wrapped inside a :class:`.ComposerModel`.""" import logging -import os -import warnings from typing import ( - TYPE_CHECKING, Any, Optional, Union, ) -from composer.models.huggingface import peft_installed -from composer.utils import dist -from torchmetrics import Metric from transformers import ( - AutoConfig, AutoModelForCausalLM, - GenerationConfig, PreTrainedModel, PreTrainedTokenizerBase, ) @@ -29,21 +21,14 @@ DEFAULT_CAUSAL_LM_EVAL_METRICS, DEFAULT_CAUSAL_LM_TRAIN_METRICS, ) -from llmfoundry.models.hf.hf_fsdp import hf_get_init_device -from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithFSDP -from llmfoundry.models.layers.attention import is_flash_v2_installed -from llmfoundry.models.utils import init_empty_weights -from llmfoundry.utils.config_utils import set_config_overrides - -if TYPE_CHECKING: - from peft import PeftConfig, PeftModel +from llmfoundry.models.hf.hf_base import BaseHuggingFaceModel __all__ = ['ComposerHFCausalLM'] log = logging.getLogger(__name__) -class ComposerHFCausalLM(HuggingFaceModelWithFSDP): +class ComposerHFCausalLM(BaseHuggingFaceModel): """Configures a :class:`.HuggingFaceModel` around a Causal LM. Args: @@ -106,7 +91,6 @@ def __init__( load_in_8bit=load_in_8bit, init_device=init_device, config_overrides=config_overrides, - use_logits=True, shift_labels=True, peft_config=peft_config, allow_embedding_resizing=allow_embedding_resizing, diff --git a/llmfoundry/models/hf/hf_t5.py b/llmfoundry/models/hf/hf_t5.py index 23b826a79f..1188090979 100644 --- a/llmfoundry/models/hf/hf_t5.py +++ b/llmfoundry/models/hf/hf_t5.py @@ -5,26 +5,26 @@ from __future__ import annotations -from typing import Mapping, Optional +from typing import Mapping, Optional, Union, Any -from composer.utils import dist from transformers import ( AutoConfig, + AutoModelForSeq2SeqLM, + PreTrainedModel, + PretrainedConfig, PreTrainedTokenizerBase, - T5ForConditionalGeneration, ) +from transformers.models.auto.auto_factory import _BaseAutoModelClass from llmfoundry.metrics import DEFAULT_ENC_DEC_METRICS -from llmfoundry.models.hf.hf_fsdp import hf_get_init_device -from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithFSDP -from llmfoundry.models.utils import init_empty_weights +from llmfoundry.models.hf.hf_base import BaseHuggingFaceModel from llmfoundry.utils.warnings import experimental_class __all__ = ['ComposerHFT5'] @experimental_class('ComposerHFT5') -class ComposerHFT5(HuggingFaceModelWithFSDP): +class ComposerHFT5(BaseHuggingFaceModel): """Configures a :class:`.HuggingFaceModel` around a T5. Note: This function uses `transformers.T5ForConditionalGeneration`. Future releases @@ -45,23 +45,45 @@ class ComposerHFT5(HuggingFaceModelWithFSDP): tokenizer (PreTrainedTokenizer): The tokenizer that the model will use. """ + model_cls: Union[_BaseAutoModelClass, + PreTrainedModel] = AutoModelForSeq2SeqLM + default_train_metrics: list = DEFAULT_ENC_DEC_METRICS + default_eval_metrics: list = [] + def __init__( self, tokenizer: PreTrainedTokenizerBase, pretrained_model_name_or_path: str, - pretrained: Optional[bool] = True, + pretrained: bool = True, trust_remote_code: bool = True, use_auth_token: bool = False, - config_overrides: Optional[Mapping] = None, + config_overrides: Optional[dict[str, Any]] = None, init_device: str = 'cpu', additional_train_metrics: Optional[list] = None, name: Optional[str] = None, ): - from llmfoundry.utils.builders import build_metric - - config_overrides = config_overrides or {} - additional_train_metrics = additional_train_metrics or [] + super().__init__( + pretrained_model_name_or_path, + tokenizer=tokenizer, + pretrained=pretrained, + trust_remote_code=trust_remote_code, + use_auth_token=use_auth_token, + init_device=init_device, + config_overrides=config_overrides, + shift_labels=True, + additional_train_metrics=additional_train_metrics, + ) + @classmethod + def build_config( + cls, + pretrained_model_name_or_path: str, + trust_remote_code: bool, + use_auth_token: bool, + attn_implementation: str, + config_overrides: dict[str, Any], + **kwargs: Any, + ) -> PretrainedConfig: config = AutoConfig.from_pretrained( pretrained_model_name_or_path, trust_remote_code=trust_remote_code, @@ -92,43 +114,4 @@ def __init__( raise ValueError(f'Model type "hf_t5" currently only supports T5 models ' +\ f'using configs where `is_encoder_decoder` is ``True``.') - # Get the device we want to initialize, and use the - # resolved version to initialize the HF model - resolved_init_device = hf_get_init_device(init_device) - - # We need to have all non-zero local ranks be not-pretrained - # Rank 0 will still be pretrained, and distribute the weights appropriately - if dist.get_local_rank() != 0 and init_device == 'mixed': - pretrained = False - - if resolved_init_device == 'cpu': - if pretrained: - model = T5ForConditionalGeneration.from_pretrained( - pretrained_model_name_or_path, - config=config, - ) - else: - model = T5ForConditionalGeneration(config) - elif resolved_init_device == 'meta': - if pretrained: - raise ValueError( - 'Setting cfg.pretrained=True is not supported when init_device="meta".', - ) - with init_empty_weights(include_buffers=False): - model = T5ForConditionalGeneration(config) - else: - raise ValueError( - f'init_device="{init_device}" must be either "cpu" or "meta".', - ) - - metrics = [ - build_metric(metric, {}) - for metric in DEFAULT_ENC_DEC_METRICS + additional_train_metrics - ] - - super().__init__( - model=model, - tokenizer=tokenizer, - metrics=metrics, - init_device=init_device, - ) + return config diff --git a/llmfoundry/models/hf/model_wrapper.py b/llmfoundry/models/hf/model_wrapper.py index c5fa2f0f80..c86583bf32 100644 --- a/llmfoundry/models/hf/model_wrapper.py +++ b/llmfoundry/models/hf/model_wrapper.py @@ -5,34 +5,18 @@ from __future__ import annotations -import logging -import os -import warnings from collections import UserDict -from typing import TYPE_CHECKING, Mapping, Optional, Union, Any +from typing import TYPE_CHECKING, Mapping, Optional, Union +import warnings import transformers -from composer.models.huggingface import HuggingFaceModel, peft_installed -from composer.utils import dist +from composer.models.huggingface import HuggingFaceModel from torchmetrics import Metric -from transformers import ( - AutoConfig, - AutoModelForCausalLM, - GenerationConfig, - PreTrainedModel, - PreTrainedTokenizerBase, - PretrainedConfig, -) +from transformers import PreTrainedTokenizerBase from transformers.utils.generic import ModelOutput -from transformers.models.auto.auto_factory import _BaseAutoModelClass -from llmfoundry.models.hf.hf_fsdp import ( - prepare_hf_model_for_fsdp, - hf_get_init_device, -) -from llmfoundry.utils.config_utils import set_config_overrides -from llmfoundry.models.layers.attention import is_flash_v2_installed -from llmfoundry.models.utils import init_empty_weights +from llmfoundry.models.hf.hf_fsdp import prepare_hf_model_for_fsdp +from llmfoundry.utils.warnings import VersionedDeprecationWarning if TYPE_CHECKING: from peft import PeftConfig, PeftModel @@ -42,68 +26,35 @@ # HuggingFace hardcodes the ignore index to -100 _HF_IGNORE_INDEX = -100 -log = logging.getLogger(__name__) class HuggingFaceModelWithFSDP(HuggingFaceModel): """Wrapper around HuggingFaceModel. Handles preparation for FSDP wrapping. """ - - model_cls: Union[_BaseAutoModelClass, - PreTrainedModel] = AutoModelForCausalLM - default_train_metrics: list = [] - default_eval_metrics: list = [] def __init__( self, - pretrained_model_name_or_path: str, + model: Union[transformers.PreTrainedModel, 'PeftModel'], tokenizer: Optional[PreTrainedTokenizerBase] = None, - pretrained: bool = True, - pretrained_lora_id_or_path: Optional[str] = None, - trust_remote_code: bool = True, - use_auth_token: bool = False, - use_flash_attention_2: bool = False, - load_in_8bit: bool = False, - init_device: str = 'cpu', - config_overrides: Optional[dict[str, Any]] = None, - use_logits: bool = True, + metrics: Optional[list[Metric]] = None, + eval_metrics: Optional[list[Metric]] = None, shift_labels: bool = False, - peft_config: Optional['PeftConfig'] = None, allow_embedding_resizing: bool = False, - use_train_metrics: bool = True, - additional_train_metrics: Optional[list] = None, - additional_eval_metrics: Optional[list] = None, + init_device: Optional[str] = None, + peft_config: Optional['PeftConfig'] = None, should_save_peft_only: bool = True, ): - config_overrides = config_overrides or {} - - model = self.build_inner_model( - pretrained_model_name_or_path=pretrained_model_name_or_path, - pretrained_lora_id_or_path=pretrained_lora_id_or_path, - trust_remote_code=trust_remote_code, - init_device=init_device, - use_flash_attention_2=use_flash_attention_2, - use_auth_token=use_auth_token, - config_overrides=config_overrides, - load_in_8bit=load_in_8bit, - pretrained=pretrained, + warnings.warn( + VersionedDeprecationWarning( + "`HuggingFaceModelWithFSDP` is deprecated. In the future please use `BaseHuggingFaceModel`.", + remove_version="0.12.0", + ), ) - - model = self.transform_model(model) - - self.prepare_inner_model(model, init_device) - - metrics, eval_metrics = self.build_metrics( - use_train_metrics=use_train_metrics, - additional_train_metrics=additional_train_metrics, - additional_eval_metrics=additional_eval_metrics, - ) - super().__init__( - model=model, - tokenizer=tokenizer, - use_logits=use_logits, + model, + tokenizer, + use_logits=True, metrics=metrics, eval_metrics=eval_metrics, shift_labels=shift_labels, @@ -112,6 +63,8 @@ def __init__( should_save_peft_only=should_save_peft_only, ) + self.prepare_inner_model(self.model, init_device) + def forward(self, batch: Mapping): if isinstance(batch, dict) or isinstance(batch, UserDict): # Further input validation is left to the huggingface forward call @@ -131,285 +84,6 @@ def loss(self, outputs: ModelOutput, batch: Mapping): # loss is at index 0 in the output tuple, logits are at index 1 return outputs[:2] - def transform_model(self, model: PreTrainedModel) -> PreTrainedModel: - """Transforms the model after initialization. - - Args: - model (PreTrainedModel): The model to transform. - - Returns: - PreTrainedModel: The transformed model. - """ - return model - - @classmethod - def build_config( - cls, - pretrained_model_name_or_path: str, - trust_remote_code: bool, - use_auth_token: bool, - attn_implementation: str, - **kwargs: Any, - ) -> PretrainedConfig: - config = AutoConfig.from_pretrained( - pretrained_model_name_or_path, - trust_remote_code=trust_remote_code, - use_auth_token=use_auth_token, - attn_implementation=attn_implementation, - use_cache= - False, # Necessary due to https://github.com/huggingface/transformers/issues/28056 - ) - return config - - @classmethod - def build_metrics( - cls, - use_train_metrics: bool, - additional_train_metrics: Optional[list[str]] = None, - additional_eval_metrics: Optional[list[str]] = None, - ) -> tuple[list[Metric], list[Metric]]: - """Builds the training and evaluation metrics for the model. - - Args: - use_train_metrics (bool): Whether to use training metrics. - additional_train_metrics (Optional[List[str]]): Additional training metrics to include. - additional_eval_metrics (Optional[List[str]]): Additional evaluation metrics to include. - - Returns: - Tuple[List[Metric], List[Metric]]: A tuple containing the list of training metrics and evaluation metrics. - """ - from llmfoundry.utils.builders import build_metric - - train_metric_names = cls.default_train_metrics + ( - additional_train_metrics or [] - ) - train_metrics = [ - build_metric(metric, {}) for metric in train_metric_names - ] if use_train_metrics else [] - eval_metric_names = cls.default_eval_metrics + ( - additional_eval_metrics or [] - ) - eval_metrics = [ - build_metric(metric, {}) for metric in eval_metric_names - ] - - return train_metrics, eval_metrics - - @classmethod - def build_inner_model( - cls, - pretrained_model_name_or_path: str, - pretrained_lora_id_or_path: Optional[str], - trust_remote_code: bool, - init_device: str, - use_flash_attention_2: bool, - use_auth_token: bool, - config_overrides: dict[str, Any], - load_in_8bit: bool, - pretrained: bool, - ) -> Union[PreTrainedModel, 'PeftModel']: - """Builds the inner model for the ComposerHFCausalLM. - - Args: - pretrained_model_name_or_path (str): The pretrained model name or path. - pretrained_lora_id_or_path (Optional[str]): The pretrained LORA ID or path. - trust_remote_code (bool): Whether to trust remote code. - init_device (str): The initialization device. - use_flash_attention_2 (bool): Whether to use flash attention 2. - use_auth_token (bool): Whether to use an authentication token. - config_overrides (Dict[str, Any]): The configuration overrides. - load_in_8bit (bool): Whether to load in 8-bit. - pretrained (bool): Whether the model is pretrained. - - Returns: - Union[PreTrainedModel, 'PeftModel']: The built inner model. - """ - if not trust_remote_code and pretrained_model_name_or_path.startswith( - 'mosaicml/mpt', - ): - raise ValueError( - 'trust_remote_code must be set to True for MPT models. Without this, the MPT model code will come from the transformers library, ' - + - 'which is significantly slower and not compatible with the LLM foundry training code, rather than the code release by MosaicML.', - ) - # Resolve "mixed" init device to either "cpu" or "meta" - resolved_init_device = hf_get_init_device(init_device) - requested_attention_implementation = 'flash_attention_2' if use_flash_attention_2 else 'eager' - - if use_flash_attention_2 and not is_flash_v2_installed(): - raise ValueError( - 'use_flash_attention_2 is set to True, but flash-attention 2 is not installed. ' - + 'Please `pip install llm-foundry[gpu]`.', - ) - - if not ( - hasattr(cls.model_cls, 'from_pretrained') and - hasattr(cls.model_cls, 'from_config') - ): - raise AttributeError( - f'{cls.model_cls=} is missing `from_pretrained` and `from_config` support.', - ) - - # Hugging Face copies the modules into the - # transformers modules cache. On particular systems, this operation seems to cause contention between - # the different processes. To avoid this contention, we first create the config and generation config on local rank - # zero. This will set up the transformers module cache and avoid the future contention. - if dist.get_local_rank() == 0: - AutoConfig.from_pretrained( - pretrained_model_name_or_path, - trust_remote_code=trust_remote_code, - use_auth_token=use_auth_token, - attn_implementation=requested_attention_implementation, - use_cache= - False, # Necessary due to https://github.com/huggingface/transformers/issues/28056 - ) - try: - GenerationConfig.from_pretrained( - pretrained_model_name_or_path, - use_auth_token=use_auth_token, - ) - except OSError: - pass - - dist.barrier() - - # Construct the Hugging Face config to use - config = cls.build_config( - pretrained_model_name_or_path, - trust_remote_code=trust_remote_code, - use_auth_token=use_auth_token, - attn_implementation=requested_attention_implementation, - ) - - set_config_overrides(config, config_overrides) - - # We need to have all non-zero local ranks be not-pretrained - # Rank 0 will still be pretrained, and distribute the weights appropriately - if dist.get_local_rank() != 0 and init_device == 'mixed': - pretrained = False - - # Hugging Face copies the modules into the - # transformers modules cache. On particular systems, this operation seems to cause contention between - # the different processes. To avoid this contention, we first create the model (on meta device) on local rank - # zero. This will set up the transformers model cache and avoid the future contention. - if dist.get_local_rank() == 0: - if pretrained and os.path.isdir(pretrained_model_name_or_path): - with init_empty_weights(include_buffers=False): - with warnings.catch_warnings(): - warnings.simplefilter('ignore', UserWarning) - cls.model_cls.from_pretrained( - pretrained_model_name_or_path, - trust_remote_code=trust_remote_code, - use_auth_token=use_auth_token, - attn_implementation= - requested_attention_implementation, - config=config, - ) - else: - with init_empty_weights(include_buffers=False): - cls.model_cls.from_config( - config, - trust_remote_code=trust_remote_code, - attn_implementation=requested_attention_implementation, - ) - - dist.barrier() - - # initialize the model on the correct device - if resolved_init_device == 'cpu': - if pretrained: - model = cls.model_cls.from_pretrained( - pretrained_model_name_or_path, - trust_remote_code=trust_remote_code, - use_auth_token=use_auth_token, - load_in_8bit=load_in_8bit, - attn_implementation=requested_attention_implementation, - config=config, - ) - else: - model = cls.model_cls.from_config( - config, - trust_remote_code=trust_remote_code, - attn_implementation=requested_attention_implementation, - ) - elif resolved_init_device == 'meta': - if pretrained: - raise ValueError( - 'Setting cfg.pretrained=True is not supported when init_device="meta".', - ) - with init_empty_weights(include_buffers=False): - model = cls.model_cls.from_config( - config, - trust_remote_code=trust_remote_code, - attn_implementation=requested_attention_implementation, - ) - else: - raise ValueError( - f'init_device="{init_device}" must be either "cpu" or "meta".', - ) - - signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_completed' - if dist.get_local_rank() == 0: - with open(signal_file_path, 'wb') as f: - f.write(b'local_rank0_completed_download') - - # Avoid the collective call until the local rank zero has finished trying to download the checkpoint - # so that we don't timeout for large downloads. This syncs all processes on the node - with dist.local_rank_zero_download_and_wait(signal_file_path): - # Then, wait to ensure every node has finished downloading the checkpoint - dist.barrier() - - if dist.get_local_rank() == 0: - os.remove(signal_file_path) - - # Use the pretrained generation config for the model if it exists. - try: - model.generation_config = GenerationConfig.from_pretrained( - pretrained_model_name_or_path, - use_auth_token=use_auth_token, - ) - except OSError: - log.warning( - f'No existing generation config found for the model with name or path {pretrained_model_name_or_path}. Using default generation config.', - ) - - # Hugging Face's weight tying does not succeed if the model is inited on meta device - # so we manually apply the weight tying here - if model.config.tie_word_embeddings and resolved_init_device == 'meta': - model.tie_weights() - - if pretrained_lora_id_or_path is not None: - if not peft_installed: - raise ValueError( - 'PEFT is not installed, but lora_id_or_path was passed. Please install LLM Foundry with the peft extra to use lora_id_or_path.', - ) - from peft import PeftModelForCausalLM - model = PeftModelForCausalLM.from_pretrained( - model, - pretrained_lora_id_or_path, - ) - - return model - - def get_peft_config(self, peft_config_dict: dict[str, Any]) -> 'PeftConfig': - if peft_installed: - from peft import LoraConfig - peft_type = peft_config_dict.get('peft_type', '') - if peft_type.upper() != 'LORA': - raise ValueError( - f'Only LORA is supported for peft_type, but got {peft_type}.', - ) - task_type = peft_config_dict.get('task_type', '') - if task_type.upper() != 'CAUSAL_LM': - raise ValueError( - f'Only CAUSAL_LM is supported for task_type, but got {task_type}.', - ) - return LoraConfig(**peft_config_dict) - else: - raise ValueError( - 'PEFT is not installed, but peft_config was passed. Please install LLM Foundry with the peft extra to use peft_config.', - ) - @staticmethod def prepare_inner_model( model: Union[transformers.PreTrainedModel, 'PeftModel'],