diff --git a/.gitignore b/.gitignore index 8875dc4..0fb5791 100644 --- a/.gitignore +++ b/.gitignore @@ -11,4 +11,4 @@ __pycache__/ *.bak *.ahv *.hv -*.v \ No newline at end of file +*.v diff --git a/main_ISTA.py b/main_ISTA.py index eeda971..17a577e 100644 --- a/main_ISTA.py +++ b/main_ISTA.py @@ -7,7 +7,7 @@ """ from cil.optimisation.algorithms import ISTA, Algorithm from cil.optimisation.functions import IndicatorBox, SGFunction -from cil.optimisation.utilities import ConstantStepSize, Sampler, callbacks +from cil.optimisation.utilities import ConstantStepSize, Preconditioner, Sampler, callbacks from petric import Dataset from sirf.contrib.partitioner import partitioner @@ -28,14 +28,29 @@ def __call__(self, algorithm: Algorithm): raise StopIteration +class MyPreconditioner(Preconditioner): + """ + Example based on the row-sum of the Hessian of the log-likelihood. See: Tsai et al. Fast Quasi-Newton Algorithms + for Penalized Reconstruction in Emission Tomography and Further Improvements via Preconditioning, + IEEE TMI https://doi.org/10.1109/tmi.2017.2786865 + """ + def __init__(self, kappa): + # add an epsilon to avoid division by zero (probably should make epsilon dependent on kappa) + self.kappasq = kappa*kappa + 1e-6 + + def apply(self, algorithm, gradient, out=None): + return gradient.divide(self.kappasq, out=out) + + class Submission(ISTA): + """Stochastic subset version of preconditioned ISTA""" + # note that `issubclass(ISTA, Algorithm) == True` def __init__(self, data: Dataset, num_subsets: int = 7, step_size: float = 1e-6, update_objective_interval: int = 10): """ Initialisation function, setting up data & (hyper)parameters. NB: in practice, `num_subsets` should likely be determined from the data. - WARNING: we also currently ignore the non-negativity constraint here. This is just an example. Try to modify and improve it! """ data_sub, acq_models, obj_funs = partitioner.data_partition(data.acquired_data, data.additive_term, @@ -45,11 +60,12 @@ def __init__(self, data: Dataset, num_subsets: int = 7, step_size: float = 1e-6, f.set_prior(data.prior) sampler = Sampler.random_without_replacement(len(obj_funs)) - F = -SGFunction(obj_funs, sampler=sampler) # negative to turn minimiser into maximiser - step_size_rule = ConstantStepSize(step_size) # ISTA default step_size is 0.99*2.0/F.L - g = IndicatorBox(lower=1e-6, accelerated=False) # "non-negativity" constraint + f = -SGFunction(obj_funs, sampler=sampler) # negative to turn minimiser into maximiser + step_size_rule = ConstantStepSize(step_size) # ISTA default step_size is 0.99*2.0/F.L + g = IndicatorBox(lower=0, accelerated=False) # non-negativity constraint - super().__init__(initial=data.OSEM_image, f=F, g=g, step_size=step_size_rule, + preconditioner = MyPreconditioner(data.kappa) + super().__init__(initial=data.OSEM_image, f=f, g=g, step_size=step_size_rule, preconditioner=preconditioner, update_objective_interval=update_objective_interval) diff --git a/petric.py b/petric.py index 61e12f4..be560f0 100755 --- a/petric.py +++ b/petric.py @@ -143,7 +143,7 @@ def construct_RDP(penalty_strength, initial_image, kappa, max_scaling=1e-3): return prior -Dataset = namedtuple('Dataset', ['acquired_data', 'additive_term', 'mult_factors', 'OSEM_image', 'prior']) +Dataset = namedtuple('Dataset', ['acquired_data', 'additive_term', 'mult_factors', 'OSEM_image', 'prior', 'kappa']) def get_data(srcdir=".", outdir=OUTDIR, sirf_verbosity=0): @@ -152,8 +152,7 @@ def get_data(srcdir=".", outdir=OUTDIR, sirf_verbosity=0): STIR.set_verbosity(sirf_verbosity) # set to higher value to diagnose problems STIR.AcquisitionData.set_storage_scheme('memory') # needed for get_subsets() - _ = STIR.MessageRedirector(str(outdir / 'info.txt'), str(outdir / 'warnings.txt'), - str(outdir / 'errors.txt')) + _ = STIR.MessageRedirector(str(outdir / 'info.txt'), str(outdir / 'warnings.txt'), str(outdir / 'errors.txt')) acquired_data = STIR.AcquisitionData(str(srcdir / 'prompts.hs')) additive_term = STIR.AcquisitionData(str(srcdir / 'additive_term.hs')) mult_factors = STIR.AcquisitionData(str(srcdir / 'mult_factors.hs')) @@ -165,17 +164,16 @@ def get_data(srcdir=".", outdir=OUTDIR, sirf_verbosity=0): penalty_strength = 1 / 700 # default choice prior = construct_RDP(penalty_strength, OSEM_image, kappa) - return Dataset(acquired_data, additive_term, mult_factors, OSEM_image, prior) + return Dataset(acquired_data, additive_term, mult_factors, OSEM_image, prior, kappa) if SRCDIR.is_dir(): - data_dirs_metrics = [ - (SRCDIR / "Siemens_mMR_NEMA_IQ", OUTDIR / "mMR_NEMA", - [MetricsWithTimeout(outdir=OUTDIR / "mMR_NEMA", transverse_slice=72, coronal_slice=109)]), - (SRCDIR / "NeuroLF_Hoffman_Dataset", OUTDIR / "NeuroLF_Hoffman", - [MetricsWithTimeout(outdir=OUTDIR / "NeuroLF_Hoffman", transverse_slice=72)]), - (SRCDIR / "Siemens_Vision600_thorax", OUTDIR / "Vision600_thorax", - [MetricsWithTimeout(outdir=OUTDIR / "Vision600_thorax")])] + data_dirs_metrics = [(SRCDIR / "Siemens_mMR_NEMA_IQ", OUTDIR / "mMR_NEMA", + [MetricsWithTimeout(outdir=OUTDIR / "mMR_NEMA", transverse_slice=72, coronal_slice=109)]), + (SRCDIR / "NeuroLF_Hoffman_Dataset", OUTDIR / "NeuroLF_Hoffman", + [MetricsWithTimeout(outdir=OUTDIR / "NeuroLF_Hoffman", transverse_slice=72)]), + (SRCDIR / "Siemens_Vision600_thorax", OUTDIR / "Vision600_thorax", + [MetricsWithTimeout(outdir=OUTDIR / "Vision600_thorax")])] else: log.warning("Source directory does not exist: %s", SRCDIR) data_dirs_metrics = [(None, None, [])]