diff --git a/nnunetv2/run/run_training.py b/nnunetv2/run/run_training.py index 93dd7598b..d5b1dd827 100644 --- a/nnunetv2/run/run_training.py +++ b/nnunetv2/run/run_training.py @@ -249,9 +249,17 @@ def run_training_entry(): help="Use this to set the device the training should run with. Available options are 'cuda' " "(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID! " "Use CUDA_VISIBLE_DEVICES=X nnUNetv2_train [...] instead!") + parser.add_argument('--gpu_index', type=int, default=None, required=False, + help="Index of the GPU to use.") args = parser.parse_args() assert args.device in ['cpu', 'cuda', 'mps'], f'-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {args.device}.' + if args.gpu_index is not None: + assert args.device == 'cuda', 'The GPU index can only be selected when running on a GPU' + + assert args.gpu_index < torch.cuda.device_count(), (f'Specified gpu index {args.gpu_index} is out of range of ' + f'available GPUs {torch.cuda.device_count()}') + if args.device == 'cpu': # let's allow torch to use hella threads import multiprocessing @@ -261,7 +269,7 @@ def run_training_entry(): # multithreading in torch doesn't help nnU-Net if run on GPU torch.set_num_threads(1) torch.set_num_interop_threads(1) - device = torch.device('cuda') + device = torch.device('cuda:'+str(args.gpu_index)) else: device = torch.device('mps') diff --git a/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py b/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py index 97abdde0a..98a3da741 100644 --- a/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py +++ b/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py @@ -97,11 +97,8 @@ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dic f"{dist.get_world_size()}." f"Setting device to {self.device}") self.device = torch.device(type='cuda', index=self.local_rank) - else: - if self.device.type == 'cuda': - # we might want to let the user pick this but for now please pick the correct GPU with CUDA_VISIBLE_DEVICES=X - self.device = torch.device(type='cuda', index=0) - print(f"Using device: {self.device}") + + print(f"Using device: {self.device}") # loading and saving this class for continuing from checkpoint should not happen based on pickling. This # would also pickle the network etc. Bad, bad. Instead we just reinstantiate and then load the checkpoint we