Skip to content

Commit

Permalink
Add logging of number of particles
Browse files Browse the repository at this point in the history
  • Loading branch information
fdamken committed Aug 15, 2023
1 parent 3c5d16c commit dfd13f1
Showing 1 changed file with 30 additions and 22 deletions.
52 changes: 30 additions & 22 deletions Pyrado/pyrado/algorithms/meta/spdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ def step(self, snapshot_mode: str, meta_info: dict = None):
target_cov_chol = self._spl_parameter.target_cov_chol.double()

# Add these keys to the logger as dummy values.
self.logger.add_value("sprl number of particles", 0)
self.logger.add_value("spdr constraint kl", 0.0)
self.logger.add_value("spdr constraint performance", 0.0)
self.logger.add_value("spdr objective", 0.0)
Expand Down Expand Up @@ -344,6 +345,8 @@ def get_domain_param_value(ro: StepSequence, param_name: str) -> np.ndarray:
requires_grad=True,
).T

self.logger.add_value("sprl number of particles", contexts.shape[0])

contexts_old_log_prob = previous_distribution.distribution.log_prob(contexts.double())
# kl_divergence = to.distributions.kl_divergence(previous_distribution.distribution, target_distribution.distribution)

Expand Down Expand Up @@ -438,28 +441,33 @@ def objective_fn(x):
x0 = previous_distribution.get_stacked()

print("Performing SPDR update.")
# noinspection PyTypeChecker
result = minimize(
objective_fn,
x0,
method="trust-constr",
jac=True,
constraints=constraints,
options={"gtol": 1e-4, "xtol": 1e-6},
# bounds=bounds,
)
new_x = result.x
if not result.success:
# If optimization process was not a success
old_f = objective_fn(previous_distribution.get_stacked())[0]
constraints_satisfied = all((const.lb <= const.fun(result.x) <= const.ub for const in constraints))

# std_ok = bounds is None or (np.all(bounds.lb <= result.x)) and np.all(result.x <= bounds.ub)
std_ok = True

if not (constraints_satisfied and std_ok and result.fun < old_f):
print(f"Update unsuccessful, keeping old values spl parameters.")
new_x = x0
try:
# noinspection PyTypeChecker
result = minimize(
objective_fn,
x0,
method="trust-constr",
jac=True,
constraints=constraints,
options={"gtol": 1e-4, "xtol": 1e-6},
# bounds=bounds,
)
new_x = result.x
if not result.success:
# If optimization process was not a success
old_f = objective_fn(previous_distribution.get_stacked())[0]
constraints_satisfied = all((const.lb <= const.fun(result.x) <= const.ub for const in constraints))

# std_ok = bounds is None or (np.all(bounds.lb <= result.x)) and np.all(result.x <= bounds.ub)
std_ok = True

update_successful = constraints_satisfied and std_ok and result.fun < old_f
if not update_successful:
print(f"Update unsuccessful, keeping old SPDR parameters.")
new_x = x0
except ValueError as e:
print(f"Update failed with error, keeping old SPDR parameters.", e)
new_x = x0

self._adapt_parameters(dim, new_x)
self.logger.add_value("spdr constraint kl", kl_constraint_fn(new_x).item())
Expand Down

0 comments on commit dfd13f1

Please sign in to comment.