Skip to content

Commit

Permalink
final cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
casperdcl committed Jul 4, 2024
1 parent fcd5f71 commit a4b75e7
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 66 deletions.
2 changes: 1 addition & 1 deletion main.py
18 changes: 8 additions & 10 deletions main_ista.py → main_ISTA.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
>>> algorithm = Submission(data)
>>> algorithm.run(np.inf, callbacks=metrics + submission_callbacks)
"""
from cil.optimisation.algorithms import GD, ISTA, Algorithm
from cil.optimisation.functions import SGFunction, IndicatorBox
from cil.optimisation.algorithms import ISTA, Algorithm
from cil.optimisation.functions import IndicatorBox, SGFunction
from cil.optimisation.utilities import ConstantStepSize, Sampler, callbacks
from petric import Dataset
from sirf.contrib.partitioner import partitioner

assert issubclass(GD, Algorithm)
assert issubclass(ISTA, Algorithm)


class MaxIteration(callbacks.Callback):
Expand All @@ -29,7 +29,7 @@ def __call__(self, algorithm: Algorithm):


class Submission(ISTA):
# note that `issubclass(GD, Algorithm) == True`
# note that `issubclass(ISTA, Algorithm) == True`
def __init__(self, data: Dataset, num_subsets: int = 7, step_size: float = 1e-6,
update_objective_interval: int = 10):
"""
Expand All @@ -45,13 +45,11 @@ def __init__(self, data: Dataset, num_subsets: int = 7, step_size: float = 1e-6,
f.set_prior(data.prior)

sampler = Sampler.random_without_replacement(len(obj_funs))
F = -SGFunction(obj_funs, sampler=sampler) # negative to turn minimiser into maximiser
step_size_rule = ConstantStepSize(step_size) # ISTA default step_size is 0.99*2.0/F.L
F = -SGFunction(obj_funs, sampler=sampler) # negative to turn minimiser into maximiser
step_size_rule = ConstantStepSize(step_size) # ISTA default step_size is 0.99*2.0/F.L
g = IndicatorBox(lower=1e-6, accelerated=False) # "non-negativity" constraint

g = IndicatorBox(lower=1e-5, accelerated=False)

super().__init__(initial=data.OSEM_image.get_uniform_copy(1e-5),
f=F, g=g, step_size=step_size_rule,
super().__init__(initial=data.OSEM_image, f=F, g=g, step_size=step_size_rule,
update_objective_interval=update_objective_interval)


Expand Down
55 changes: 0 additions & 55 deletions main_SGD.py

This file was deleted.

0 comments on commit a4b75e7

Please sign in to comment.