Skip to content

Commit

Permalink
Merge pull request #35 from SyneRBI/ISTA
Browse files Browse the repository at this point in the history
ISTA
  • Loading branch information
casperdcl authored Jul 4, 2024
2 parents a78fb4a + a4b75e7 commit e3f5300
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
2 changes: 1 addition & 1 deletion main.py
19 changes: 10 additions & 9 deletions main_SGD.py → main_ISTA.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
>>> algorithm = Submission(data)
>>> algorithm.run(np.inf, callbacks=metrics + submission_callbacks)
"""
from cil.optimisation.algorithms import GD, Algorithm
from cil.optimisation.functions import SGFunction
from cil.optimisation.algorithms import ISTA, Algorithm
from cil.optimisation.functions import IndicatorBox, SGFunction
from cil.optimisation.utilities import ConstantStepSize, Sampler, callbacks
from petric import Dataset
from sirf.contrib.partitioner import partitioner

assert issubclass(GD, Algorithm)
assert issubclass(ISTA, Algorithm)


class MaxIteration(callbacks.Callback):
Expand All @@ -28,9 +28,9 @@ def __call__(self, algorithm: Algorithm):
raise StopIteration


class Submission(GD):
# note that `issubclass(GD, Algorithm) == True`
def __init__(self, data: Dataset, num_subsets: int = 7, step_size: float = 1e-10,
class Submission(ISTA):
# note that `issubclass(ISTA, Algorithm) == True`
def __init__(self, data: Dataset, num_subsets: int = 7, step_size: float = 1e-6,
update_objective_interval: int = 10):
"""
Initialisation function, setting up data & (hyper)parameters.
Expand All @@ -45,10 +45,11 @@ def __init__(self, data: Dataset, num_subsets: int = 7, step_size: float = 1e-10
f.set_prior(data.prior)

sampler = Sampler.random_without_replacement(len(obj_funs))
F = -SGFunction(obj_funs, sampler=sampler) # negative to turn minimiser into maximiser
step_size_rule = ConstantStepSize(step_size) # ISTA default step_size is 0.99*2.0/F.L
F = -SGFunction(obj_funs, sampler=sampler) # negative to turn minimiser into maximiser
step_size_rule = ConstantStepSize(step_size) # ISTA default step_size is 0.99*2.0/F.L
g = IndicatorBox(lower=1e-6, accelerated=False) # "non-negativity" constraint

super().__init__(initial=data.OSEM_image, objective_function=F, step_size=step_size_rule,
super().__init__(initial=data.OSEM_image, f=F, g=g, step_size=step_size_rule,
update_objective_interval=update_objective_interval)


Expand Down

0 comments on commit e3f5300

Please sign in to comment.