Skip to content

Commit

Permalink
allow resampling with torch
Browse files Browse the repository at this point in the history
  • Loading branch information
FabianIsensee committed Jun 7, 2024
1 parent 8bd1922 commit a737753
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -294,63 +294,6 @@ def __init__(self, dataset_name_or_id: Union[str, int],
self.max_dataset_covered = 1


class nnUNetPlannerResEncL_torchres(nnUNetPlannerResEncL):
def __init__(self, dataset_name_or_id: Union[str, int],
gpu_memory_target_in_gb: float = 24,
preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetLPlans_torchres',
overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None,
suppress_transpose: bool = False):
super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name,
overwrite_target_spacing, suppress_transpose)

def generate_data_identifier(self, configuration_name: str) -> str:
"""
configurations are unique within each plans file but different plans file can have configurations with the
same name. In order to distinguish the associated data we need a data identifier that reflects not just the
config but also the plans it originates from
"""
return self.plans_identifier + '_' + configuration_name

def determine_resampling(self, *args, **kwargs):
"""
returns what functions to use for resampling data and seg, respectively. Also returns kwargs
resampling function must be callable(data, current_spacing, new_spacing, **kwargs)
determine_resampling is called within get_plans_for_configuration to allow for different functions for each
configuration
"""
resampling_data = resample_torch_fornnunet
resampling_data_kwargs = {
"is_seg": False,
'force_separate_z': False,
'memefficient_seg_resampling': False
}
resampling_seg = resample_torch_fornnunet
resampling_seg_kwargs = {
"is_seg": True,
'force_separate_z': False,
'memefficient_seg_resampling': False
}
return resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs

def determine_segmentation_softmax_export_fn(self, *args, **kwargs):
"""
function must be callable(data, new_shape, current_spacing, new_spacing, **kwargs). The new_shape should be
used as target. current_spacing and new_spacing are merely there in case we want to use it somehow
determine_segmentation_softmax_export_fn is called within get_plans_for_configuration to allow for different
functions for each configuration
"""
resampling_fn = resample_torch_fornnunet
resampling_fn_kwargs = {
"is_seg": False,
'force_separate_z': False,
'memefficient_seg_resampling': False
}
return resampling_fn, resampling_fn_kwargs


if __name__ == '__main__':
# we know both of these networks run with batch size 2 and 12 on ~8-10GB, respectively
net = ResidualEncoderUNet(input_channels=1, n_stages=6, features_per_stage=(32, 64, 128, 256, 320, 320),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,6 @@ def run(self, dataset_name_or_id: Union[int, str], configuration_name: str, plan
pbar.update()
remaining = [i for i in remaining if i not in done]
sleep(0.1)
_ = [i.get() for i in r]

def modify_seg_fn(self, seg: np.ndarray, plans_manager: PlansManager, dataset_json: dict,
configuration_manager: ConfigurationManager) -> np.ndarray:
Expand Down

0 comments on commit a737753

Please sign in to comment.