diff --git a/docs/source/sft_trainer.mdx b/docs/source/sft_trainer.mdx index 8cf6dbc28a..fce38cce29 100644 --- a/docs/source/sft_trainer.mdx +++ b/docs/source/sft_trainer.mdx @@ -582,6 +582,20 @@ Pay attention to the following best practices when training a model with that tr - For a more memory-efficient training using adapters, you can load the base model in 8bit, for that simply add `load_in_8bit` argument when creating the [`SFTTrainer`], or create a base model in 8bit outside the trainer and pass it. - If you create a model outside the trainer, make sure to not pass to the trainer any additional keyword arguments that are relative to `from_pretrained()` method. +## Multi-GPU Training + +Trainer (and thus SFTTrainer) supports multi-GPU training. If you run your script with `python script.py` it will default to using DP as the strategy, which may be [slower than expected](https://github.com/huggingface/trl/issues/1303). To use DDP (which is generally recommended, see [here](https://huggingface.co/docs/transformers/en/perf_train_gpu_many?select-gpu=Accelerate#data-parallelism) for more info) you must launch the script with `python -m torch.distributed.launch script.py` or `accelerate launch script.py`. For DDP to work you must also check the following: +- If you're using gradient_checkpointing, add the following to the TrainingArguments: `gradient_checkpointing_kwargs={'use_reentrant':False}` (more info [here](https://github.com/huggingface/transformers/issues/26969) +- Ensure that the model is placed on the correct device: +```python +from accelerate import PartialState +device_string = PartialState().process_index +model = AutoModelForCausalLM.from_pretrained( + ... + device_map={'':device_string} +) +``` + ## GPTQ Conversion You may experience some issues with GPTQ Quantization after completing training. Lowering `gradient_accumulation_steps` to `4` will resolve most issues during the quantization process to GPTQ format.