Skip to content

Commit

Permalink
removed printing (again)
Browse files Browse the repository at this point in the history
  • Loading branch information
samdporter committed Sep 30, 2024
1 parent 4d971b0 commit 3955bc2
Showing 1 changed file with 1 addition and 8 deletions.
9 changes: 1 addition & 8 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ def get_step_size(self, algorithm):
if self.counter < self.steps: # or algorithm.iteration == self.update_interval:
if self.f_x is None:
self.f_x = algorithm.f(algorithm.solution) + algorithm.g(algorithm.solution)
print(f"Current function value: {self.f_x}")
precond_grad = algorithm.preconditioner.apply(algorithm, algorithm.gradient_update)
g_norm = algorithm.gradient_update.dot(precond_grad)

Expand All @@ -137,14 +136,11 @@ def get_step_size(self, algorithm):
x_new = algorithm.solution.copy().sapyb(1, precond_grad, -step_size)
algorithm.g.proximal(x_new, step_size, out=x_new)
f_x_new = algorithm.f(x_new) + algorithm.g(x_new)
print("f_x_new: ", f_x_new)
# Armijo condition check
if f_x_new <= self.f_x - self.tol * step_size * g_norm:
self.f_x = f_x_new
break

print(f"step_size: {step_size}")

# Reduce step size
step_size *= self.beta

Expand All @@ -154,13 +150,11 @@ def get_step_size(self, algorithm):
if self.counter < self.steps:
self.counter += 1

print(f"Armijo step size: {step_size}")
return step_size

# Apply linear decay if Armijo steps are done
if self.linear_decay is None: # or algorithm.iteration == self.update_interval:
self.f_x = None
print(f"Linear decay with decay rate: {self.decay} and step size: {self.step_size}")
self.linear_decay = LinearDecayStepSizeRule(self.step_size, self.decay)

# Return decayed step size
Expand Down Expand Up @@ -231,7 +225,6 @@ def __init__(self, data: Dataset, update_objective_interval=10):
decay_perc = 0.1
decay = (1/(1-decay_perc) - 1)/update_interval
beta = 0.5
print(self.num_subsets)

data_subs, acq_models, obj_funs = partitioner.data_partition(data.acquired_data, data.additive_term,
data.mult_factors, self.num_subsets, mode='staggered',
Expand All @@ -241,7 +234,7 @@ def __init__(self, data: Dataset, update_objective_interval=10):
data.prior.set_penalisation_factor(data.prior.get_penalisation_factor() / len(obj_funs))
data.prior.set_up(data.OSEM_image)

for f, d in zip(obj_funs, data_subs): # add prior to every objective function
for f in obj_funs: # add prior evenly to every objective function
f.set_prior(data.prior)

sampler = Sampler.random_without_replacement(len(obj_funs))
Expand Down

0 comments on commit 3955bc2

Please sign in to comment.