Skip to content

Commit

Permalink
added memory storage for acquistion ratio ot avoid fault
Browse files Browse the repository at this point in the history
  • Loading branch information
samdporter committed Sep 30, 2024
1 parent e6732ce commit 5f9664c
Showing 1 changed file with 15 additions and 14 deletions.
29 changes: 15 additions & 14 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
>>> algorithm.run(np.inf, callbacks=metrics + submission_callbacks)
"""
#%%
from sirf.STIR import AcquisitionData
AcquisitionData.set_storage_scheme('memory')
from cil.optimisation.algorithms import ISTA, Algorithm
from cil.optimisation.functions import IndicatorBox, SVRGFunction
from cil.optimisation.utilities import (Preconditioner, Sampler,
Expand Down Expand Up @@ -67,21 +69,22 @@ class BSREMPreconditioner(Preconditioner):
Preconditioner for BSREM
'''

def __init__(self, acq_models, freeze_iter = np.inf, epsilon=1e-6):
def __init__(self, obj_funs, freeze_iter = np.inf, epsilon=1e-6):

self.epsilon = epsilon
self.freeze_iter = freeze_iter
self.freeze = None

for i,el in enumerate(acq_models):
if i == 0:
self.s_sum_inv = el.domain_geometry().get_uniform_copy(0.)
ones = el.range_geometry().allocate(1.)
s_inv = el.adjoint(ones)
for i,el in enumerate(obj_funs):
s_inv = el.get_subset_sensitivity(0)
s_inv.maximum(0, out=s_inv)
arr = s_inv.as_array()
np.reciprocal(arr, out=arr, where=arr!=0)
s_inv.fill(arr)
self.s_sum_inv += s_inv
if i == 0:
self.s_sum_inv = s_inv
else:
self.s_sum_inv += s_inv

def apply(self, algorithm, gradient, out=None):
if algorithm.iteration < self.freeze_iter:
Expand Down Expand Up @@ -145,8 +148,7 @@ def get_step_size(self, algorithm):
# Armijo step size search
for _ in range(self.max_iter):
# Proximal step
x_new = algorithm.solution.copy().sapyb(1, precond_grad, -step_size)
algorithm.g.proximal(x_new, step_size, out=x_new)
x_new = algorithm.g.proximal(algorithm.solution.copy() - step_size * precond_grad, step_size)
f_x_new = algorithm.f(x_new) + algorithm.g(x_new)
# Armijo condition check
if f_x_new <= self.f_x - self.tol * step_size * g_norm:
Expand Down Expand Up @@ -239,10 +241,9 @@ def __init__(self, data: Dataset, update_objective_interval=10):
decay = (1/(1-decay_perc) - 1)/update_interval
beta = 0.5

data_subs, acq_models, obj_funs = partitioner.data_partition(data.acquired_data, data.additive_term,
_, _, obj_funs = partitioner.data_partition(data.acquired_data, data.additive_term,
data.mult_factors, self.num_subsets, mode='staggered',
initial_image=data.OSEM_image)


data.prior.set_penalisation_factor(data.prior.get_penalisation_factor() / len(obj_funs))
data.prior.set_up(data.OSEM_image)
Expand All @@ -253,12 +254,12 @@ def __init__(self, data: Dataset, update_objective_interval=10):
sampler = Sampler.random_without_replacement(len(obj_funs))
f = -FullGradientInitialiserFunction(obj_funs, sampler=sampler, init_steps=5, store_gradients=True, snapshot_update_interval=update_interval)

preconditioner = BSREMPreconditioner(acq_models, epsilon=data.OSEM_image.max()/1e6, freeze_iter=10*update_interval+5)
g = IndicatorBox(lower=0, accelerated=True) # non-negativity constraint
preconditioner = BSREMPreconditioner(obj_funs, epsilon=data.OSEM_image.max()/1e6, freeze_iter=10*update_interval+5)
g = IndicatorBox(lower=0, accelerated=False) # non-negativity constraint

step_size_rule = ArmijoStepSearchRule(0.08, beta, decay, max_iter=100, tol=0.2, init_steps=5, update_interval=10*update_interval+5)

super().__init__(initial=data.OSEM_image, f=f, g=g, step_size=step_size_rule,
preconditioner=preconditioner, update_objective_interval=update_objective_interval)

submission_callbacks = [] # 10000 iterations max
submission_callbacks = []

0 comments on commit 5f9664c

Please sign in to comment.