diff --git a/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDA5.py b/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDA5.py index f1c283fd1..7250fb845 100644 --- a/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDA5.py +++ b/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDA5.py @@ -233,7 +233,7 @@ def get_training_transforms(patch_size: Union[np.ndarray, Tuple[int]], tr_transforms.append( BrightnessGradientAdditiveTransform( - _transform_scale, + _brightnessadditive_localgamma_transform_scale, (-0.5, 1.5), max_strength=_brightness_gradient_additive_max_strength, mean_centered=False, @@ -245,7 +245,7 @@ def get_training_transforms(patch_size: Union[np.ndarray, Tuple[int]], tr_transforms.append( LocalGammaTransform( - _transform_scale, + _brightnessadditive_localgamma_transform_scale, (-0.5, 1.5), _local_gamma_gamma, same_for_all_channels=False, @@ -353,15 +353,19 @@ def get_dataloaders(self): return mt_gen_train, mt_gen_val -def _transform_scale(x, y): + +def _brightnessadditive_localgamma_transform_scale(x, y): return np.exp(np.random.uniform(np.log(x[y] // 6), np.log(x[y]))) + def _brightness_gradient_additive_max_strength(_x, _y): return np.random.uniform(-5, -1) if np.random.uniform() < 0.5 else np.random.uniform(1, 5) + def _local_gamma_gamma(): return np.random.uniform(0.01, 0.8) if np.random.uniform() < 0.5 else np.random.uniform(1.5, 4) + class nnUNetTrainerDA5Segord0(nnUNetTrainerDA5): def get_dataloaders(self): """