Skip to content

Commit

Permalink
better progress output
Browse files Browse the repository at this point in the history
  • Loading branch information
casperdcl committed Oct 28, 2024
1 parent 67f4532 commit 5d2c136
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions petric.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from scipy.ndimage import binary_erosion
from skimage.metrics import mean_squared_error as mse
from tensorboardX import SummaryWriter
from tqdm.auto import tqdm

import sirf.STIR as STIR
from cil.optimisation.algorithms import Algorithm
Expand Down Expand Up @@ -157,7 +158,8 @@ def evaluate(self, test_im: STIR.ImageData) -> dict[str, float]:
f"AEM_VOI_{voi_name}": np.abs(test_im_arr[voi_indices].mean() - self.ref_im_arr[voi_indices].mean()) /
self.norm
for voi_name, voi_indices in sorted(self.voi_indices.items())}
return {**whole, **local}
self._evaluate_cache = {**whole, **local}
return self._evaluate_cache

def keys(self):
return ["RMSE_whole_object", "RMSE_background"] + [f"AEM_VOI_{name}" for name in sorted(self.voi_indices)]
Expand All @@ -181,11 +183,11 @@ def pass_index(metrics: np.ndarray, thresh: Iterable, window: int = 10) -> int:
class MetricsWithTimeout(Callback):
"""Stops the algorithm after `seconds`"""
def __init__(self, seconds=3600, outdir=OUTDIR, transverse_slice=None, coronal_slice=None, sagittal_slice=None,
**kwargs):
tqdm_class=tqdm, **kwargs):
super().__init__(**kwargs)
self._seconds = seconds
self.callbacks = [
cil_callbacks.ProgressCallback(),
cil_callbacks.ProgressCallback(desc=f"{TEAM}/{VERSION}/{outdir.name}"),
SaveIters(outdir=outdir, **kwargs),
(tb_cbk := StatsLog(logdir=outdir, transverse_slice=transverse_slice, coronal_slice=coronal_slice,
sagittal_slice=sagittal_slice, **kwargs))]
Expand All @@ -205,6 +207,10 @@ def __call__(self, algo: Algorithm):
for c in self.callbacks:
c._time_ = time_excluding_metrics
c(algo)
if isinstance(self.callbacks[-1], QualityMetrics) and isinstance(self.callbacks[0],
cil_callbacks.ProgressCallback):
self.callbacks[0].pbar.set_postfix(
RMSE_whole_object=self.callbacks[-1]._evaluate_cache['RMSE_whole_object'], refresh=False)
self.offset += time() - now

@staticmethod
Expand Down

0 comments on commit 5d2c136

Please sign in to comment.