Skip to content

Commit

Permalink
adding focal loss, distance map weighted dice loss, functionality for…
Browse files Browse the repository at this point in the history
… loading/transforming distance maps
  • Loading branch information
andy-s-ding committed Dec 2, 2024
1 parent ac79a61 commit 5685a27
Show file tree
Hide file tree
Showing 21 changed files with 2,003 additions and 5 deletions.
9 changes: 5 additions & 4 deletions nnunetv2/evaluation/evaluate_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,12 +251,13 @@ def evaluate_simple_entry_point():


if __name__ == '__main__':
folder_ref = '/media/fabian/data/nnUNet_raw/Dataset004_Hippocampus/labelsTr'
folder_pred = '/home/fabian/results/nnUNet_remake/Dataset004_Hippocampus/nnUNetModule__nnUNetPlans__3d_fullres/fold_0/validation'
output_file = '/home/fabian/results/nnUNet_remake/Dataset004_Hippocampus/nnUNetModule__nnUNetPlans__3d_fullres/fold_0/validation/summary.json'
folder_ref = '/home/andyding/tbone-seg-nnunetv2/00_nnUNetv2_baseline_retrain/nnUNet_raw/Dataset101_TemporalBone/labelsTs'
folder_pred = '/home/andyding/tbone-seg-nnunetv2/00_nnUNetv2_baseline_retrain/nnUNet_trained_models/Dataset101_TemporalBone/nnUNetTrainer_300epochs__nnUNetPlans__3d_fullres/test_results/postprocessed'
output_file = '/home/andyding/tbone-seg-nnunetv2/00_nnUNetv2_baseline_retrain/nnUNet_trained_models/Dataset101_TemporalBone/nnUNetTrainer_300epochs__nnUNetPlans__3d_fullres/test_results/postprocessed/summary.json'

image_reader_writer = SimpleITKIO()
file_ending = '.nii.gz'
regions = labels_to_list_of_regions([1, 2])
regions = labels_to_list_of_regions(list(range(17)))
ignore_label = None
num_processes = 12
compute_metrics_on_folder(folder_ref, folder_pred, output_file, image_reader_writer, file_ending, regions, ignore_label,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import shutil
from copy import deepcopy
from typing import List, Union, Tuple
import argparse

import numpy as np
import torch
from batchgenerators.utilities.file_and_folder_operations import load_json, join, save_json, isfile, maybe_mkdir_p
from nnunetv2.configuration import ANISO_THRESHOLD
from nnunetv2.experiment_planning.experiment_planners.network_topology import get_pool_and_conv_props
from nnunetv2.imageio.reader_writer_registry import determine_reader_writer_from_dataset_json
from nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed
from nnunetv2.preprocessing.normalization.map_channel_name_to_normalization import get_normalization_scheme
from nnunetv2.preprocessing.resampling.default_resampling import resample_data_or_seg_to_shape, compute_new_shape
from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA
from nnunetv2.utilities.get_network_from_plans import get_network_from_plans
from nnunetv2.utilities.json_export import recursive_fix_for_json_export
from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets
from nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner

class DistanceTransformExperimentPlanner(ExperimentPlanner):
def __init__(self, dataset_name_or_id: Union[str, int],
gpu_memory_target_in_gb: float = 8,
preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'DistanceTransformPlans',
overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None,
suppress_transpose: bool = False):
super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, overwrite_target_spacing, suppress_transpose)

# Custom properties for distance maps
self.distance_maps_dir = join(nnUNet_preprocessed, self.dataset_name, 'distance_transforms')

def plan_experiment(self):
"""
Extend the default plan_experiment function to incorporate distance maps into the plans.
"""
_tmp = {}

# Get transpose
transpose_forward, transpose_backward = self.determine_transpose()

# Get fullres spacing and transpose it
fullres_spacing = self.determine_fullres_target_spacing()
fullres_spacing_transposed = fullres_spacing[transpose_forward]

# Get transposed new median shape (what we would have after resampling)
new_shapes = [compute_new_shape(j, i, fullres_spacing) for i, j in
zip(self.dataset_fingerprint['spacings'], self.dataset_fingerprint['shapes_after_crop'])]
new_median_shape = np.median(new_shapes, 0)
new_median_shape_transposed = new_median_shape[transpose_forward]

