-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
adding focal loss, distance map weighted dice loss, functionality for…
… loading/transforming distance maps
- Loading branch information
1 parent
ac79a61
commit 5685a27
Showing
21 changed files
with
2,003 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
156 changes: 156 additions & 0 deletions
156
nnunetv2/experiment_planning/experiment_planners/distance_transform_experiment_planner.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
import shutil | ||
from copy import deepcopy | ||
from typing import List, Union, Tuple | ||
import argparse | ||
|
||
import numpy as np | ||
import torch | ||
from batchgenerators.utilities.file_and_folder_operations import load_json, join, save_json, isfile, maybe_mkdir_p | ||
from nnunetv2.configuration import ANISO_THRESHOLD | ||
from nnunetv2.experiment_planning.experiment_planners.network_topology import get_pool_and_conv_props | ||
from nnunetv2.imageio.reader_writer_registry import determine_reader_writer_from_dataset_json | ||
from nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed | ||
from nnunetv2.preprocessing.normalization.map_channel_name_to_normalization import get_normalization_scheme | ||
from nnunetv2.preprocessing.resampling.default_resampling import resample_data_or_seg_to_shape, compute_new_shape | ||
from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name | ||
from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA | ||
from nnunetv2.utilities.get_network_from_plans import get_network_from_plans | ||
from nnunetv2.utilities.json_export import recursive_fix_for_json_export | ||
from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets | ||
from nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner | ||
|
||
class DistanceTransformExperimentPlanner(ExperimentPlanner): | ||
def __init__(self, dataset_name_or_id: Union[str, int], | ||
gpu_memory_target_in_gb: float = 8, | ||
preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'DistanceTransformPlans', | ||
overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, | ||
suppress_transpose: bool = False): | ||
super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, overwrite_target_spacing, suppress_transpose) | ||
|
||
# Custom properties for distance maps | ||
self.distance_maps_dir = join(nnUNet_preprocessed, self.dataset_name, 'distance_transforms') | ||
|
||
def plan_experiment(self): | ||
""" | ||
Extend the default plan_experiment function to incorporate distance maps into the plans. | ||
""" | ||
_tmp = {} | ||
|
||
# Get transpose | ||
transpose_forward, transpose_backward = self.determine_transpose() | ||
|
||
# Get fullres spacing and transpose it | ||
fullres_spacing = self.determine_fullres_target_spacing() | ||
fullres_spacing_transposed = fullres_spacing[transpose_forward] | ||
|
||
# Get transposed new median shape (what we would have after resampling) | ||
new_shapes = [compute_new_shape(j, i, fullres_spacing) for i, j in | ||
zip(self.dataset_fingerprint['spacings'], self.dataset_fingerprint['shapes_after_crop'])] | ||
new_median_shape = np.median(new_shapes, 0) | ||
new_median_shape_transposed = new_median_shape[transpose_forward] | ||
|
||
approximate_n_voxels_dataset = float(np.prod(new_median_shape_transposed, dtype=np.float64) * | ||
self.dataset_json['numTraining']) | ||
# Only run 3d if this is a 3d dataset | ||
if new_median_shape_transposed[0] != 1: | ||
plan_3d_fullres = self.get_plans_for_configuration(fullres_spacing_transposed, | ||
new_median_shape_transposed, | ||
self.generate_data_identifier('3d_fullres'), | ||
approximate_n_voxels_dataset, _tmp) | ||
plan_3d_fullres['distance_maps_dir'] = self.distance_maps_dir # Add distance maps directory to the plans | ||
plan_3d_fullres['batch_dice'] = True | ||
else: | ||
plan_3d_fullres = None | ||
|
||
# 2D configuration | ||
plan_2d = self.get_plans_for_configuration(fullres_spacing_transposed[1:], | ||
new_median_shape_transposed[1:], | ||
self.generate_data_identifier('2d'), approximate_n_voxels_dataset, | ||
_tmp) | ||
plan_2d['distance_maps_dir'] = self.distance_maps_dir # Add distance maps directory to the plans | ||
plan_2d['batch_dice'] = True | ||
|
||
print('2D U-Net configuration:') | ||
print(plan_2d) | ||
print() | ||
|
||
# Median spacing and shape, just for reference when printing the plans | ||
median_spacing = np.median(self.dataset_fingerprint['spacings'], 0)[transpose_forward] | ||
median_shape = np.median(self.dataset_fingerprint['shapes_after_crop'], 0)[transpose_forward] | ||
|
||
# Instead of writing all that into the plans we just copy the original file. More files, but less crowded | ||
# per file. | ||
shutil.copy(join(self.raw_dataset_folder, 'dataset.json'), | ||
join(nnUNet_preprocessed, self.dataset_name, 'dataset.json')) | ||
|
||
# JSON serialization adjustment | ||
plans = { | ||
'dataset_name': self.dataset_name, | ||
'plans_name': self.plans_identifier, | ||
'original_median_spacing_after_transp': [float(i) for i in median_spacing], | ||
'original_median_shape_after_transp': [int(round(i)) for i in median_shape], | ||
'image_reader_writer': self.determine_reader_writer().__name__, | ||
'transpose_forward': [int(i) for i in transpose_forward], | ||
'transpose_backward': [int(i) for i in transpose_backward], | ||
'configurations': {'2d': plan_2d}, | ||
'experiment_planner_used': self.__class__.__name__, | ||
'label_manager': 'LabelManager', | ||
'foreground_intensity_properties_per_channel': self.dataset_fingerprint['foreground_intensity_properties_per_channel'] | ||
} | ||
|
||
if plan_3d_fullres is not None: | ||
plans['configurations']['3d_fullres'] = plan_3d_fullres | ||
print('3D fullres U-Net configuration:') | ||
print(plan_3d_fullres) | ||
print() | ||
|
||
self.plans = plans | ||
self.save_plans(plans) | ||
return plans | ||
|
||
def save_plans(self, plans): | ||
recursive_fix_for_json_export(plans) | ||
|
||
plans_file = join(nnUNet_preprocessed, self.dataset_name, self.plans_identifier + '.json') | ||
|
||
# Avoid overwriting existing custom configurations every time this is executed. | ||
if isfile(plans_file): | ||
old_plans = load_json(plans_file) | ||
old_configurations = old_plans['configurations'] | ||
for c in plans['configurations'].keys(): | ||
if c in old_configurations.keys(): | ||
del (old_configurations[c]) | ||
plans['configurations'].update(old_configurations) | ||
|
||
maybe_mkdir_p(join(nnUNet_preprocessed, self.dataset_name)) | ||
save_json(plans, plans_file, sort_keys=False) | ||
print(f"Plans were saved to {join(nnUNet_preprocessed, self.dataset_name, self.plans_identifier + '.json')}") | ||
|
||
def test_distance_transform_experiment_planner(dataset_name_or_id=101): | ||
# Create an instance of the planner | ||
planner = DistanceTransformExperimentPlanner(dataset_name_or_id=dataset_name_or_id, | ||
gpu_memory_target_in_gb=8) | ||
|
||
# Run the planner | ||
try: | ||
plans = planner.plan_experiment() | ||
print("Experiment planner executed successfully.") | ||
|
||
# Check if the plans have the distance_maps_dir property in their configurations | ||
assert plans['configurations'] is not None, "Configurations are missing from the plans." | ||
assert '2d' in plans['configurations'], "2D plan configuration is missing." | ||
assert 'distance_maps_dir' in plans['configurations']['2d'], "Distance maps directory is missing in 2D configuration." | ||
|
||
if '3d_fullres' in plans['configurations']: | ||
assert 'distance_maps_dir' in plans['configurations']['3d_fullres'], "Distance maps directory is missing in 3D fullres configuration." | ||
|
||
except Exception as e: | ||
print(f"Experiment planner test failed: {e}") | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="Test the DistanceTransformExperimentPlanner.") | ||
parser.add_argument('-d', '--dataset_id', type=str, default="test_dataset", help='Dataset name or ID to be used for testing.') | ||
args = parser.parse_args() | ||
|
||
# Run the test function with the provided dataset ID | ||
test_distance_transform_experiment_planner(dataset_name_or_id=args.dataset_id) |
98 changes: 98 additions & 0 deletions
98
nnunetv2/training/data_augmentation/custom_transforms/distance_map_augmentation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import inspect | ||
import multiprocessing | ||
import os | ||
import shutil | ||
import sys | ||
import warnings | ||
from copy import deepcopy | ||
from datetime import datetime | ||
from time import time, sleep | ||
from typing import Tuple, Union, List | ||
|
||
import abc | ||
import numpy as np | ||
import torch | ||
from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter | ||
from batchgenerators.dataloading.nondet_multi_threaded_augmenter import NonDetMultiThreadedAugmenter | ||
from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter | ||
from batchgenerators.utilities.file_and_folder_operations import join, load_json, isfile, save_json, maybe_mkdir_p | ||
from batchgeneratorsv2.helpers.scalar_type import RandomScalar | ||
from batchgeneratorsv2.transforms.base.basic_transform import BasicTransform | ||
from batchgeneratorsv2.transforms.intensity.brightness import MultiplicativeBrightnessTransform | ||
from batchgeneratorsv2.transforms.intensity.contrast import ContrastTransform, BGContrast | ||
from batchgeneratorsv2.transforms.intensity.gamma import GammaTransform | ||
from batchgeneratorsv2.transforms.intensity.gaussian_noise import GaussianNoiseTransform | ||
from batchgeneratorsv2.transforms.nnunet.random_binary_operator import ApplyRandomBinaryOperatorTransform | ||
from batchgeneratorsv2.transforms.nnunet.remove_connected_components import \ | ||
RemoveRandomConnectedComponentFromOneHotEncodingTransform | ||
from batchgeneratorsv2.transforms.nnunet.seg_to_onehot import MoveSegAsOneHotToDataTransform | ||
from batchgeneratorsv2.transforms.noise.gaussian_blur import GaussianBlurTransform | ||
from batchgeneratorsv2.transforms.spatial.low_resolution import SimulateLowResolutionTransform | ||
from batchgeneratorsv2.transforms.spatial.mirroring import MirrorTransform | ||
from batchgeneratorsv2.transforms.spatial.spatial import SpatialTransform | ||
from batchgeneratorsv2.transforms.utils.compose import ComposeTransforms | ||
from batchgeneratorsv2.transforms.utils.deep_supervision_downsampling import DownsampleSegForDSTransform | ||
from batchgeneratorsv2.transforms.utils.nnunet_masking import MaskImageTransform | ||
from batchgeneratorsv2.transforms.utils.pseudo2d import Convert3DTo2DTransform, Convert2DTo3DTransform | ||
from batchgeneratorsv2.transforms.utils.random import RandomTransform | ||
from batchgeneratorsv2.transforms.utils.remove_label import RemoveLabelTansform | ||
from batchgeneratorsv2.transforms.utils.seg_to_regions import ConvertSegmentationToRegionsTransform | ||
from torch import autocast, nn | ||
from torch import distributed as dist | ||
from torch._dynamo import OptimizedModule | ||
from torch.cuda import device_count | ||
from torch.cuda.amp import GradScaler | ||
from torch.nn.parallel import DistributedDataParallel as DDP | ||
from torch.nn.functional import interpolate | ||
|
||
class Convert3DTo2DDistTransform(Convert3DTo2DTransform): | ||
def apply(self, data_dict, **params): | ||
if 'dist_map' in data_dict.keys(): | ||
data_dict['nchannels_dist_map'] = deepcopy(data_dict['dist_map']).shape[0] | ||
return super().apply(data_dict, **params) | ||
|
||
def _apply_to_dist_map(self, dist_map: torch.Tensor, **params) -> torch.Tensor: | ||
return self._apply_to_image(dist_map, **params) | ||
|
||
class Convert2DTo3DDistTransform(Convert2DTo3DTransform): | ||
def get_parameters(self, **data_dict) -> dict: | ||
return {i: data_dict[i] for i in | ||
['nchannels_img', 'nchannels_seg', 'nchannels_regr_trg', 'nchannels_dist_map'] | ||
if i in data_dict.keys()} | ||
|
||
def apply(self, data_dict, **params): | ||
if 'nchannels_dist_map' in data_dict.keys(): | ||
del data_dict['nchannels_dist_map'] | ||
return super().apply(data_dict, **params) | ||
|
||
class SpatialDistTransform(SpatialTransform): | ||
def _apply_to_dist_map(self, dist_map, **params) -> torch.Tensor: | ||
return self._apply_to_image(dist_map, **params) | ||
|
||
class MirrorDistTransform(MirrorTransform): | ||
def _apply_to_dist_map(self, dist_map: torch.Tensor, **params) -> torch.Tensor: | ||
if len(params['axes']) == 0: | ||
return dist_map | ||
axes = [i + 1 for i in params['axes']] | ||
return torch.flip(dist_map, axes) | ||
|
||
class DownsampleSegForDSDistTransform(DownsampleSegForDSTransform): | ||
def _apply_to_dist_map(self, dist_map: torch.Tensor, **params) -> List[torch.Tensor]: | ||
results = [] | ||
for s in self.ds_scales: | ||
if not isinstance(s, (tuple, list)): | ||
s = [s] * (dist_map.ndim - 1) | ||
else: | ||
assert len(s) == dist_map.ndim - 1 | ||
|
||
if all([i == 1 for i in s]): | ||
results.append(dist_map) | ||
else: | ||
new_shape = [round(i * j) for i, j in zip(dist_map.shape[1:], s)] | ||
dtype = dist_map.dtype | ||
# interpolate is not defined for short etc | ||
results.append(interpolate(dist_map[None].float(), new_shape, mode='bilinear')[0].to(dtype)) | ||
return results | ||
|
||
if __name__ == '__main__': | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .distance_transform_dataset import DistanceTransformDataset |
Oops, something went wrong.