diff --git a/petric.py b/petric.py index 7dd3b5a..135db80 100755 --- a/petric.py +++ b/petric.py @@ -92,6 +92,7 @@ class MetricsWithTimeout(cbks.Callback): def __init__(self, seconds=300, outdir=OUTDIR, transverse_slice=None, coronal_slice=None, reference_image=None, verbose=1): super().__init__(verbose) + self._seconds = seconds self.callbacks = [ cbks.ProgressCallback(), SaveIters(outdir=outdir), @@ -106,8 +107,10 @@ def __init__(self, seconds=300, outdir=OUTDIR, transverse_slice=None, coronal_sl reference_image, tb_cbk.tb, roi_mask_dict=roi_image_dict, metrics_dict={ 'MSE': mean_squared_error, 'MAE': self.mean_absolute_error, 'PSNR': peak_signal_noise_ratio}, statistics_dict={'MEAN': np.mean, 'STDDEV': np.std, 'MAX': np.max})) + self.reset() - self.limit = time() + seconds + def reset(self, seconds=None): + self.limit = time() + (self._seconds if seconds is None else seconds) def __call__(self, algorithm: Algorithm): if (now := time()) > self.limit: @@ -186,6 +189,7 @@ def get_data(srcdir=".", outdir=OUTDIR, sirf_verbosity=0): from main import Submission, submission_callbacks assert issubclass(Submission, Algorithm) for data, metrics in data_metrics_pairs: + metrics[0].reset() # timeout from now algo = Submission(data) try: algo.run(np.inf, callbacks=metrics + submission_callbacks)