approximate_n_voxels_dataset = float(np.prod(new_median_shape_transposed, dtype=np.float64) *
self.dataset_json['numTraining'])
# Only run 3d if this is a 3d dataset
if new_median_shape_transposed[0] != 1:
plan_3d_fullres = self.get_plans_for_configuration(fullres_spacing_transposed,
new_median_shape_transposed,
self.generate_data_identifier('3d_fullres'),
approximate_n_voxels_dataset, _tmp)
plan_3d_fullres['distance_maps_dir'] = self.distance_maps_dir # Add distance maps directory to the plans
plan_3d_fullres['batch_dice'] = True
else:
plan_3d_fullres = None

# 2D configuration
plan_2d = self.get_plans_for_configuration(fullres_spacing_transposed[1:],
new_median_shape_transposed[1:],
self.generate_data_identifier('2d'), approximate_n_voxels_dataset,
_tmp)
plan_2d['distance_maps_dir'] = self.distance_maps_dir # Add distance maps directory to the plans
plan_2d['batch_dice'] = True

print('2D U-Net configuration:')
print(plan_2d)
print()

# Median spacing and shape, just for reference when printing the plans
median_spacing = np.median(self.dataset_fingerprint['spacings'], 0)[transpose_forward]
median_shape = np.median(self.dataset_fingerprint['shapes_after_crop'], 0)[transpose_forward]

# Instead of writing all that into the plans we just copy the original file. More files, but less crowded
# per file.
shutil.copy(join(self.raw_dataset_folder, 'dataset.json'),
join(nnUNet_preprocessed, self.dataset_name, 'dataset.json'))

# JSON serialization adjustment
plans = {
'dataset_name': self.dataset_name,
'plans_name': self.plans_identifier,
'original_median_spacing_after_transp': [float(i) for i in median_spacing],
'original_median_shape_after_transp': [int(round(i)) for i in median_shape],
'image_reader_writer': self.determine_reader_writer().__name__,
'transpose_forward': [int(i) for i in transpose_forward],
'transpose_backward': [int(i) for i in transpose_backward],
'configurations': {'2d': plan_2d},
'experiment_planner_used': self.__class__.__name__,
'label_manager': 'LabelManager',
'foreground_intensity_properties_per_channel': self.dataset_fingerprint['foreground_intensity_properties_per_channel']
}

if plan_3d_fullres is not None:
plans['configurations']['3d_fullres'] = plan_3d_fullres
print('3D fullres U-Net configuration:')
print(plan_3d_fullres)
print()

self.plans = plans
self.save_plans(plans)
return plans

def save_plans(self, plans):
recursive_fix_for_json_export(plans)

plans_file = join(nnUNet_preprocessed, self.dataset_name, self.plans_identifier + '.json')

# Avoid overwriting existing custom configurations every time this is executed.
if isfile(plans_file):
old_plans = load_json(plans_file)
old_configurations = old_plans['configurations']
for c in plans['configurations'].keys():
if c in old_configurations.keys():
del (old_configurations[c])
plans['configurations'].update(old_configurations)

maybe_mkdir_p(join(nnUNet_preprocessed, self.dataset_name))
save_json(plans, plans_file, sort_keys=False)
print(f"Plans were saved to {join(nnUNet_preprocessed, self.dataset_name, self.plans_identifier + '.json')}")

def test_distance_transform_experiment_planner(dataset_name_or_id=101):
# Create an instance of the planner
planner = DistanceTransformExperimentPlanner(dataset_name_or_id=dataset_name_or_id,
gpu_memory_target_in_gb=8)

# Run the planner
try:
plans = planner.plan_experiment()
print("Experiment planner executed successfully.")

# Check if the plans have the distance_maps_dir property in their configurations
assert plans['configurations'] is not None, "Configurations are missing from the plans."
assert '2d' in plans['configurations'], "2D plan configuration is missing."
assert 'distance_maps_dir' in plans['configurations']['2d'], "Distance maps directory is missing in 2D configuration."

if '3d_fullres' in plans['configurations']:
assert 'distance_maps_dir' in plans['configurations']['3d_fullres'], "Distance maps directory is missing in 3D fullres configuration."

except Exception as e:
print(f"Experiment planner test failed: {e}")

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Test the DistanceTransformExperimentPlanner.")
parser.add_argument('-d', '--dataset_id', type=str, default="test_dataset", help='Dataset name or ID to be used for testing.')
args = parser.parse_args()

# Run the test function with the provided dataset ID
test_distance_transform_experiment_planner(dataset_name_or_id=args.dataset_id)
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import inspect
import multiprocessing
import os
import shutil
import sys
import warnings
from copy import deepcopy
from datetime import datetime
from time import time, sleep
from typing import Tuple, Union, List

