From b1e31da2e74f203212f1f335fa242a9dff391cae Mon Sep 17 00:00:00 2001 From: Casper da Costa-Luis Date: Wed, 11 Sep 2024 15:29:17 +0100 Subject: [PATCH] TensorBoard: add sagittal slice --- petric.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/petric.py b/petric.py index bef5936..eaa45cd 100755 --- a/petric.py +++ b/petric.py @@ -73,10 +73,12 @@ def __call__(self, algo: Algorithm): class StatsLog(Callback): """Log image slices & objective value""" - def __init__(self, transverse_slice=None, coronal_slice=None, vmax=None, logdir=OUTDIR, **kwargs): + def __init__(self, transverse_slice=None, coronal_slice=None, sagittal_slice=None, vmax=None, logdir=OUTDIR, + **kwargs): super().__init__(**kwargs) self.transverse_slice = transverse_slice self.coronal_slice = coronal_slice + self.sagittal_slice = sagittal_slice self.vmax = vmax self.x_prev = None self.tb = logdir if isinstance(logdir, SummaryWriter) else SummaryWriter(logdir=str(logdir)) @@ -89,6 +91,7 @@ def __call__(self, algo: Algorithm): # initialise `None` values self.transverse_slice = algo.x.dimensions()[0] // 2 if self.transverse_slice is None else self.transverse_slice self.coronal_slice = algo.x.dimensions()[1] // 2 if self.coronal_slice is None else self.coronal_slice + self.sagittal_slice = algo.x.dimensions()[2] // 2 if self.sagittal_slice is None else self.sagittal_slice self.vmax = algo.x.max() if self.vmax is None else self.vmax self.tb.add_scalar("objective", algo.get_last_loss(), algo.iteration, t) @@ -97,9 +100,11 @@ def __call__(self, algo: Algorithm): self.tb.add_scalar("normalised_change", normalised_change, algo.iteration, t) self.x_prev = algo.x.clone() x_arr = algo.x.as_array() - self.tb.add_image("transverse", np.clip(x_arr[self.transverse_slice:self.transverse_slice + 1] / self.vmax, 0, - 1), algo.iteration, t) + self.tb.add_image("transverse", np.clip(x_arr[None, self.transverse_slice] / self.vmax, 0, 1), algo.iteration, + t) self.tb.add_image("coronal", np.clip(x_arr[None, :, self.coronal_slice] / self.vmax, 0, 1), algo.iteration, t) + self.tb.add_image("sagittal", np.clip(x_arr[None, :, :, self.sagittal_slice] / self.vmax, 0, 1), algo.iteration, + t) log.debug("...logged") @@ -148,7 +153,8 @@ def __init__(self, seconds=600, outdir=OUTDIR, transverse_slice=None, coronal_sl self.callbacks = [ cil_callbacks.ProgressCallback(), SaveIters(outdir=outdir), - (tb_cbk := StatsLog(logdir=outdir, transverse_slice=transverse_slice, coronal_slice=coronal_slice))] + (tb_cbk := StatsLog(logdir=outdir, transverse_slice=transverse_slice, coronal_slice=coronal_slice, + sagittal_slice=sagittal_slice))] self.tb = tb_cbk.tb # convenient access to the underlying SummaryWriter self.reset()