From a4dc9d0d58dbf3086a3a4bfd40bf47bc2645cb8e Mon Sep 17 00:00:00 2001 From: Fabian Damken Date: Tue, 15 Aug 2023 15:22:28 +0200 Subject: [PATCH] Refactor SPDR for readability --- Pyrado/pyrado/algorithms/meta/spdr.py | 279 +++++++++++++------------- 1 file changed, 144 insertions(+), 135 deletions(-) diff --git a/Pyrado/pyrado/algorithms/meta/spdr.py b/Pyrado/pyrado/algorithms/meta/spdr.py index 7357c6c6cbd..492439f6e75 100644 --- a/Pyrado/pyrado/algorithms/meta/spdr.py +++ b/Pyrado/pyrado/algorithms/meta/spdr.py @@ -28,7 +28,7 @@ import os.path from csv import DictWriter -from typing import Iterator, Optional, Tuple +from typing import Iterator, Optional, Tuple, List, Callable import numpy as np import torch as to @@ -283,6 +283,16 @@ def sample_count(self) -> int: # Forward to subroutine return self._subrtn.sample_count + @property + def dim(self) -> int: + return self._spl_parameter.target_mean.shape[0] + + @property + def subrtn_sampler(self) -> RolloutSavingWrapper: + # It is checked in the constructor that the sampler is a RolloutSavingWrapper. + # noinspection PyTypeChecker + return self._subrtn.sampler + def step(self, snapshot_mode: str, meta_info: dict = None): """ Perform a step of SPDR. This includes training the subroutine and updating the context distribution accordingly. @@ -290,31 +300,12 @@ def step(self, snapshot_mode: str, meta_info: dict = None): """ self.save_snapshot() - context_mean = self._spl_parameter.context_mean.double() - context_cov = self._spl_parameter.context_cov.double() - context_cov_chol = self._spl_parameter.context_cov_chol.double() - target_mean = self._spl_parameter.target_mean.double() - target_cov_chol = self._spl_parameter.target_cov_chol.double() - # Add these keys to the logger as dummy values. self.logger.add_value("sprl number of particles", 0) self.logger.add_value("spdr constraint kl", 0.0) self.logger.add_value("spdr constraint performance", 0.0) self.logger.add_value("spdr objective", 0.0) - for param_a_idx, param_a_name in enumerate(self._spl_parameter.name): - for param_b_idx, param_b_name in enumerate(self._spl_parameter.name): - self.logger.add_value( - f"context cov for {param_a_name}--{param_b_name}", context_cov[param_a_idx, param_b_idx].item() - ) - self.logger.add_value( - f"context cov_chol for {param_a_name}--{param_b_name}", - context_cov_chol[param_a_idx, param_b_idx].item(), - ) - if param_a_name == param_b_name: - self.logger.add_value(f"context mean for {param_a_name}", context_mean[param_a_idx].item()) - break - - dim = context_mean.shape[0] + self._log_context_distribution() # If we are in the first iteration and have a bad performance, # we want to completely reset the policy if training is unsuccessful @@ -325,110 +316,34 @@ def step(self, snapshot_mode: str, meta_info: dict = None): self._train_subroutine_and_evaluate_perf )(snapshot_mode, meta_info, reset_policy) - # Update distribution - previous_distribution = MultivariateNormalWrapper(context_mean, context_cov_chol) - target_distribution = MultivariateNormalWrapper(target_mean, target_cov_chol) - - def get_domain_param_value(ro: StepSequence, param_name: str) -> np.ndarray: - domain_param_dict = ro.rollout_info["domain_param"] - untransformed_param_name = param_name + DomainParamTransform.UNTRANSFORMED_DOMAIN_PARAMETER_SUFFIX - if untransformed_param_name in domain_param_dict: - return domain_param_dict[untransformed_param_name] - return domain_param_dict[param_name] + previous_distribution = MultivariateNormalWrapper(self._spl_parameter.context_mean.double(), self._spl_parameter.context_cov_chol.double()) + target_distribution = MultivariateNormalWrapper(self._spl_parameter.target_mean.double(), self._spl_parameter.target_cov_chol.double()) - rollouts_all = self._get_sampler().rollouts - contexts = to.tensor( - [ - [to.from_numpy(get_domain_param_value(ro, name)) for rollouts in rollouts_all for ro in rollouts] - for name in self._spl_parameter.name - ], - requires_grad=True, - ).T - - self.logger.add_value("sprl number of particles", contexts.shape[0]) - - contexts_old_log_prob = previous_distribution.distribution.log_prob(contexts.double()) - - values = to.tensor([ro.undiscounted_return() for rollouts in rollouts_all for ro in rollouts]) - - constraints = [] - - def kl_constraint_fn(x): - """Compute the constraint for the KL-divergence between current and proposed distribution.""" - distribution = MultivariateNormalWrapper.from_stacked(dim, x) - kl_divergence = to.distributions.kl_divergence( - previous_distribution.distribution, distribution.distribution - ) - return kl_divergence.detach().numpy() - - def kl_constraint_fn_prime(x): - """Compute the derivative for the KL-constraint (used for scipy optimizer).""" - distribution = MultivariateNormalWrapper.from_stacked(dim, x) - kl_divergence = to.distributions.kl_divergence( - previous_distribution.distribution, distribution.distribution - ) - grads = to.autograd.grad(kl_divergence, list(distribution.parameters())) - return np.concatenate([g.detach().numpy() for g in grads]) - - constraints.append( - NonlinearConstraint( - fun=kl_constraint_fn, - lb=-np.inf, - ub=self._kl_constraints_ub, - jac=kl_constraint_fn_prime, - ) - ) - - def performance_constraint_fn(x): - """Compute the constraint for the expected performance under the proposed distribution.""" - distribution = MultivariateNormalWrapper.from_stacked(dim, x) - performance = self._compute_expected_performance(distribution, contexts, contexts_old_log_prob, values) - return performance.detach().numpy() - - def performance_constraint_fn_prime(x): - """Compute the derivative for the performance-constraint (used for scipy optimizer).""" - distribution = MultivariateNormalWrapper.from_stacked(dim, x) - performance = self._compute_expected_performance(distribution, contexts, contexts_old_log_prob, values) - grads = to.autograd.grad(performance, list(distribution.parameters())) - return np.concatenate([g.detach().numpy() for g in grads]) - - constraints.append( - NonlinearConstraint( - fun=performance_constraint_fn, - lb=self._performance_lower_bound, - ub=np.inf, - jac=performance_constraint_fn_prime, - ) - ) - - # We now optimize based on the kl-divergence between target and context distribution by minimizing it - def objective_fn(x): - """Tries to find the minimum kl divergence between the current and the update distribution, which - still satisfies the minimum update constraint and the performance constraint.""" - distribution = MultivariateNormalWrapper.from_stacked(dim, x) - kl_divergence = to.distributions.kl_divergence(distribution.distribution, target_distribution.distribution) - grads = to.autograd.grad(kl_divergence, list(distribution.parameters())) - - return ( - kl_divergence.detach().numpy(), - np.concatenate([g.detach().numpy() for g in grads]), - ) + proposal_rollouts = self._sample_proposal_rollouts() + contexts, contexts_old_log_prob, values = self._extract_particles(proposal_rollouts, previous_distribution) + # Define the SPRL optimization problem + kl_constraint = self._make_kl_constraint(previous_distribution, self._kl_constraints_ub) + performance_constraint = self._make_performance_constraint(contexts, contexts_old_log_prob, values, self._performance_lower_bound) + constraints = [kl_constraint, performance_constraint] + objective_fn = self._make_objective_fn(target_distribution) x0 = previous_distribution.get_stacked() + minimize_args = dict( + fun=objective_fn, + x0=x0, + method="trust-constr", + jac=True, + constraints=constraints, + options={"gtol": 1e-4, "xtol": 1e-6}, + # bounds=bounds, + ) print("Performing SPDR update.") try: - # noinspection PyTypeChecker - result = minimize( - objective_fn, - x0, - method="trust-constr", - jac=True, - constraints=constraints, - options={"gtol": 1e-4, "xtol": 1e-6}, - # bounds=bounds, - ) + result = minimize(**minimize_args) new_x = result.x + + # Reset parameters if optimization was not successful if not result.success: # If optimization process was not a success old_f = objective_fn(previous_distribution.get_stacked())[0] @@ -445,18 +360,20 @@ def objective_fn(x): print(f"Update failed with error, keeping old SPDR parameters.", e) new_x = x0 - self._adapt_parameters(dim, new_x) - self.logger.add_value("spdr constraint kl", kl_constraint_fn(new_x).item()) - self.logger.add_value("spdr constraint performance", performance_constraint_fn(new_x).item()) - self.logger.add_value("spdr objective", objective_fn(new_x)[0].item()) + self._adapt_parameters(new_x) + + # we can't use the stored values here as new_x might not be result.x + self.logger.add_value("spdr constraint kl", kl_constraint.fun(new_x)) + self.logger.add_value("spdr constraint performance", performance_constraint.fun(new_x)) + self.logger.add_value("spdr objective", objective_fn(new_x)[0]) def reset(self, seed: int = None): # Forward to subroutine self._subrtn.reset(seed) - self._get_sampler().reset_rollouts() + self.subrtn_sampler.reset_rollouts() def save_snapshot(self, meta_info: dict = None): - self._get_sampler().reset_rollouts() + self.subrtn_sampler.reset_rollouts() super().save_snapshot(meta_info) if meta_info is None: @@ -475,17 +392,114 @@ def load_snapshot(self, parsed_args) -> Tuple[Env, Policy, dict]: return env, policy, extra - def _compute_expected_performance( - self, distribution: MultivariateNormalWrapper, context: to.Tensor, old_log_prop: to.Tensor, values: to.Tensor - ) -> to.Tensor: + def _make_objective_fn(self, target_distribution: MultivariateNormalWrapper) -> Callable[[np.ndarray], Tuple[float, np.ndarray]]: + def objective_fn(x): + """Tries to find the minimum kl divergence between the current and the update distribution, which + still satisfies the minimum update constraint and the performance constraint.""" + distribution = MultivariateNormalWrapper.from_stacked(self.dim, x) + kl_divergence = to.distributions.kl_divergence(distribution.distribution, target_distribution.distribution) + grads = to.autograd.grad(kl_divergence, list(distribution.parameters())) + + return ( + kl_divergence.detach().numpy().item(), + np.concatenate([g.detach().numpy() for g in grads]), + ) + + return objective_fn + + def _make_kl_constraint(self, previous_distribution: MultivariateNormalWrapper, kl_constraint_ub: float) -> NonlinearConstraint: + def kl_constraint_fn(x): + """Compute the constraint for the KL-divergence between current and proposed distribution.""" + distribution = MultivariateNormalWrapper.from_stacked(self.dim, x) + kl_divergence = to.distributions.kl_divergence(previous_distribution.distribution, distribution.distribution) + return kl_divergence.detach().numpy().item() + + def kl_constraint_fn_prime(x): + """Compute the derivative for the KL-constraint (used for scipy optimizer).""" + distribution = MultivariateNormalWrapper.from_stacked(self.dim, x) + kl_divergence = to.distributions.kl_divergence(previous_distribution.distribution, distribution.distribution) + grads = to.autograd.grad(kl_divergence, list(distribution.parameters())) + return np.concatenate([g.detach().numpy() for g in grads]) + + return NonlinearConstraint( + fun=kl_constraint_fn, + lb=-np.inf, + ub=kl_constraint_ub, + jac=kl_constraint_fn_prime, + ) + + def _make_performance_constraint(self, contexts: to.Tensor, contexts_old_log_prob: to.Tensor, values: to.Tensor, performance_lower_bound: float) -> NonlinearConstraint: + def performance_constraint_fn(x): + """Compute the constraint for the expected performance under the proposed distribution.""" + distribution = MultivariateNormalWrapper.from_stacked(self.dim, x) + performance = self._compute_expected_performance(distribution, contexts, contexts_old_log_prob, values) + return performance.detach().numpy().item() + + def performance_constraint_fn_prime(x): + """Compute the derivative for the performance-constraint (used for scipy optimizer).""" + distribution = MultivariateNormalWrapper.from_stacked(self.dim, x) + performance = self._compute_expected_performance(distribution, contexts, contexts_old_log_prob, values) + grads = to.autograd.grad(performance, list(distribution.parameters())) + return np.concatenate([g.detach().numpy() for g in grads]) + + return NonlinearConstraint( + fun=performance_constraint_fn, + lb=performance_lower_bound, + ub=np.inf, + jac=performance_constraint_fn_prime, + ) + + def _log_context_distribution(self): + context_mean = self._spl_parameter.context_mean.double() + context_cov = self._spl_parameter.context_cov.double() + context_cov_chol = self._spl_parameter.context_cov_chol.double() + for param_a_idx, param_a_name in enumerate(self._spl_parameter.name): + for param_b_idx, param_b_name in enumerate(self._spl_parameter.name): + self.logger.add_value( + f"context cov for {param_a_name}--{param_b_name}", + context_cov[param_a_idx, param_b_idx].item(), + ) + self.logger.add_value( + f"context cov_chol for {param_a_name}--{param_b_name}", + context_cov_chol[param_a_idx, param_b_idx].item(), + ) + if param_a_name == param_b_name: + self.logger.add_value(f"context mean for {param_a_name}", context_mean[param_a_idx].item()) + break + + def _sample_proposal_rollouts(self) -> List[List[StepSequence]]: + return self.subrtn_sampler.rollouts + + def _extract_particles(self, rollouts_all: List[List[StepSequence]], distribution: MultivariateNormalWrapper) -> Tuple[to.Tensor, to.Tensor, to.Tensor]: + def get_domain_param_value(ro: StepSequence, param_name: str) -> np.ndarray: + domain_param_dict = ro.rollout_info["domain_param"] + untransformed_param_name = param_name + DomainParamTransform.UNTRANSFORMED_DOMAIN_PARAMETER_SUFFIX + if untransformed_param_name in domain_param_dict: + return domain_param_dict[untransformed_param_name] + return domain_param_dict[param_name] + + contexts = to.tensor( + [ + [to.from_numpy(get_domain_param_value(ro, name)) for rollouts in rollouts_all for ro in rollouts] + for name in self._spl_parameter.name + ], + requires_grad=True, + ).T + self.logger.add_value("sprl number of particles", contexts.shape[0]) + contexts_log_prob = distribution.distribution.log_prob(contexts.double()) + values = to.tensor([ro.undiscounted_return() for rollouts in rollouts_all for ro in rollouts]) + return contexts, contexts_log_prob, values + + # noinspection PyMethodMayBeStatic + def _compute_expected_performance(self, distribution: MultivariateNormalWrapper, context: to.Tensor, old_log_prop: to.Tensor, values: to.Tensor) -> to.Tensor: """Calculate the expected performance after an update step.""" context_ratio = to.exp(distribution.distribution.log_prob(context) - old_log_prop) return to.mean(context_ratio * values) - def _adapt_parameters(self, dim: int, result: np.ndarray) -> None: + def _adapt_parameters(self, result: np.ndarray) -> None: """Update the parameters of the distribution based on the result of the optimization step and the general algorithm settings.""" - context_distr = MultivariateNormalWrapper.from_stacked(dim, result) + context_distr = MultivariateNormalWrapper.from_stacked(self.dim, result) self._spl_parameter.adapt("context_mean", context_distr.mean) self._spl_parameter.adapt("context_cov_chol", context_distr.cov_chol) @@ -504,10 +518,5 @@ def _train_subroutine_and_evaluate_perf( self._subrtn.reset() self._subrtn.train(snapshot_mode, None, meta_info) - rollouts_all = self._get_sampler().rollouts + rollouts_all = self.subrtn_sampler.rollouts return np.median([[ro.undiscounted_return() for rollouts in rollouts_all for ro in rollouts]]).item() - - def _get_sampler(self) -> RolloutSavingWrapper: - # It is checked in the constructor that the sampler is a RolloutSavingWrapper. - # noinspection PyTypeChecker - return self._subrtn.sampler