diff --git a/petric.py b/petric.py index f05665f..d86ce77 100755 --- a/petric.py +++ b/petric.py @@ -178,7 +178,7 @@ def pass_index(metrics: np.ndarray, thresh: Iterable, window: int = 10) -> int: return np.where(res)[0][0] -class MetricsWithTimeout(cil_callbacks.Callback): +class MetricsWithTimeout(Callback): """Stops the algorithm after `seconds`""" def __init__(self, seconds=3600, outdir=OUTDIR, transverse_slice=None, coronal_slice=None, sagittal_slice=None, **kwargs): @@ -186,9 +186,9 @@ def __init__(self, seconds=3600, outdir=OUTDIR, transverse_slice=None, coronal_s self._seconds = seconds self.callbacks = [ cil_callbacks.ProgressCallback(), - SaveIters(outdir=outdir), + SaveIters(outdir=outdir, **kwargs), (tb_cbk := StatsLog(logdir=outdir, transverse_slice=transverse_slice, coronal_slice=coronal_slice, - sagittal_slice=sagittal_slice))] + sagittal_slice=sagittal_slice, **kwargs))] self.tb = tb_cbk.tb # convenient access to the underlying SummaryWriter self.reset()