diff --git a/petric.py b/petric.py index c282e80..182e588 100755 --- a/petric.py +++ b/petric.py @@ -22,7 +22,6 @@ from traceback import print_exc import numpy as np -from skimage.metrics import mean_squared_error, peak_signal_noise_ratio from tensorboardX import SummaryWriter import sirf.STIR as STIR @@ -88,26 +87,58 @@ def __call__(self, algo: Algorithm): log.debug("...logged") +class QualityMetrics(ImageQualityCallback): + """From https://github.com/SyneRBI/PETRIC/wiki#metrics-and-thresholds""" + def __init__(self, reference_image, backround_mask, foreground_mask, **kwargs): + super().__init__(reference_image, **kwargs) + self.background = np.where(backround_mask == 1) + self.foreground = np.where(foreground_mask == 1) + + def __call__(self, algorithm): + from skimage import metrics as sm + + iteration = algorithm.iteration + if iteration % algorithm.update_objective_interval != 0 and iteration != algorithm.max_iteration: + return + test_image = algorithm.x # CIL or SIRF ImageData + + # # (0) objective value + # objective = algorithm.get_last_objective(return_all=False) + # self.tb_summary_writer.add_scalar('objective', objective, iteration) + + test_im_arr, ref_im_arr = test_image.as_array(), self.reference_image.as_array() + + for filter_name, filter_func in self.filter.items(): + if filter_func is not None: + test_im, ref_im = map(filter_func, (test_im_arr, ref_im_arr)) + + # (1) global metrics & statistics + norm = ref_im[self.background].mean() + self.tb_summary_writer.add_scalar( + f"RMSE_foreground{filter_name}", + np.sqrt(sm.mean_squared_error(test_im[self.foreground], ref_im[self.foreground])) / norm, iteration) + self.tb_summary_writer.add_scalar( + f"RMSE_background{filter_name}", + np.sqrt(sm.mean_squared_error(test_im[self.background], ref_im[self.background])) / norm, iteration) + + # (2) local metrics & statistics + for roi_name, roi_inds in self.roi_indices.items(): + # AEM not to be confused with MAE + self.tb_summary_writer.add_scalar(f"AEM_VOI_{roi_name}{filter_name}", + np.abs(test_im[roi_inds].mean() - ref_im[roi_inds].mean()) / norm, + iteration) + + class MetricsWithTimeout(cbks.Callback): """Stops the algorithm after `seconds`""" - def __init__(self, seconds=300, outdir=OUTDIR, transverse_slice=None, coronal_slice=None, reference_image=None, - verbose=1): + def __init__(self, seconds=300, outdir=OUTDIR, transverse_slice=None, coronal_slice=None, verbose=1): super().__init__(verbose) self._seconds = seconds self.callbacks = [ cbks.ProgressCallback(), SaveIters(outdir=outdir), (tb_cbk := TensorBoard(logdir=outdir, transverse_slice=transverse_slice, coronal_slice=coronal_slice))] - - if reference_image: - roi_image_dict = {f'S{i}': STIR.ImageData(f'S{i}.hv') for i in range(1, 8)} - # NB: these metrics are for testing only. - # The final evaluation will use metrics described in https://github.com/SyneRBI/PETRIC/wiki - self.callbacks.append( - ImageQualityCallback( - reference_image, tb_cbk.tb, roi_mask_dict=roi_image_dict, metrics_dict={ - 'MSE': mean_squared_error, 'MAE': self.mean_absolute_error, 'PSNR': peak_signal_noise_ratio}, - statistics_dict={'MEAN': np.mean, 'STDDEV': np.std, 'MAX': np.max})) + self.tb = tb_cbk.tb self.reset() def reset(self, seconds=None): @@ -144,7 +175,9 @@ def construct_RDP(penalty_strength, initial_image, kappa, max_scaling=1e-3): return prior -Dataset = namedtuple('Dataset', ['acquired_data', 'additive_term', 'mult_factors', 'OSEM_image', 'prior', 'kappa']) +Dataset = namedtuple('Dataset', [ + 'acquired_data', 'additive_term', 'mult_factors', 'OSEM_image', 'prior', 'kappa', 'reference_image', + 'background_mask', 'foreground_mask', 'voi_masks']) def get_data(srcdir=".", outdir=OUTDIR, sirf_verbosity=0): @@ -165,7 +198,18 @@ def get_data(srcdir=".", outdir=OUTDIR, sirf_verbosity=0): penalty_strength = 1 / 700 # default choice prior = construct_RDP(penalty_strength, OSEM_image, kappa) - return Dataset(acquired_data, additive_term, mult_factors, OSEM_image, prior, kappa) + reference_image = STIR.ImageData(str(srcdir / 'reference_image.hv')) if (srcdir / + 'reference_image.hv').is_file() else None + background_mask = STIR.ImageData(str(srcdir / 'VOI_background.hv')) if (srcdir / + 'VOI_background.hv').is_file() else None + foreground_mask = STIR.ImageData(str(srcdir / 'VOI_foreground.hv')) if (srcdir / + 'VOI_foreground.hv').is_file() else None + voi_masks = { + voi.stem: STIR.ImageData(str(voi)) + for voi in srcdir.glob("VOI_*.hv") if voi.stem[4:] not in ('background', 'foreground')} + + return Dataset(acquired_data, additive_term, mult_factors, OSEM_image, prior, kappa, reference_image, + background_mask, foreground_mask, voi_masks) if SRCDIR.is_dir(): @@ -194,7 +238,12 @@ def get_data(srcdir=".", outdir=OUTDIR, sirf_verbosity=0): assert issubclass(Submission, Algorithm) for srcdir, outdir, metrics in data_dirs_metrics: data = get_data(srcdir=srcdir, outdir=outdir) - metrics[0].reset() # timeout from now + metrics_cbk = metrics[0] + if data.reference_image is not None: + metrics_cbk.callbacks.append( + QualityMetrics(data.reference_image, data.background_mask, data.foreground_mask, + tb_summary_writer=metrics_cbk.tb, roi_mask_dict=data.voi_masks)) + metrics_cbk.reset() # timeout from now algo = Submission(data) try: algo.run(np.inf, callbacks=metrics + submission_callbacks)