Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add no_resampling function + remove unnecessary imports #2547

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion nnunetv2/dataset_conversion/Dataset137_BraTS21.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import multiprocessing
import shutil
from multiprocessing import Pool

import SimpleITK as sitk
import numpy as np
Expand Down
1 change: 0 additions & 1 deletion nnunetv2/dataset_conversion/convert_MSD_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import argparse
import multiprocessing
import shutil
from multiprocessing import Pool
from typing import Optional
import SimpleITK as sitk
from batchgenerators.utilities.file_and_folder_operations import *
Expand Down
1 change: 0 additions & 1 deletion nnunetv2/ensembling/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import multiprocessing
import shutil
from copy import deepcopy
from multiprocessing import Pool
from typing import List, Union, Tuple

import numpy as np
Expand Down
3 changes: 1 addition & 2 deletions nnunetv2/evaluation/evaluate_predictions.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import multiprocessing
import os
from copy import deepcopy
from multiprocessing import Pool
from typing import Tuple, List, Union, Optional
from typing import Tuple, List, Union

import numpy as np
from batchgenerators.utilities.file_and_folder_operations import subfiles, join, save_json, load_json, \
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import argparse
from typing import Union

from batchgenerators.utilities.file_and_folder_operations import join, isdir, isfile, load_json, subfiles, save_json
from batchgenerators.utilities.file_and_folder_operations import join, isdir, isfile, load_json, save_json

from nnunetv2.imageio.reader_writer_registry import determine_reader_writer_from_dataset_json
from nnunetv2.paths import nnUNet_preprocessed, nnUNet_raw
from nnunetv2.utilities.file_path_utilities import maybe_convert_to_dataset_name
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager
from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets


Expand Down
5 changes: 1 addition & 4 deletions nnunetv2/experiment_planning/verify_dataset_integrity.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import multiprocessing
import re
from multiprocessing import Pool
from typing import Type

import numpy as np
Expand All @@ -25,8 +23,7 @@
from nnunetv2.imageio.reader_writer_registry import determine_reader_writer_from_dataset_json
from nnunetv2.paths import nnUNet_raw
from nnunetv2.utilities.label_handling.label_handling import LabelManager
from nnunetv2.utilities.utils import get_identifiers_from_splitted_dataset_folder, \
get_filenames_of_train_images_and_targets
from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets


def verify_labels(label_file: str, readerclass: Type[BaseReaderWriter], expected_labels: List[int]) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion nnunetv2/inference/data_iterators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import multiprocessing
import queue
from torch.multiprocessing import Event, Process, Queue, Manager
from torch.multiprocessing import Event, Queue, Manager

from time import sleep
from typing import Union, List
Expand Down
17 changes: 1 addition & 16 deletions nnunetv2/inference/export_prediction.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import os
from copy import deepcopy
from typing import Union, List

import numpy as np
Expand Down Expand Up @@ -74,13 +72,6 @@ def export_prediction_from_logits(predicted_array_or_file: Union[np.ndarray, tor
plans_manager: PlansManager,
dataset_json_dict_or_file: Union[dict, str], output_file_truncated: str,
save_probabilities: bool = False):
# if isinstance(predicted_array_or_file, str):
# tmp = deepcopy(predicted_array_or_file)
# if predicted_array_or_file.endswith('.npy'):
# predicted_array_or_file = np.load(predicted_array_or_file)
# elif predicted_array_or_file.endswith('.npz'):
# predicted_array_or_file = np.load(predicted_array_or_file)['softmax']
# os.remove(tmp)

if isinstance(dataset_json_dict_or_file, str):
dataset_json_dict_or_file = load_json(dataset_json_dict_or_file)
Expand Down Expand Up @@ -111,13 +102,7 @@ def resample_and_save(predicted: Union[torch.Tensor, np.ndarray], target_shape:
plans_manager: PlansManager, configuration_manager: ConfigurationManager, properties_dict: dict,
dataset_json_dict_or_file: Union[dict, str], num_threads_torch: int = default_num_processes) \
-> None:
# # needed for cascade
# if isinstance(predicted, str):
# assert isfile(predicted), "If isinstance(segmentation_softmax, str) then " \
# "isfile(segmentation_softmax) must be True"
# del_file = deepcopy(predicted)
# predicted = np.load(predicted)
# os.remove(del_file)

old_threads = torch.get_num_threads()
torch.set_num_threads(num_threads_torch)

Expand Down
1 change: 0 additions & 1 deletion nnunetv2/postprocessing/remove_connected_components.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import argparse
import multiprocessing
import shutil
from multiprocessing import Pool
from typing import Union, Tuple, List, Callable

import numpy as np
Expand Down
18 changes: 15 additions & 3 deletions nnunetv2/preprocessing/resampling/default_resampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import numpy as np
import pandas as pd
import sklearn
import torch
from batchgenerators.augmentations.utils import resize_segmentation
from scipy.ndimage import map_coordinates
Expand Down Expand Up @@ -143,7 +142,6 @@ def resample_data_or_seg(data: np.ndarray, new_shape: Union[Tuple[float, ...], L
if np.any(shape != new_shape):
data = data.astype(float, copy=False)
if do_separate_z:
# print("separate z, order in z is", order_z, "order inplane is", order)
assert axis is not None, 'If do_separate_z, we need to know what axis is anisotropic'
if axis == 0:
new_shape_2d = new_shape[1:]
Expand Down Expand Up @@ -191,13 +189,27 @@ def resample_data_or_seg(data: np.ndarray, new_shape: Union[Tuple[float, ...], L
else:
reshaped_final[c] = reshaped_here
else:
# print("no separate z, order", order)
for c in range(data.shape[0]):
reshaped_final[c] = resize_fn(data[c], new_shape, order, **kwargs)
return reshaped_final
else:
# print("no resampling necessary")
return data


def no_resampling_data_or_seg_to_shape(data: Union[torch.Tensor, np.ndarray],
new_shape: Union[Tuple[int, ...], List[int], np.ndarray],
current_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
new_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
is_seg: bool = False,
order: int = 3, order_z: int = 0,
force_separate_z: Union[bool, None] = False,
separate_z_anisotropy_threshold: float = ANISO_THRESHOLD):
"""
A simplified resampling function that bypasses actual resampling.
This approach makes it much easier to implement no-resampling training and inference.
"""
return data


if __name__ == '__main__':
Expand Down
1 change: 0 additions & 1 deletion nnunetv2/run/run_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,6 @@ def run_training_entry():
assert args.device in ['cpu', 'cuda', 'mps'], f'-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {args.device}.'
if args.device == 'cpu':
# let's allow torch to use hella threads
import multiprocessing
torch.set_num_threads(multiprocessing.cpu_count())
device = torch.device('cpu')
elif args.device == 'cuda':
Expand Down
1 change: 0 additions & 1 deletion nnunetv2/training/loss/deep_supervision.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import torch
from torch import nn


Expand Down