-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[docs] MPS #28016
[docs] MPS #28016
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -13,12 +13,51 @@ rendered properly in your Markdown viewer. | |||
|
||||
--> | ||||
|
||||
# Training on Specialized Hardware | ||||
# PyTorch training on Apple silicon | ||||
|
||||
<Tip> | ||||
Previously, training models on a Mac was limited to the CPU only. With the release of PyTorch v1.12, you can take advantage of training models with Apple's silicon GPUs for significantly faster performance and training. This is powered in PyTorch by integrating Apple's Metal Performance Shaders (MPS) as a backend. The [MPS backend](https://pytorch.org/docs/stable/notes/mps.html) implements PyTorch operations as custom Metal shaders and places these modules on a `mps` device. | ||||
|
||||
Note: Most of the strategies introduced in the [single GPU section](perf_train_gpu_one) (such as mixed precision training or gradient accumulation) and [multi-GPU section](perf_train_gpu_many) are generic and apply to training models in general so make sure to have a look at it before diving into this section. | ||||
<Tip warning={true}> | ||||
|
||||
Some PyTorch operations are not implemented in MPS yet and will throw an error. To avoid this, you should set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU kernels instead (you'll still see a `UserWarning`). | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a way to have trainer just use the CPU entirely and ignore the MPS backend? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you can set
|
||||
|
||||
<br> | ||||
|
||||
If you run into any other errors, please open an issue in the [PyTorch](https://github.com/pytorch/pytorch/issues) repository because the [`Trainer`] only integrates the MPS backend. | ||||
|
||||
</Tip> | ||||
|
||||
This document will be completed soon with information on how to train on specialized hardware. | ||||
With the `mps` device set, you can: | ||||
|
||||
* train larger networks or batch sizes locally | ||||
* reduce data retrieval latency because the GPU's unified memory architecture allows direct access to the full memory store | ||||
* reduce costs because you don't need to train on cloud-based GPUs or add additional local GPUs | ||||
|
||||
Get started by making sure you have PyTorch installed. MPS acceleration is supported on macOS 12.3+. | ||||
|
||||
```bash | ||||
pip install torch torchvision torchaudio | ||||
``` | ||||
|
||||
[`TrainingArguments`] uses the `mps` device by default if it's available which means you don't need to explicitly set the device. For example, you can run the [run_glue.py](https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification/run_glue.py) script with the MPS backend automatically enabled without making any changes. | ||||
|
||||
```diff | ||||
export TASK_NAME=mrpc | ||||
|
||||
python examples/pytorch/text-classification/run_glue.py \ | ||||
--model_name_or_path bert-base-cased \ | ||||
--task_name $TASK_NAME \ | ||||
- --use_mps_device \ | ||||
--do_train \ | ||||
--do_eval \ | ||||
--max_seq_length 128 \ | ||||
--per_device_train_batch_size 32 \ | ||||
--learning_rate 2e-5 \ | ||||
--num_train_epochs 3 \ | ||||
--output_dir /tmp/$TASK_NAME/ \ | ||||
--overwrite_output_dir | ||||
``` | ||||
|
||||
Backends for [distributed setups](https://pytorch.org/docs/stable/distributed.html#backends) like `gloo` and `nccl` are not supported by the `mps` device which means you can only train on a single GPU with the MPS backend. | ||||
|
||||
You can learn more about the MPS backend in the [Introducing Accelerated PyTorch Training on Mac](https://pytorch.org/blog/introducing-accelerated-pytorch-training-on-mac/) blog post. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this no longer the case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe its still true, I didn't see
mps
among the supported backends fortorch.distributed
(included in the second to last paragraph of the new doc)