Skip to content

Commit

Permalink
Refactor SPDR for readability
Browse files Browse the repository at this point in the history
  • Loading branch information
fdamken committed Aug 15, 2023
1 parent 2a3451c commit a4dc9d0
Showing 1 changed file with 144 additions and 135 deletions.
279 changes: 144 additions & 135 deletions Pyrado/pyrado/algorithms/meta/spdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

import os.path
from csv import DictWriter
from typing import Iterator, Optional, Tuple
from typing import Iterator, Optional, Tuple, List, Callable

import numpy as np
import torch as to
Expand Down Expand Up @@ -283,38 +283,29 @@ def sample_count(self) -> int:
# Forward to subroutine
return self._subrtn.sample_count

@property
def dim(self) -> int:
return self._spl_parameter.target_mean.shape[0]

@property
def subrtn_sampler(self) -> RolloutSavingWrapper:
# It is checked in the constructor that the sampler is a RolloutSavingWrapper.
# noinspection PyTypeChecker
return self._subrtn.sampler

def step(self, snapshot_mode: str, meta_info: dict = None):
"""
Perform a step of SPDR. This includes training the subroutine and updating the context distribution accordingly.
For a description of the parameters see `pyrado.algorithms.base.Algorithm.step`.
"""
self.save_snapshot()

context_mean = self._spl_parameter.context_mean.double()
context_cov = self._spl_parameter.context_cov.double()
context_cov_chol = self._spl_parameter.context_cov_chol.double()
target_mean = self._spl_parameter.target_mean.double()
target_cov_chol = self._spl_parameter.target_cov_chol.double()

# Add these keys to the logger as dummy values.
self.logger.add_value("sprl number of particles", 0)
self.logger.add_value("spdr constraint kl", 0.0)
self.logger.add_value("spdr constraint performance", 0.0)
self.logger.add_value("spdr objective", 0.0)
for param_a_idx, param_a_name in enumerate(self._spl_parameter.name):
for param_b_idx, param_b_name in enumerate(self._spl_parameter.name):
self.logger.add_value(
f"context cov for {param_a_name}--{param_b_name}", context_cov[param_a_idx, param_b_idx].item()
)
self.logger.add_value(
f"context cov_chol for {param_a_name}--{param_b_name}",
context_cov_chol[param_a_idx, param_b_idx].item(),
)
if param_a_name == param_b_name:
self.logger.add_value(f"context mean for {param_a_name}", context_mean[param_a_idx].item())
break

dim = context_mean.shape[0]
self._log_context_distribution()

# If we are in the first iteration and have a bad performance,
# we want to completely reset the policy if training is unsuccessful
Expand All @@ -325,110 +316,34 @@ def step(self, snapshot_mode: str, meta_info: dict = None):
self._train_subroutine_and_evaluate_perf
)(snapshot_mode, meta_info, reset_policy)

# Update distribution
previous_distribution = MultivariateNormalWrapper(context_mean, context_cov_chol)
target_distribution = MultivariateNormalWrapper(target_mean, target_cov_chol)

def get_domain_param_value(ro: StepSequence, param_name: str) -> np.ndarray:
domain_param_dict = ro.rollout_info["domain_param"]
untransformed_param_name = param_name + DomainParamTransform.UNTRANSFORMED_DOMAIN_PARAMETER_SUFFIX
if untransformed_param_name in domain_param_dict:
return domain_param_dict[untransformed_param_name]
return domain_param_dict[param_name]
previous_distribution = MultivariateNormalWrapper(self._spl_parameter.context_mean.double(), self._spl_parameter.context_cov_chol.double())
target_distribution = MultivariateNormalWrapper(self._spl_parameter.target_mean.double(), self._spl_parameter.target_cov_chol.double())

rollouts_all = self._get_sampler().rollouts
contexts = to.tensor(
[
[to.from_numpy(get_domain_param_value(ro, name)) for rollouts in rollouts_all for ro in rollouts]
for name in self._spl_parameter.name
],
requires_grad=True,
).T

