Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable user to specify torch device in run_training #1943

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion nnunetv2/run/run_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,17 @@ def run_training_entry():
help="Use this to set the device the training should run with. Available options are 'cuda' "
"(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID! "
"Use CUDA_VISIBLE_DEVICES=X nnUNetv2_train [...] instead!")
parser.add_argument('--gpu_index', type=int, default=None, required=False,
help="Index of the GPU to use.")
args = parser.parse_args()

assert args.device in ['cpu', 'cuda', 'mps'], f'-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {args.device}.'
if args.gpu_index is not None:
assert args.device == 'cuda', 'The GPU index can only be selected when running on a GPU'

assert args.gpu_index < torch.cuda.device_count(), (f'Specified gpu index {args.gpu_index} is out of range of '
f'available GPUs {torch.cuda.device_count()}')

if args.device == 'cpu':
# let's allow torch to use hella threads
import multiprocessing
Expand All @@ -261,7 +269,7 @@ def run_training_entry():
# multithreading in torch doesn't help nnU-Net if run on GPU
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
device = torch.device('cuda')
device = torch.device('cuda:'+str(args.gpu_index))
else:
device = torch.device('mps')

Expand Down
7 changes: 2 additions & 5 deletions nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,8 @@ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dic
f"{dist.get_world_size()}."
f"Setting device to {self.device}")
self.device = torch.device(type='cuda', index=self.local_rank)
else:
if self.device.type == 'cuda':
# we might want to let the user pick this but for now please pick the correct GPU with CUDA_VISIBLE_DEVICES=X
self.device = torch.device(type='cuda', index=0)
print(f"Using device: {self.device}")

print(f"Using device: {self.device}")

# loading and saving this class for continuing from checkpoint should not happen based on pickling. This
# would also pickle the network etc. Bad, bad. Instead we just reinstantiate and then load the checkpoint we
Expand Down