import abc
import numpy as np
import torch
from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter
from batchgenerators.dataloading.nondet_multi_threaded_augmenter import NonDetMultiThreadedAugmenter
from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter
from batchgenerators.utilities.file_and_folder_operations import join, load_json, isfile, save_json, maybe_mkdir_p
from batchgeneratorsv2.helpers.scalar_type import RandomScalar
from batchgeneratorsv2.transforms.base.basic_transform import BasicTransform
from batchgeneratorsv2.transforms.intensity.brightness import MultiplicativeBrightnessTransform
from batchgeneratorsv2.transforms.intensity.contrast import ContrastTransform, BGContrast
from batchgeneratorsv2.transforms.intensity.gamma import GammaTransform
from batchgeneratorsv2.transforms.intensity.gaussian_noise import GaussianNoiseTransform
from batchgeneratorsv2.transforms.nnunet.random_binary_operator import ApplyRandomBinaryOperatorTransform
from batchgeneratorsv2.transforms.nnunet.remove_connected_components import \
RemoveRandomConnectedComponentFromOneHotEncodingTransform
from batchgeneratorsv2.transforms.nnunet.seg_to_onehot import MoveSegAsOneHotToDataTransform
from batchgeneratorsv2.transforms.noise.gaussian_blur import GaussianBlurTransform
from batchgeneratorsv2.transforms.spatial.low_resolution import SimulateLowResolutionTransform
from batchgeneratorsv2.transforms.spatial.mirroring import MirrorTransform
from batchgeneratorsv2.transforms.spatial.spatial import SpatialTransform
from batchgeneratorsv2.transforms.utils.compose import ComposeTransforms
from batchgeneratorsv2.transforms.utils.deep_supervision_downsampling import DownsampleSegForDSTransform
from batchgeneratorsv2.transforms.utils.nnunet_masking import MaskImageTransform
from batchgeneratorsv2.transforms.utils.pseudo2d import Convert3DTo2DTransform, Convert2DTo3DTransform
from batchgeneratorsv2.transforms.utils.random import RandomTransform
from batchgeneratorsv2.transforms.utils.remove_label import RemoveLabelTansform
from batchgeneratorsv2.transforms.utils.seg_to_regions import ConvertSegmentationToRegionsTransform
from torch import autocast, nn
from torch import distributed as dist
from torch._dynamo import OptimizedModule
from torch.cuda import device_count
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.functional import interpolate

class Convert3DTo2DDistTransform(Convert3DTo2DTransform):
def apply(self, data_dict, **params):
if 'dist_map' in data_dict.keys():
data_dict['nchannels_dist_map'] = deepcopy(data_dict['dist_map']).shape[0]
return super().apply(data_dict, **params)

def _apply_to_dist_map(self, dist_map: torch.Tensor, **params) -> torch.Tensor:
return self._apply_to_image(dist_map, **params)

class Convert2DTo3DDistTransform(Convert2DTo3DTransform):
def get_parameters(self, **data_dict) -> dict:
return {i: data_dict[i] for i in
['nchannels_img', 'nchannels_seg', 'nchannels_regr_trg', 'nchannels_dist_map']
if i in data_dict.keys()}

def apply(self, data_dict, **params):
if 'nchannels_dist_map' in data_dict.keys():
del data_dict['nchannels_dist_map']
return super().apply(data_dict, **params)

class SpatialDistTransform(SpatialTransform):
def _apply_to_dist_map(self, dist_map, **params) -> torch.Tensor:
return self._apply_to_image(dist_map, **params)

class MirrorDistTransform(MirrorTransform):
def _apply_to_dist_map(self, dist_map: torch.Tensor, **params) -> torch.Tensor:
if len(params['axes']) == 0:
return dist_map
axes = [i + 1 for i in params['axes']]
return torch.flip(dist_map, axes)

class DownsampleSegForDSDistTransform(DownsampleSegForDSTransform):
def _apply_to_dist_map(self, dist_map: torch.Tensor, **params) -> List[torch.Tensor]:
results = []
for s in self.ds_scales:
if not isinstance(s, (tuple, list)):
s = [s] * (dist_map.ndim - 1)
else:
assert len(s) == dist_map.ndim - 1

if all([i == 1 for i in s]):
results.append(dist_map)
else:
new_shape = [round(i * j) for i, j in zip(dist_map.shape[1:], s)]
dtype = dist_map.dtype
# interpolate is not defined for short etc
results.append(interpolate(dist_map[None].float(), new_shape, mode='bilinear')[0].to(dtype))
return results

if __name__ == '__main__':
pass
1 change: 1 addition & 0 deletions nnunetv2/training/dataloading/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .distance_transform_dataset import DistanceTransformDataset
Loading

0 comments on commit 5685a27

Please sign in to comment.