diff --git a/nnunetv2/training/nnUNetTrainer/nnUNetTrainer_Misalign.py b/nnunetv2/training/nnUNetTrainer/nnUNetTrainer_Misalign.py index 21ad286..daf06b3 100644 --- a/nnunetv2/training/nnUNetTrainer/nnUNetTrainer_Misalign.py +++ b/nnunetv2/training/nnUNetTrainer/nnUNetTrainer_Misalign.py @@ -70,14 +70,13 @@ class nnUNetTrainer_Misalign(nnUNetTrainer): def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, device: torch.device = torch.device('cuda')): super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) - self.num_epochs = 3 - self.print_to_log_file("\n#######################################################################\n" - "Please cite the following paper when using nnU-Net:\n" - "Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). " - "nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. " - "Nature methods, 18(2), 203-211.\n" + "Please cite the following paper when using misalignment augmentations:\n" + "Kovacs, Balint, et al.\n" + "Addressing image misalignments in multi-parametric prostate MRI\n" + "for enhanced computer-aided diagnosis of prostate cancer.\n" + "Scientific Reports 13.1 (2023): 19805.\n" "#######################################################################\n", also_print_to_console=True, add_timestamp=False) @@ -189,506 +188,3 @@ def get_training_transforms(patch_size: Union[np.ndarray, Tuple[int]], tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) tr_transforms = Compose(tr_transforms) return tr_transforms - - @staticmethod - def get_validation_transforms(deep_supervision_scales: Union[List, Tuple], - is_cascaded: bool = False, - foreground_labels: Union[Tuple[int, ...], List[int]] = None, - regions: List[Union[List[int], Tuple[int, ...], int]] = None, - ignore_label: int = None) -> AbstractTransform: - val_transforms = [] - val_transforms.append(RemoveLabelTransform(-1, 0)) - - if is_cascaded: - val_transforms.append(MoveSegAsOneHotToData(1, foreground_labels, 'seg', 'data')) - - val_transforms.append(RenameTransform('seg', 'target', True)) - - if regions is not None: - # the ignore label must also be converted - val_transforms.append(ConvertSegmentationToRegionsTransform(list(regions) + [ignore_label] - if ignore_label is not None else regions, - 'target', 'target')) - - if deep_supervision_scales is not None: - val_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target', - output_key='target')) - - val_transforms.append(NumpyToTensor(['data', 'target'], 'float')) - val_transforms = Compose(val_transforms) - return val_transforms - - def set_deep_supervision_enabled(self, enabled: bool): - """ - This function is specific for the default architecture in nnU-Net. If you change the architecture, there are - chances you need to change this as well! - """ - if self.is_ddp: - self.network.module.decoder.deep_supervision = enabled - else: - self.network.decoder.deep_supervision = enabled - - def on_train_start(self): - if not self.was_initialized: - self.initialize() - - maybe_mkdir_p(self.output_folder) - - # make sure deep supervision is on in the network - self.set_deep_supervision_enabled(True) - - self.print_plans() - empty_cache(self.device) - - # maybe unpack - if self.unpack_dataset and self.local_rank == 0: - self.print_to_log_file('unpacking dataset...') - unpack_dataset(self.preprocessed_dataset_folder, unpack_segmentation=True, overwrite_existing=False, - num_processes=max(1, round(get_allowed_n_proc_DA() // 2))) - self.print_to_log_file('unpacking done...') - - if self.is_ddp: - dist.barrier() - - # dataloaders must be instantiated here because they need access to the training data which may not be present - # when doing inference - self.dataloader_train, self.dataloader_val = self.get_dataloaders() - - # copy plans and dataset.json so that they can be used for restoring everything we need for inference - save_json(self.plans_manager.plans, join(self.output_folder_base, 'plans.json'), sort_keys=False) - save_json(self.dataset_json, join(self.output_folder_base, 'dataset.json'), sort_keys=False) - - # we don't really need the fingerprint but its still handy to have it with the others - shutil.copy(join(self.preprocessed_dataset_folder_base, 'dataset_fingerprint.json'), - join(self.output_folder_base, 'dataset_fingerprint.json')) - - # produces a pdf in output folder - self.plot_network_architecture() - - self._save_debug_information() - - # print(f"batch size: {self.batch_size}") - # print(f"oversample: {self.oversample_foreground_percent}") - - def on_train_end(self): - # dirty hack because on_epoch_end increments the epoch counter and this is executed afterwards. - # This will lead to the wrong current epoch to be stored - self.current_epoch -= 1 - self.save_checkpoint(join(self.output_folder, "checkpoint_final.pth")) - self.current_epoch += 1 - - # now we can delete latest - if self.local_rank == 0 and isfile(join(self.output_folder, "checkpoint_latest.pth")): - os.remove(join(self.output_folder, "checkpoint_latest.pth")) - - # shut down dataloaders - old_stdout = sys.stdout - with open(os.devnull, 'w') as f: - sys.stdout = f - if self.dataloader_train is not None: - self.dataloader_train._finish() - if self.dataloader_val is not None: - self.dataloader_val._finish() - sys.stdout = old_stdout - - empty_cache(self.device) - self.print_to_log_file("Training done.") - - def on_train_epoch_start(self): - self.network.train() - self.lr_scheduler.step(self.current_epoch) - self.print_to_log_file('') - self.print_to_log_file(f'Epoch {self.current_epoch}') - self.print_to_log_file( - f"Current learning rate: {np.round(self.optimizer.param_groups[0]['lr'], decimals=5)}") - # lrs are the same for all workers so we don't need to gather them in case of DDP training - self.logger.log('lrs', self.optimizer.param_groups[0]['lr'], self.current_epoch) - - def train_step(self, batch: dict) -> dict: - data = batch['data'] - target = batch['target'] - - data = data.to(self.device, non_blocking=True) - if isinstance(target, list): - target = [i.to(self.device, non_blocking=True) for i in target] - else: - target = target.to(self.device, non_blocking=True) - - self.optimizer.zero_grad(set_to_none=True) - # Autocast is a little bitch. - # If the device_type is 'cpu' then it's slow as heck and needs to be disabled. - # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False) - # So autocast will only be active if we have a cuda device. - with autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context(): - output = self.network(data) - # del data - l = self.loss(output, target) - - if self.grad_scaler is not None: - self.grad_scaler.scale(l).backward() - self.grad_scaler.unscale_(self.optimizer) - torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12) - self.grad_scaler.step(self.optimizer) - self.grad_scaler.update() - else: - l.backward() - torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12) - self.optimizer.step() - return {'loss': l.detach().cpu().numpy()} - - def on_train_epoch_end(self, train_outputs: List[dict]): - outputs = collate_outputs(train_outputs) - - if self.is_ddp: - losses_tr = [None for _ in range(dist.get_world_size())] - dist.all_gather_object(losses_tr, outputs['loss']) - loss_here = np.vstack(losses_tr).mean() - else: - loss_here = np.mean(outputs['loss']) - - self.logger.log('train_losses', loss_here, self.current_epoch) - - def on_validation_epoch_start(self): - self.network.eval() - - def validation_step(self, batch: dict) -> dict: - data = batch['data'] - target = batch['target'] - - data = data.to(self.device, non_blocking=True) - if isinstance(target, list): - target = [i.to(self.device, non_blocking=True) for i in target] - else: - target = target.to(self.device, non_blocking=True) - - # Autocast is a little bitch. - # If the device_type is 'cpu' then it's slow as heck and needs to be disabled. - # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False) - # So autocast will only be active if we have a cuda device. - with autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context(): - output = self.network(data) - del data - l = self.loss(output, target) - - # we only need the output with the highest output resolution - output = output[0] - target = target[0] - - # the following is needed for online evaluation. Fake dice (green line) - axes = [0] + list(range(2, output.ndim)) - - if self.label_manager.has_regions: - predicted_segmentation_onehot = (torch.sigmoid(output) > 0.5).long() - else: - # no need for softmax - output_seg = output.argmax(1)[:, None] - predicted_segmentation_onehot = torch.zeros(output.shape, device=output.device, dtype=torch.float32) - predicted_segmentation_onehot.scatter_(1, output_seg, 1) - del output_seg - - if self.label_manager.has_ignore_label: - if not self.label_manager.has_regions: - mask = (target != self.label_manager.ignore_label).float() - # CAREFUL that you don't rely on target after this line! - target[target == self.label_manager.ignore_label] = 0 - else: - mask = 1 - target[:, -1:] - # CAREFUL that you don't rely on target after this line! - target = target[:, :-1] - else: - mask = None - - tp, fp, fn, _ = get_tp_fp_fn_tn(predicted_segmentation_onehot, target, axes=axes, mask=mask) - - tp_hard = tp.detach().cpu().numpy() - fp_hard = fp.detach().cpu().numpy() - fn_hard = fn.detach().cpu().numpy() - if not self.label_manager.has_regions: - # if we train with regions all segmentation heads predict some kind of foreground. In conventional - # (softmax training) there needs tobe one output for the background. We are not interested in the - # background Dice - # [1:] in order to remove background - tp_hard = tp_hard[1:] - fp_hard = fp_hard[1:] - fn_hard = fn_hard[1:] - - return {'loss': l.detach().cpu().numpy(), 'tp_hard': tp_hard, 'fp_hard': fp_hard, 'fn_hard': fn_hard} - - def on_validation_epoch_end(self, val_outputs: List[dict]): - outputs_collated = collate_outputs(val_outputs) - tp = np.sum(outputs_collated['tp_hard'], 0) - fp = np.sum(outputs_collated['fp_hard'], 0) - fn = np.sum(outputs_collated['fn_hard'], 0) - - if self.is_ddp: - world_size = dist.get_world_size() - - tps = [None for _ in range(world_size)] - dist.all_gather_object(tps, tp) - tp = np.vstack([i[None] for i in tps]).sum(0) - - fps = [None for _ in range(world_size)] - dist.all_gather_object(fps, fp) - fp = np.vstack([i[None] for i in fps]).sum(0) - - fns = [None for _ in range(world_size)] - dist.all_gather_object(fns, fn) - fn = np.vstack([i[None] for i in fns]).sum(0) - - losses_val = [None for _ in range(world_size)] - dist.all_gather_object(losses_val, outputs_collated['loss']) - loss_here = np.vstack(losses_val).mean() - else: - loss_here = np.mean(outputs_collated['loss']) - - global_dc_per_class = [i for i in [2 * i / (2 * i + j + k) for i, j, k in - zip(tp, fp, fn)]] - mean_fg_dice = np.nanmean(global_dc_per_class) - self.logger.log('mean_fg_dice', mean_fg_dice, self.current_epoch) - self.logger.log('dice_per_class_or_region', global_dc_per_class, self.current_epoch) - self.logger.log('val_losses', loss_here, self.current_epoch) - - def on_epoch_start(self): - self.logger.log('epoch_start_timestamps', time(), self.current_epoch) - - def on_epoch_end(self): - self.logger.log('epoch_end_timestamps', time(), self.current_epoch) - - # todo find a solution for this stupid shit - self.print_to_log_file('train_loss', np.round(self.logger.my_fantastic_logging['train_losses'][-1], decimals=4)) - self.print_to_log_file('val_loss', np.round(self.logger.my_fantastic_logging['val_losses'][-1], decimals=4)) - self.print_to_log_file('Pseudo dice', [np.round(i, decimals=4) for i in - self.logger.my_fantastic_logging['dice_per_class_or_region'][-1]]) - self.print_to_log_file( - f"Epoch time: {np.round(self.logger.my_fantastic_logging['epoch_end_timestamps'][-1] - self.logger.my_fantastic_logging['epoch_start_timestamps'][-1], decimals=2)} s") - - # handling periodic checkpointing - current_epoch = self.current_epoch - if (current_epoch + 1) % self.save_every == 0 and current_epoch != (self.num_epochs - 1): - self.save_checkpoint(join(self.output_folder, 'checkpoint_latest.pth')) - - # handle 'best' checkpointing. ema_fg_dice is computed by the logger and can be accessed like this - if self._best_ema is None or self.logger.my_fantastic_logging['ema_fg_dice'][-1] > self._best_ema: - self._best_ema = self.logger.my_fantastic_logging['ema_fg_dice'][-1] - self.print_to_log_file(f"Yayy! New best EMA pseudo Dice: {np.round(self._best_ema, decimals=4)}") - self.save_checkpoint(join(self.output_folder, 'checkpoint_best.pth')) - - if self.local_rank == 0: - self.logger.plot_progress_png(self.output_folder) - - self.current_epoch += 1 - - def save_checkpoint(self, filename: str) -> None: - if self.local_rank == 0: - if not self.disable_checkpointing: - if self.is_ddp: - mod = self.network.module - else: - mod = self.network - if isinstance(mod, OptimizedModule): - mod = mod._orig_mod - - checkpoint = { - 'network_weights': mod.state_dict(), - 'optimizer_state': self.optimizer.state_dict(), - 'grad_scaler_state': self.grad_scaler.state_dict() if self.grad_scaler is not None else None, - 'logging': self.logger.get_checkpoint(), - '_best_ema': self._best_ema, - 'current_epoch': self.current_epoch + 1, - 'init_args': self.my_init_kwargs, - 'trainer_name': self.__class__.__name__, - 'inference_allowed_mirroring_axes': self.inference_allowed_mirroring_axes, - } - torch.save(checkpoint, filename) - else: - self.print_to_log_file('No checkpoint written, checkpointing is disabled') - - def load_checkpoint(self, filename_or_checkpoint: Union[dict, str]) -> None: - if not self.was_initialized: - self.initialize() - - if isinstance(filename_or_checkpoint, str): - checkpoint = torch.load(filename_or_checkpoint, map_location=self.device) - # if state dict comes from nn.DataParallel but we use non-parallel model here then the state dict keys do not - # match. Use heuristic to make it match - new_state_dict = {} - for k, value in checkpoint['network_weights'].items(): - key = k - if key not in self.network.state_dict().keys() and key.startswith('module.'): - key = key[7:] - new_state_dict[key] = value - - self.my_init_kwargs = checkpoint['init_args'] - self.current_epoch = checkpoint['current_epoch'] - self.logger.load_checkpoint(checkpoint['logging']) - self._best_ema = checkpoint['_best_ema'] - self.inference_allowed_mirroring_axes = checkpoint[ - 'inference_allowed_mirroring_axes'] if 'inference_allowed_mirroring_axes' in checkpoint.keys() else self.inference_allowed_mirroring_axes - - # messing with state dict naming schemes. Facepalm. - if self.is_ddp: - if isinstance(self.network.module, OptimizedModule): - self.network.module._orig_mod.load_state_dict(new_state_dict) - else: - self.network.module.load_state_dict(new_state_dict) - else: - if isinstance(self.network, OptimizedModule): - self.network._orig_mod.load_state_dict(new_state_dict) - else: - self.network.load_state_dict(new_state_dict) - self.optimizer.load_state_dict(checkpoint['optimizer_state']) - if self.grad_scaler is not None: - if checkpoint['grad_scaler_state'] is not None: - self.grad_scaler.load_state_dict(checkpoint['grad_scaler_state']) - - def perform_actual_validation(self, save_probabilities: bool = False): - self.set_deep_supervision_enabled(False) - self.network.eval() - - predictor = nnUNetPredictor(tile_step_size=0.5, use_gaussian=True, use_mirroring=True, - perform_everything_on_gpu=True, device=self.device, verbose=False, - verbose_preprocessing=False, allow_tqdm=False) - predictor.manual_initialization(self.network, self.plans_manager, self.configuration_manager, None, - self.dataset_json, self.__class__.__name__, - self.inference_allowed_mirroring_axes) - - with multiprocessing.get_context("spawn").Pool(default_num_processes) as segmentation_export_pool: - worker_list = [i for i in segmentation_export_pool._pool] - validation_output_folder = join(self.output_folder, 'validation') - maybe_mkdir_p(validation_output_folder) - - # we cannot use self.get_tr_and_val_datasets() here because we might be DDP and then we have to distribute - # the validation keys across the workers. - _, val_keys = self.do_split() - if self.is_ddp: - val_keys = val_keys[self.local_rank:: dist.get_world_size()] - - dataset_val = nnUNetDataset(self.preprocessed_dataset_folder, val_keys, - folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage, - num_images_properties_loading_threshold=0) - - next_stages = self.configuration_manager.next_stage_names - - if next_stages is not None: - _ = [maybe_mkdir_p(join(self.output_folder_base, 'predicted_next_stage', n)) for n in next_stages] - - results = [] - - for k in dataset_val.keys(): - proceed = not check_workers_alive_and_busy(segmentation_export_pool, worker_list, results, - allowed_num_queued=2) - while not proceed: - sleep(0.1) - proceed = not check_workers_alive_and_busy(segmentation_export_pool, worker_list, results, - allowed_num_queued=2) - - self.print_to_log_file(f"predicting {k}") - data, seg, properties = dataset_val.load_case(k) - - if self.is_cascaded: - data = np.vstack((data, convert_labelmap_to_one_hot(seg[-1], self.label_manager.foreground_labels, - output_dtype=data.dtype))) - with warnings.catch_warnings(): - # ignore 'The given NumPy array is not writable' warning - warnings.simplefilter("ignore") - data = torch.from_numpy(data) - - output_filename_truncated = join(validation_output_folder, k) - - try: - prediction = predictor.predict_sliding_window_return_logits(data) - except RuntimeError: - predictor.perform_everything_on_gpu = False - prediction = predictor.predict_sliding_window_return_logits(data) - predictor.perform_everything_on_gpu = True - - prediction = prediction.cpu() - - # this needs to go into background processes - results.append( - segmentation_export_pool.starmap_async( - export_prediction_from_logits, ( - (prediction, properties, self.configuration_manager, self.plans_manager, - self.dataset_json, output_filename_truncated, save_probabilities), - ) - ) - ) - # for debug purposes - # export_prediction(prediction_for_export, properties, self.configuration, self.plans, self.dataset_json, - # output_filename_truncated, save_probabilities) - - # if needed, export the softmax prediction for the next stage - if next_stages is not None: - for n in next_stages: - next_stage_config_manager = self.plans_manager.get_configuration(n) - expected_preprocessed_folder = join(nnUNet_preprocessed, self.plans_manager.dataset_name, - next_stage_config_manager.data_identifier) - - try: - # we do this so that we can use load_case and do not have to hard code how loading training cases is implemented - tmp = nnUNetDataset(expected_preprocessed_folder, [k], - num_images_properties_loading_threshold=0) - d, s, p = tmp.load_case(k) - except FileNotFoundError: - self.print_to_log_file( - f"Predicting next stage {n} failed for case {k} because the preprocessed file is missing! " - f"Run the preprocessing for this configuration first!") - continue - - target_shape = d.shape[1:] - output_folder = join(self.output_folder_base, 'predicted_next_stage', n) - output_file = join(output_folder, k + '.npz') - - # resample_and_save(prediction, target_shape, output_file, self.plans_manager, self.configuration_manager, properties, - # self.dataset_json) - results.append(segmentation_export_pool.starmap_async( - resample_and_save, ( - (prediction, target_shape, output_file, self.plans_manager, - self.configuration_manager, - properties, - self.dataset_json), - ) - )) - - _ = [r.get() for r in results] - - if self.is_ddp: - dist.barrier() - - if self.local_rank == 0: - metrics = compute_metrics_on_folder(join(self.preprocessed_dataset_folder_base, 'gt_segmentations'), - validation_output_folder, - join(validation_output_folder, 'summary.json'), - self.plans_manager.image_reader_writer_class(), - self.dataset_json["file_ending"], - self.label_manager.foreground_regions if self.label_manager.has_regions else - self.label_manager.foreground_labels, - self.label_manager.ignore_label, chill=True) - self.print_to_log_file("Validation complete", also_print_to_console=True) - self.print_to_log_file("Mean Validation Dice: ", (metrics['foreground_mean']["Dice"]), also_print_to_console=True) - - self.set_deep_supervision_enabled(True) - compute_gaussian.cache_clear() - - def run_training(self): - self.on_train_start() - - for epoch in range(self.current_epoch, self.num_epochs): - self.on_epoch_start() - - self.on_train_epoch_start() - train_outputs = [] - for batch_id in range(self.num_iterations_per_epoch): - train_outputs.append(self.train_step(next(self.dataloader_train))) - self.on_train_epoch_end(train_outputs) - - with torch.no_grad(): - self.on_validation_epoch_start() - val_outputs = [] - for batch_id in range(self.num_val_iterations_per_epoch): - val_outputs.append(self.validation_step(next(self.dataloader_val))) - self.on_validation_epoch_end(val_outputs) - - self.on_epoch_end() - - self.on_train_end()