Skip to content

Commit

Permalink
Make tensor device correct when ACCELERATE_TORCH_DEVICE is defined (#…
Browse files Browse the repository at this point in the history
…31751)

return correct device when ACCELERATE_TORCH_DEVICE is defined
  • Loading branch information
kiszk authored Jul 5, 2024
1 parent 8c5c180 commit 2aa2a14
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2194,7 +2194,9 @@ def _setup_devices(self) -> "torch.device":
# trigger an error that a device index is missing. Index 0 takes into account the
# GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0`
# will use the first GPU in that env, i.e. GPU#1
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device(
"cuda:0" if torch.cuda.is_available() else os.environ.get("ACCELERATE_TORCH_DEVICE", "cpu")
)
# Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at
# the default value.
self._n_gpu = torch.cuda.device_count()
Expand Down

0 comments on commit 2aa2a14

Please sign in to comment.