diff --git a/petric.py b/petric.py index 42bb270..a240530 100755 --- a/petric.py +++ b/petric.py @@ -27,7 +27,7 @@ import sirf.STIR as STIR from cil.optimisation.algorithms import Algorithm -from cil.optimisation.utilities import callbacks as cbks +from cil.optimisation.utilities import callbacks as cil_callbacks from img_quality_cil_stir import ImageQualityCallback log = logging.getLogger('petric') @@ -38,17 +38,31 @@ SRCDIR = Path("./data") -class SaveIters(cbks.Callback): +class Callback(cil_callbacks.Callback): + """ + CIL Callback but with `self.skip_iteration` checking `min(self.interval, algo.update_objective_interval)`. + TODO: backport this class to CIL. + """ + def __init__(self, interval: int = 1 << 31, **kwargs): + super().__init__(**kwargs) + self.interval = interval + + def skip_iteration(self, algo: Algorithm) -> bool: + return algo.iteration % min(self.interval, + algo.update_objective_interval) != 0 and algo.iteration != algo.max_iteration + + +class SaveIters(Callback): """Saves `algo.x` as "iter_{algo.iteration:04d}.hv" and `algo.loss` in `csv_file`""" - def __init__(self, verbose=1, outdir=OUTDIR, csv_file='objectives.csv'): - super().__init__(verbose) + def __init__(self, outdir=OUTDIR, csv_file='objectives.csv', **kwargs): + super().__init__(**kwargs) self.outdir = Path(outdir) self.outdir.mkdir(parents=True, exist_ok=True) self.csv = csv.writer((self.outdir / csv_file).open("w", buffering=1)) self.csv.writerow(("iter", "objective")) def __call__(self, algo: Algorithm): - if algo.iteration % algo.update_objective_interval == 0 or algo.iteration == algo.max_iteration: + if not self.skip_iteration(algo): log.debug("saving iter %d...", algo.iteration) algo.x.write(str(self.outdir / f'iter_{algo.iteration:04d}.hv')) self.csv.writerow((algo.iteration, algo.get_last_loss())) @@ -57,10 +71,10 @@ def __call__(self, algo: Algorithm): algo.x.write(str(self.outdir / 'iter_final.hv')) -class StatsLog(cbks.Callback): +class StatsLog(Callback): """Log image slices & objective value""" - def __init__(self, verbose=1, transverse_slice=None, coronal_slice=None, vmax=None, logdir=OUTDIR): - super().__init__(verbose) + def __init__(self, transverse_slice=None, coronal_slice=None, vmax=None, logdir=OUTDIR, **kwargs): + super().__init__(**kwargs) self.transverse_slice = transverse_slice self.coronal_slice = coronal_slice self.vmax = vmax @@ -68,7 +82,7 @@ def __init__(self, verbose=1, transverse_slice=None, coronal_slice=None, vmax=No self.tb = logdir if isinstance(logdir, SummaryWriter) else SummaryWriter(logdir=str(logdir)) def __call__(self, algo: Algorithm): - if algo.iteration % algo.update_objective_interval != 0 and algo.iteration != algo.max_iteration: + if self.skip_iteration(algo): return log.debug("logging iter %d...", algo.iteration) # initialise `None` values @@ -89,21 +103,22 @@ def __call__(self, algo: Algorithm): log.debug("...logged") -class QualityMetrics(ImageQualityCallback): +class QualityMetrics(ImageQualityCallback, Callback): """From https://github.com/SyneRBI/PETRIC/wiki#metrics-and-thresholds""" - def __init__(self, reference_image, whole_object_mask, background_mask, **kwargs): - super().__init__(reference_image, **kwargs) + def __init__(self, reference_image, whole_object_mask, background_mask, interval: int = 1 << 31, **kwargs): + # TODO: drop multiple inheritance once `interval` included in CIL + Callback.__init__(self, interval=interval) + ImageQualityCallback.__init__(self, reference_image, **kwargs) self.whole_object_indices = np.where(whole_object_mask.as_array()) self.background_indices = np.where(background_mask.as_array()) self.ref_im_arr = reference_image.as_array() self.norm = self.ref_im_arr[self.background_indices].mean() def __call__(self, algo: Algorithm): - iteration = algo.iteration - if iteration % algo.update_objective_interval != 0 and iteration != algo.max_iteration: + if self.skip_iteration(algo): return for tag, value in self.evaluate(algo.x).items(): - self.tb_summary_writer.add_scalar(tag, value, iteration) + self.tb_summary_writer.add_scalar(tag, value, algo.iteration) def evaluate(self, test_im: STIR.ImageData) -> dict[str, float]: assert not any(self.filter.values()), "Filtering not implemented" @@ -120,16 +135,16 @@ def evaluate(self, test_im: STIR.ImageData) -> dict[str, float]: return {**whole, **local} -class MetricsWithTimeout(cbks.Callback): +class MetricsWithTimeout(cil_callbacks.Callback): """Stops the algorithm after `seconds`""" - def __init__(self, seconds=300, outdir=OUTDIR, transverse_slice=None, coronal_slice=None, verbose=1): - super().__init__(verbose) + def __init__(self, seconds=300, outdir=OUTDIR, transverse_slice=None, coronal_slice=None, **kwargs): + super().__init__(**kwargs) self._seconds = seconds self.callbacks = [ - cbks.ProgressCallback(), + cil_callbacks.ProgressCallback(), SaveIters(outdir=outdir), (tb_cbk := StatsLog(logdir=outdir, transverse_slice=transverse_slice, coronal_slice=coronal_slice))] - self.tb = tb_cbk.tb + self.tb = tb_cbk.tb # convenient access to the underlying SummaryWriter self.reset() def reset(self, seconds=None):