self.logger.add_value("sprl number of particles", contexts.shape[0])

contexts_old_log_prob = previous_distribution.distribution.log_prob(contexts.double())

values = to.tensor([ro.undiscounted_return() for rollouts in rollouts_all for ro in rollouts])

constraints = []

def kl_constraint_fn(x):
"""Compute the constraint for the KL-divergence between current and proposed distribution."""
distribution = MultivariateNormalWrapper.from_stacked(dim, x)
kl_divergence = to.distributions.kl_divergence(
previous_distribution.distribution, distribution.distribution
)
return kl_divergence.detach().numpy()

def kl_constraint_fn_prime(x):
"""Compute the derivative for the KL-constraint (used for scipy optimizer)."""
distribution = MultivariateNormalWrapper.from_stacked(dim, x)
kl_divergence = to.distributions.kl_divergence(
previous_distribution.distribution, distribution.distribution
)
grads = to.autograd.grad(kl_divergence, list(distribution.parameters()))
return np.concatenate([g.detach().numpy() for g in grads])

constraints.append(
NonlinearConstraint(
fun=kl_constraint_fn,
lb=-np.inf,
ub=self._kl_constraints_ub,
jac=kl_constraint_fn_prime,
)
)

def performance_constraint_fn(x):
"""Compute the constraint for the expected performance under the proposed distribution."""
distribution = MultivariateNormalWrapper.from_stacked(dim, x)
performance = self._compute_expected_performance(distribution, contexts, contexts_old_log_prob, values)
return performance.detach().numpy()

def performance_constraint_fn_prime(x):
"""Compute the derivative for the performance-constraint (used for scipy optimizer)."""
distribution = MultivariateNormalWrapper.from_stacked(dim, x)
performance = self._compute_expected_performance(distribution, contexts, contexts_old_log_prob, values)
grads = to.autograd.grad(performance, list(distribution.parameters()))
return np.concatenate([g.detach().numpy() for g in grads])

constraints.append(
NonlinearConstraint(
fun=performance_constraint_fn,
lb=self._performance_lower_bound,
ub=np.inf,
jac=performance_constraint_fn_prime,
)
)

# We now optimize based on the kl-divergence between target and context distribution by minimizing it
def objective_fn(x):
"""Tries to find the minimum kl divergence between the current and the update distribution, which
still satisfies the minimum update constraint and the performance constraint."""
distribution = MultivariateNormalWrapper.from_stacked(dim, x)
kl_divergence = to.distributions.kl_divergence(distribution.distribution, target_distribution.distribution)
grads = to.autograd.grad(kl_divergence, list(distribution.parameters()))

return (
kl_divergence.detach().numpy(),
np.concatenate([g.detach().numpy() for g in grads]),
)
proposal_rollouts = self._sample_proposal_rollouts()
contexts, contexts_old_log_prob, values = self._extract_particles(proposal_rollouts, previous_distribution)

# Define the SPRL optimization problem
kl_constraint = self._make_kl_constraint(previous_distribution, self._kl_constraints_ub)
performance_constraint = self._make_performance_constraint(contexts, contexts_old_log_prob, values, self._performance_lower_bound)
constraints = [kl_constraint, performance_constraint]
objective_fn = self._make_objective_fn(target_distribution)
x0 = previous_distribution.get_stacked()
minimize_args = dict(
fun=objective_fn,
x0=x0,
method="trust-constr",
jac=True,
constraints=constraints,
options={"gtol": 1e-4, "xtol": 1e-6},
# bounds=bounds,
)

print("Performing SPDR update.")
try:
# noinspection PyTypeChecker
result = minimize(
objective_fn,
x0,
method="trust-constr",
jac=True,
constraints=constraints,
options={"gtol": 1e-4, "xtol": 1e-6},
# bounds=bounds,
)
result = minimize(**minimize_args)
new_x = result.x

# Reset parameters if optimization was not successful
if not result.success:
# If optimization process was not a success
old_f = objective_fn(previous_distribution.get_stacked())[0]
Expand All @@ -445,18 +360,20 @@ def objective_fn(x):
print(f"Update failed with error, keeping old SPDR parameters.", e)
new_x = x0

