Skip to content

Commit

Permalink
lint & comments
Browse files Browse the repository at this point in the history
  • Loading branch information
casperdcl committed Jul 5, 2024
1 parent fffc8bd commit 970bb18
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 26 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ __pycache__/
*.bak
*.ahv
*.hv
*.v
*.v
34 changes: 18 additions & 16 deletions main_ISTA.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""
from cil.optimisation.algorithms import ISTA, Algorithm
from cil.optimisation.functions import IndicatorBox, SGFunction
from cil.optimisation.utilities import ConstantStepSize, Sampler, callbacks, Preconditioner
from cil.optimisation.utilities import ConstantStepSize, Preconditioner, Sampler, callbacks
from petric import Dataset
from sirf.contrib.partitioner import partitioner

Expand All @@ -27,20 +27,24 @@ def __call__(self, algorithm: Algorithm):
if algorithm.iteration >= self.max_iteration:
raise StopIteration


class MyPreconditioner(Preconditioner):
# Use a preconditioner based on the row-sum of the Hessian of the log-likelihood as an example
# See:
# Tsai, Y.-J., Bousse, A., Ehrhardt, M.J., Stearns, C.W., Ahn, S., Hutton, B., Arridge, S., Thielemans, K., 2017.
# Fast Quasi-Newton Algorithms for Penalized Reconstruction in Emission Tomography and Further Improvements via Preconditioning.
# IEEE Transactions on Medical Imaging 1. https://doi.org/10.1109/tmi.2017.2786865
"""
Example based on the row-sum of the Hessian of the log-likelihood. See: Tsai et al. Fast Quasi-Newton Algorithms
for Penalized Reconstruction in Emission Tomography and Further Improvements via Preconditioning,
IEEE TMI https://doi.org/10.1109/tmi.2017.2786865
"""
def __init__(self, kappa):
# add an epsilon to avoid division by zero. This eps value probably should be made dependent on kappa though.
self.kappasq = kappa * kappa + 1e-6
# add an epsilon to avoid division by zero (probably should make epsilon dependent on kappa)
self.kappasq = kappa*kappa + 1e-6

def apply(self, algorithm, gradient, out=None):
return gradient.divide(self.kappasq, out=out)



class Submission(ISTA):
"""Stochastic subset version of preconditioned ISTA"""

# note that `issubclass(ISTA, Algorithm) == True`
def __init__(self, data: Dataset, num_subsets: int = 7, step_size: float = 1e-6,
update_objective_interval: int = 10):
Expand All @@ -56,14 +60,12 @@ def __init__(self, data: Dataset, num_subsets: int = 7, step_size: float = 1e-6,
f.set_prior(data.prior)

sampler = Sampler.random_without_replacement(len(obj_funs))
F = -SGFunction(obj_funs, sampler=sampler) # negative to turn minimiser into maximiser
step_size_rule = ConstantStepSize(step_size) # ISTA default step_size is 0.99*2.0/F.L
f = -SGFunction(obj_funs, sampler=sampler) # negative to turn minimiser into maximiser
step_size_rule = ConstantStepSize(step_size) # ISTA default step_size is 0.99*2.0/F.L
g = IndicatorBox(lower=0, accelerated=False) # non-negativity constraint

my_preconditioner = MyPreconditioner(data.kappa)
super().__init__(initial=data.OSEM_image,
f=F, g=g, step_size=step_size_rule,
preconditioner=my_preconditioner,
preconditioner = MyPreconditioner(data.kappa)
super().__init__(initial=data.OSEM_image, f=f, g=g, step_size=step_size_rule, preconditioner=preconditioner,
update_objective_interval=update_objective_interval)


Expand Down
16 changes: 7 additions & 9 deletions petric.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,7 @@ def get_data(srcdir=".", outdir=OUTDIR, sirf_verbosity=0):
STIR.set_verbosity(sirf_verbosity) # set to higher value to diagnose problems
STIR.AcquisitionData.set_storage_scheme('memory') # needed for get_subsets()

_ = STIR.MessageRedirector(str(outdir / 'info.txt'), str(outdir / 'warnings.txt'),
str(outdir / 'errors.txt'))
_ = STIR.MessageRedirector(str(outdir / 'info.txt'), str(outdir / 'warnings.txt'), str(outdir / 'errors.txt'))
acquired_data = STIR.AcquisitionData(str(srcdir / 'prompts.hs'))
additive_term = STIR.AcquisitionData(str(srcdir / 'additive_term.hs'))
mult_factors = STIR.AcquisitionData(str(srcdir / 'mult_factors.hs'))
Expand All @@ -169,13 +168,12 @@ def get_data(srcdir=".", outdir=OUTDIR, sirf_verbosity=0):


if SRCDIR.is_dir():
data_dirs_metrics = [
(SRCDIR / "Siemens_mMR_NEMA_IQ", OUTDIR / "mMR_NEMA",
[MetricsWithTimeout(outdir=OUTDIR / "mMR_NEMA", transverse_slice=72, coronal_slice=109)]),
(SRCDIR / "NeuroLF_Hoffman_Dataset", OUTDIR / "NeuroLF_Hoffman",
[MetricsWithTimeout(outdir=OUTDIR / "NeuroLF_Hoffman", transverse_slice=72)]),
(SRCDIR / "Siemens_Vision600_thorax", OUTDIR / "Vision600_thorax",
[MetricsWithTimeout(outdir=OUTDIR / "Vision600_thorax")])]
data_dirs_metrics = [(SRCDIR / "Siemens_mMR_NEMA_IQ", OUTDIR / "mMR_NEMA",
[MetricsWithTimeout(outdir=OUTDIR / "mMR_NEMA", transverse_slice=72, coronal_slice=109)]),
(SRCDIR / "NeuroLF_Hoffman_Dataset", OUTDIR / "NeuroLF_Hoffman",
[MetricsWithTimeout(outdir=OUTDIR / "NeuroLF_Hoffman", transverse_slice=72)]),
(SRCDIR / "Siemens_Vision600_thorax", OUTDIR / "Vision600_thorax",
[MetricsWithTimeout(outdir=OUTDIR / "Vision600_thorax")])]
else:
log.warning("Source directory does not exist: %s", SRCDIR)
data_dirs_metrics = [(None, None, [])]
Expand Down

0 comments on commit 970bb18

Please sign in to comment.