diff --git a/nnunetv2/inference/predict_from_raw_data.py b/nnunetv2/inference/predict_from_raw_data.py index 8d4096482..aa2fceea0 100644 --- a/nnunetv2/inference/predict_from_raw_data.py +++ b/nnunetv2/inference/predict_from_raw_data.py @@ -29,7 +29,7 @@ from nnunetv2.utilities.find_class_by_name import recursive_find_python_class from nnunetv2.utilities.helpers import empty_cache, dummy_context from nnunetv2.utilities.json_export import recursive_fix_for_json_export -from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels +from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels, convert_labelmap_to_one_hot from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager from nnunetv2.utilities.utils import create_lists_from_splitted_dataset_folder @@ -250,6 +250,10 @@ def predict_from_files(self, if len(list_of_lists_or_source_folder) == 0: return + if num_processes_preprocessing == 0 and num_processes_segmentation_export == 0: + return self._sequential_prediction(list_of_lists_or_source_folder, seg_from_prev_stage_files, + output_filename_truncated, save_probabilities) + data_iterator = self._internal_get_data_iterator_from_lists_of_filenames(list_of_lists_or_source_folder, seg_from_prev_stage_files, output_filename_truncated, @@ -257,6 +261,76 @@ def predict_from_files(self, return self.predict_from_data_iterator(data_iterator, save_probabilities, num_processes_segmentation_export) + def _load_data_for_prediction(self, input_file, input_seg, properties, preprocessor, plans_manager, + configuration_manager, dataset_json, label_manager): + if properties is not None: + data, seg = preprocessor.run_case_npy( + input_file, + input_seg, + properties, + plans_manager, + configuration_manager, + dataset_json) + else: + data, seg, properties = preprocessor.run_case( + input_file, + input_seg, + plans_manager, + configuration_manager, + dataset_json) + + if input_seg is not None: + seg_onehot = convert_labelmap_to_one_hot(input_seg[0], label_manager.foreground_labels, data.dtype) + data = np.vstack((data, seg_onehot)) + + data = torch.from_numpy(data).to(dtype=torch.float32, memory_format=torch.contiguous_format) + if self.device.type == 'cuda': + data = data.pin_memory() + return data, properties + + @torch.inference_mode() + def _sequential_prediction(self, input_list_of_lists, seg_from_prev_stage_files, + output_filename_truncated, save_probabilities): + ret = [] + configuration_manager = self.configuration_manager + preprocessor = configuration_manager.preprocessor_class(verbose=self.verbose_preprocessing) + plans_manager = self.plans_manager + dataset_json = self.dataset_json + label_manager = plans_manager.get_label_manager(dataset_json) + + for i in range(len(input_list_of_lists)): + ofile = output_filename_truncated[i] if output_filename_truncated is not None else None + if ofile is not None: + print(f'\nPredicting {os.path.basename(ofile)}:') + else: + print(f'\nPredicting image of shape {data.shape}:') + + data, properties = self._load_data_for_prediction(input_list_of_lists[i], + seg_from_prev_stage_files[ + i] if seg_from_prev_stage_files is not None else None, + None, + preprocessor, plans_manager, configuration_manager, + dataset_json, label_manager) + + prediction = self.predict_logits_from_preprocessed_data(data) + + if ofile is not None: + print('resampling and export') + export_prediction_from_logits( + prediction, properties, self.configuration_manager, self.plans_manager, self.dataset_json, ofile, + save_probabilities) + print(f'done with {os.path.basename(ofile)}') + else: + print('resampling') + ret.append(convert_predicted_logits_to_segmentation_with_correct_shape( + prediction, self.plans_manager, self.configuration_manager, self.label_manager, properties, + save_probabilities)) + print(f'\nDone with image of shape {data.shape}:') + + compute_gaussian.cache_clear() + empty_cache(self.device) + return ret + def _internal_get_data_iterator_from_lists_of_filenames(self, input_list_of_lists: List[List[str]], seg_from_prev_stage_files: Union[List[str], None], @@ -418,6 +492,7 @@ def predict_from_data_iterator(self, empty_cache(self.device) return ret + @torch.inference_mode() def predict_single_npy_array(self, input_image: np.ndarray, image_properties: dict, segmentation_previous_stage: np.ndarray = None, output_file_truncated: str = None, @@ -435,35 +510,29 @@ def predict_single_npy_array(self, input_image: np.ndarray, image_properties: di you need to transpose your axes AND your spacing from [x,y,z] to [z,y,x]! image_properties must only have a 'spacing' key! """ - ppa = PreprocessAdapterFromNpy([input_image], [segmentation_previous_stage], [image_properties], - [output_file_truncated], - self.plans_manager, self.dataset_json, self.configuration_manager, - num_threads_in_multithreaded=1, verbose=self.verbose) - if self.verbose: - print('preprocessing') - dct = next(ppa) + data, properties = self._load_data_for_prediction(input_image, segmentation_previous_stage, image_properties, + self.configuration_manager.preprocessor_class( + verbose=self.verbose_preprocessing), + self.plans_manager, + self.configuration_manager, + self.dataset_json, + self.plans_manager.get_label_manager(self.dataset_json)) - if self.verbose: - print('predicting') - predicted_logits = self.predict_logits_from_preprocessed_data(dct['data']).cpu() + predicted_logits = self.predict_logits_from_preprocessed_data(data) if self.verbose: print('resampling to original shape') if output_file_truncated is not None: - export_prediction_from_logits(predicted_logits, dct['data_properties'], self.configuration_manager, + export_prediction_from_logits(predicted_logits, properties, self.configuration_manager, self.plans_manager, self.dataset_json, output_file_truncated, save_or_return_probabilities) else: - ret = convert_predicted_logits_to_segmentation_with_correct_shape(predicted_logits, self.plans_manager, - self.configuration_manager, - self.label_manager, - dct['data_properties'], - return_probabilities= - save_or_return_probabilities) - if save_or_return_probabilities: - return ret[0], ret[1] - else: - return ret + return convert_predicted_logits_to_segmentation_with_correct_shape(predicted_logits, self.plans_manager, + self.configuration_manager, + self.label_manager, + properties, + return_probabilities= + save_or_return_probabilities) def predict_logits_from_preprocessed_data(self, data: torch.Tensor) -> torch.Tensor: """