self._adapt_parameters(dim, new_x)
self.logger.add_value("spdr constraint kl", kl_constraint_fn(new_x).item())
self.logger.add_value("spdr constraint performance", performance_constraint_fn(new_x).item())
self.logger.add_value("spdr objective", objective_fn(new_x)[0].item())
self._adapt_parameters(new_x)

# we can't use the stored values here as new_x might not be result.x
self.logger.add_value("spdr constraint kl", kl_constraint.fun(new_x))
self.logger.add_value("spdr constraint performance", performance_constraint.fun(new_x))
self.logger.add_value("spdr objective", objective_fn(new_x)[0])

def reset(self, seed: int = None):
# Forward to subroutine
self._subrtn.reset(seed)
self._get_sampler().reset_rollouts()
self.subrtn_sampler.reset_rollouts()

def save_snapshot(self, meta_info: dict = None):
self._get_sampler().reset_rollouts()
self.subrtn_sampler.reset_rollouts()
super().save_snapshot(meta_info)

if meta_info is None:
Expand All @@ -475,17 +392,114 @@ def load_snapshot(self, parsed_args) -> Tuple[Env, Policy, dict]:

return env, policy, extra

def _compute_expected_performance(
self, distribution: MultivariateNormalWrapper, context: to.Tensor, old_log_prop: to.Tensor, values: to.Tensor
) -> to.Tensor:
def _make_objective_fn(self, target_distribution: MultivariateNormalWrapper) -> Callable[[np.ndarray], Tuple[float, np.ndarray]]:
def objective_fn(x):
"""Tries to find the minimum kl divergence between the current and the update distribution, which
still satisfies the minimum update constraint and the performance constraint."""
distribution = MultivariateNormalWrapper.from_stacked(self.dim, x)
kl_divergence = to.distributions.kl_divergence(distribution.distribution, target_distribution.distribution)
grads = to.autograd.grad(kl_divergence, list(distribution.parameters()))

return (
kl_divergence.detach().numpy().item(),
np.concatenate([g.detach().numpy() for g in grads]),
)

return objective_fn

def _make_kl_constraint(self, previous_distribution: MultivariateNormalWrapper, kl_constraint_ub: float) -> NonlinearConstraint:
def kl_constraint_fn(x):
"""Compute the constraint for the KL-divergence between current and proposed distribution."""
distribution = MultivariateNormalWrapper.from_stacked(self.dim, x)
kl_divergence = to.distributions.kl_divergence(previous_distribution.distribution, distribution.distribution)
return kl_divergence.detach().numpy().item()

def kl_constraint_fn_prime(x):
"""Compute the derivative for the KL-constraint (used for scipy optimizer)."""
distribution = MultivariateNormalWrapper.from_stacked(self.dim, x)
kl_divergence = to.distributions.kl_divergence(previous_distribution.distribution, distribution.distribution)
grads = to.autograd.grad(kl_divergence, list(distribution.parameters()))
return np.concatenate([g.detach().numpy() for g in grads])

return NonlinearConstraint(
fun=kl_constraint_fn,
lb=-np.inf,
ub=kl_constraint_ub,
jac=kl_constraint_fn_prime,
)

def _make_performance_constraint(self, contexts: to.Tensor, contexts_old_log_prob: to.Tensor, values: to.Tensor, performance_lower_bound: float) -> NonlinearConstraint:
def performance_constraint_fn(x):
"""Compute the constraint for the expected performance under the proposed distribution."""
distribution = MultivariateNormalWrapper.from_stacked(self.dim, x)
performance = self._compute_expected_performance(distribution, contexts, contexts_old_log_prob, values)
return performance.detach().numpy().item()

def performance_constraint_fn_prime(x):
"""Compute the derivative for the performance-constraint (used for scipy optimizer)."""
distribution = MultivariateNormalWrapper.from_stacked(self.dim, x)
performance = self._compute_expected_performance(distribution, contexts, contexts_old_log_prob, values)
grads = to.autograd.grad(performance, list(distribution.parameters()))
return np.concatenate([g.detach().numpy() for g in grads])

