diff --git a/src/axolotl/train.py b/src/axolotl/train.py index da98600a45..80acddb9c5 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -12,6 +12,7 @@ import transformers.modelcard from datasets import Dataset from optimum.bettertransformer import BetterTransformer +from transformers.deepspeed import is_deepspeed_zero3_enabled from axolotl.common.cli import TrainerCliArgs from axolotl.logging_config import configure_logging @@ -134,6 +135,22 @@ def terminate_handler(_, __, model): # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file if cfg.fsdp: trainer.save_model(cfg.output_dir) + elif cfg.deepspeed and is_deepspeed_zero3_enabled(): + # Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading + trainer.accelerator.wait_for_everyone() + unwrapped_model = trainer.accelerator.unwrap_model(trainer.model_wrapped) + + # Saves the whole/unpartitioned fp16 model when in ZeRO Stage-3 to the output directory if + # `stage3_gather_16bit_weights_on_model_save` is True in DeepSpeed Config file or + # `zero3_save_16bit_model` is True in DeepSpeed Plugin. + # For Zero Stages 1 and 2, models are saved as usual in the output directory. + # The model name saved is `pytorch_model.bin` + unwrapped_model.save_pretrained( + cfg.output_dir, + is_main_process=trainer.accelerator.is_main_process, + save_function=trainer.accelerator.save, + state_dict=trainer.accelerator.get_state_dict(trainer.model_wrapped), + ) elif cfg.local_rank == 0: if cfg.flash_optimum: model = BetterTransformer.reverse(model)