Skip to content

Commit

Permalink
Adding sequential inference
Browse files Browse the repository at this point in the history
  • Loading branch information
ancestor-mithril committed May 31, 2024
1 parent d12a0c1 commit fef5a45
Showing 1 changed file with 91 additions and 22 deletions.
113 changes: 91 additions & 22 deletions nnunetv2/inference/predict_from_raw_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from nnunetv2.utilities.find_class_by_name import recursive_find_python_class
from nnunetv2.utilities.helpers import empty_cache, dummy_context
from nnunetv2.utilities.json_export import recursive_fix_for_json_export
from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels
from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels, convert_labelmap_to_one_hot
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager
from nnunetv2.utilities.utils import create_lists_from_splitted_dataset_folder

Expand Down Expand Up @@ -250,13 +250,87 @@ def predict_from_files(self,
if len(list_of_lists_or_source_folder) == 0:
return

if num_processes_preprocessing == 0 and num_processes_segmentation_export == 0:
return self._sequential_prediction(list_of_lists_or_source_folder, seg_from_prev_stage_files,
output_filename_truncated, save_probabilities)

data_iterator = self._internal_get_data_iterator_from_lists_of_filenames(list_of_lists_or_source_folder,
seg_from_prev_stage_files,
output_filename_truncated,
num_processes_preprocessing)

return self.predict_from_data_iterator(data_iterator, save_probabilities, num_processes_segmentation_export)

def _load_data_for_prediction(self, input_file, input_seg, properties, preprocessor, plans_manager,
configuration_manager, dataset_json, label_manager):
if properties is not None:
data, seg = preprocessor.run_case_npy(
input_file,
input_seg,
properties,
plans_manager,
configuration_manager,
dataset_json)
else:
data, seg, properties = preprocessor.run_case(
input_file,
input_seg,
plans_manager,
configuration_manager,
dataset_json)

if input_seg is not None:
seg_onehot = convert_labelmap_to_one_hot(input_seg[0], label_manager.foreground_labels, data.dtype)
data = np.vstack((data, seg_onehot))

data = torch.from_numpy(data).to(dtype=torch.float32, memory_format=torch.contiguous_format)
if self.device.type == 'cuda':
data = data.pin_memory()
return data, properties

@torch.inference_mode()
def _sequential_prediction(self, input_list_of_lists, seg_from_prev_stage_files,
output_filename_truncated, save_probabilities):
ret = []
configuration_manager = self.configuration_manager
preprocessor = configuration_manager.preprocessor_class(verbose=self.verbose_preprocessing)
plans_manager = self.plans_manager
dataset_json = self.dataset_json
label_manager = plans_manager.get_label_manager(dataset_json)

for i in range(len(input_list_of_lists)):
ofile = output_filename_truncated[i] if output_filename_truncated is not None else None
if ofile is not None:
print(f'\nPredicting {os.path.basename(ofile)}:')
else:
print(f'\nPredicting image of shape {data.shape}:')

data, properties = self._load_data_for_prediction(input_list_of_lists[i],
seg_from_prev_stage_files[
i] if seg_from_prev_stage_files is not None else None,
None,
preprocessor, plans_manager, configuration_manager,
dataset_json, label_manager)

prediction = self.predict_logits_from_preprocessed_data(data)

if ofile is not None:
print('resampling and export')
export_prediction_from_logits(
prediction, properties, self.configuration_manager, self.plans_manager, self.dataset_json, ofile,
save_probabilities)
print(f'done with {os.path.basename(ofile)}')
else:
print('resampling')
ret.append(convert_predicted_logits_to_segmentation_with_correct_shape(
prediction, self.plans_manager, self.configuration_manager, self.label_manager, properties,
save_probabilities))
print(f'\nDone with image of shape {data.shape}:')

compute_gaussian.cache_clear()
empty_cache(self.device)
return ret

def _internal_get_data_iterator_from_lists_of_filenames(self,
input_list_of_lists: List[List[str]],
seg_from_prev_stage_files: Union[List[str], None],
Expand Down Expand Up @@ -418,6 +492,7 @@ def predict_from_data_iterator(self,
empty_cache(self.device)
return ret

@torch.inference_mode()
def predict_single_npy_array(self, input_image: np.ndarray, image_properties: dict,
segmentation_previous_stage: np.ndarray = None,
output_file_truncated: str = None,
Expand All @@ -435,35 +510,29 @@ def predict_single_npy_array(self, input_image: np.ndarray, image_properties: di
you need to transpose your axes AND your spacing from [x,y,z] to [z,y,x]!
image_properties must only have a 'spacing' key!
"""
ppa = PreprocessAdapterFromNpy([input_image], [segmentation_previous_stage], [image_properties],
[output_file_truncated],
self.plans_manager, self.dataset_json, self.configuration_manager,
num_threads_in_multithreaded=1, verbose=self.verbose)
if self.verbose:
print('preprocessing')
dct = next(ppa)
data, properties = self._load_data_for_prediction(input_image, segmentation_previous_stage, image_properties,
self.configuration_manager.preprocessor_class(
verbose=self.verbose_preprocessing),
self.plans_manager,
self.configuration_manager,
self.dataset_json,
self.plans_manager.get_label_manager(self.dataset_json))

if self.verbose:
print('predicting')
predicted_logits = self.predict_logits_from_preprocessed_data(dct['data']).cpu()
predicted_logits = self.predict_logits_from_preprocessed_data(data)

if self.verbose:
print('resampling to original shape')
if output_file_truncated is not None:
export_prediction_from_logits(predicted_logits, dct['data_properties'], self.configuration_manager,
export_prediction_from_logits(predicted_logits, properties, self.configuration_manager,
self.plans_manager, self.dataset_json, output_file_truncated,
save_or_return_probabilities)
else:
ret = convert_predicted_logits_to_segmentation_with_correct_shape(predicted_logits, self.plans_manager,
self.configuration_manager,
self.label_manager,
dct['data_properties'],
return_probabilities=
save_or_return_probabilities)
if save_or_return_probabilities:
return ret[0], ret[1]
else:
return ret
return convert_predicted_logits_to_segmentation_with_correct_shape(predicted_logits, self.plans_manager,
self.configuration_manager,
self.label_manager,
properties,
return_probabilities=
save_or_return_probabilities)

def predict_logits_from_preprocessed_data(self, data: torch.Tensor) -> torch.Tensor:
"""
Expand Down

0 comments on commit fef5a45

Please sign in to comment.