return NonlinearConstraint(
fun=performance_constraint_fn,
lb=performance_lower_bound,
ub=np.inf,
jac=performance_constraint_fn_prime,
)

def _log_context_distribution(self):
context_mean = self._spl_parameter.context_mean.double()
context_cov = self._spl_parameter.context_cov.double()
context_cov_chol = self._spl_parameter.context_cov_chol.double()
for param_a_idx, param_a_name in enumerate(self._spl_parameter.name):
for param_b_idx, param_b_name in enumerate(self._spl_parameter.name):
self.logger.add_value(
f"context cov for {param_a_name}--{param_b_name}",
context_cov[param_a_idx, param_b_idx].item(),
)
self.logger.add_value(
f"context cov_chol for {param_a_name}--{param_b_name}",
context_cov_chol[param_a_idx, param_b_idx].item(),
)
if param_a_name == param_b_name:
self.logger.add_value(f"context mean for {param_a_name}", context_mean[param_a_idx].item())
break

def _sample_proposal_rollouts(self) -> List[List[StepSequence]]:
return self.subrtn_sampler.rollouts

def _extract_particles(self, rollouts_all: List[List[StepSequence]], distribution: MultivariateNormalWrapper) -> Tuple[to.Tensor, to.Tensor, to.Tensor]:
def get_domain_param_value(ro: StepSequence, param_name: str) -> np.ndarray:
domain_param_dict = ro.rollout_info["domain_param"]
untransformed_param_name = param_name + DomainParamTransform.UNTRANSFORMED_DOMAIN_PARAMETER_SUFFIX
if untransformed_param_name in domain_param_dict:
return domain_param_dict[untransformed_param_name]
return domain_param_dict[param_name]

contexts = to.tensor(
[
[to.from_numpy(get_domain_param_value(ro, name)) for rollouts in rollouts_all for ro in rollouts]
for name in self._spl_parameter.name
],
requires_grad=True,
).T
self.logger.add_value("sprl number of particles", contexts.shape[0])
contexts_log_prob = distribution.distribution.log_prob(contexts.double())
values = to.tensor([ro.undiscounted_return() for rollouts in rollouts_all for ro in rollouts])
return contexts, contexts_log_prob, values

# noinspection PyMethodMayBeStatic
def _compute_expected_performance(self, distribution: MultivariateNormalWrapper, context: to.Tensor, old_log_prop: to.Tensor, values: to.Tensor) -> to.Tensor:
"""Calculate the expected performance after an update step."""
context_ratio = to.exp(distribution.distribution.log_prob(context) - old_log_prop)
return to.mean(context_ratio * values)

def _adapt_parameters(self, dim: int, result: np.ndarray) -> None:
def _adapt_parameters(self, result: np.ndarray) -> None:
"""Update the parameters of the distribution based on the result of
the optimization step and the general algorithm settings."""
context_distr = MultivariateNormalWrapper.from_stacked(dim, result)
context_distr = MultivariateNormalWrapper.from_stacked(self.dim, result)
self._spl_parameter.adapt("context_mean", context_distr.mean)
self._spl_parameter.adapt("context_cov_chol", context_distr.cov_chol)

Expand All @@ -504,10 +518,5 @@ def _train_subroutine_and_evaluate_perf(
self._subrtn.reset()

self._subrtn.train(snapshot_mode, None, meta_info)
rollouts_all = self._get_sampler().rollouts
rollouts_all = self.subrtn_sampler.rollouts
return np.median([[ro.undiscounted_return() for rollouts in rollouts_all for ro in rollouts]]).item()

def _get_sampler(self) -> RolloutSavingWrapper:
# It is checked in the constructor that the sampler is a RolloutSavingWrapper.
# noinspection PyTypeChecker
return self._subrtn.sampler

0 comments on commit a4dc9d0

Please sign in to comment.