Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
fdamken committed Oct 8, 2023
1 parent 33cfb08 commit 54e7eb0
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 38 deletions.
42 changes: 31 additions & 11 deletions Pyrado/pyrado/algorithms/meta/spdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

import os.path
from csv import DictWriter
from typing import Iterator, Optional, Tuple, List, Callable
from typing import Callable, Iterator, List, Optional, Tuple

import numpy as np
import torch as to
Expand Down Expand Up @@ -316,15 +316,21 @@ def step(self, snapshot_mode: str, meta_info: dict = None):
self._train_subroutine_and_evaluate_perf
)(snapshot_mode, meta_info, reset_policy)

previous_distribution = MultivariateNormalWrapper(self._spl_parameter.context_mean.double(), self._spl_parameter.context_cov_chol.double())
target_distribution = MultivariateNormalWrapper(self._spl_parameter.target_mean.double(), self._spl_parameter.target_cov_chol.double())
previous_distribution = MultivariateNormalWrapper(
self._spl_parameter.context_mean.double(), self._spl_parameter.context_cov_chol.double()
)
target_distribution = MultivariateNormalWrapper(
self._spl_parameter.target_mean.double(), self._spl_parameter.target_cov_chol.double()
)

proposal_rollouts = self._sample_proposal_rollouts()
contexts, contexts_old_log_prob, values = self._extract_particles(proposal_rollouts, previous_distribution)

# Define the SPRL optimization problem
kl_constraint = self._make_kl_constraint(previous_distribution, self._kl_constraints_ub)
performance_constraint = self._make_performance_constraint(contexts, contexts_old_log_prob, values, self._performance_lower_bound)
performance_constraint = self._make_performance_constraint(
contexts, contexts_old_log_prob, values, self._performance_lower_bound
)
constraints = [kl_constraint, performance_constraint]
objective_fn = self._make_objective_fn(target_distribution)
x0 = previous_distribution.get_stacked()
Expand Down Expand Up @@ -392,7 +398,9 @@ def load_snapshot(self, parsed_args) -> Tuple[Env, Policy, dict]:

return env, policy, extra

def _make_objective_fn(self, target_distribution: MultivariateNormalWrapper) -> Callable[[np.ndarray], Tuple[float, np.ndarray]]:
def _make_objective_fn(
self, target_distribution: MultivariateNormalWrapper
) -> Callable[[np.ndarray], Tuple[float, np.ndarray]]:
def objective_fn(x):
"""Tries to find the minimum kl divergence between the current and the update distribution, which
still satisfies the minimum update constraint and the performance constraint."""
Expand All @@ -407,17 +415,23 @@ def objective_fn(x):

return objective_fn

def _make_kl_constraint(self, previous_distribution: MultivariateNormalWrapper, kl_constraint_ub: float) -> NonlinearConstraint:
def _make_kl_constraint(
self, previous_distribution: MultivariateNormalWrapper, kl_constraint_ub: float
) -> NonlinearConstraint:
def kl_constraint_fn(x):
"""Compute the constraint for the KL-divergence between current and proposed distribution."""
distribution = MultivariateNormalWrapper.from_stacked(self.dim, x)
kl_divergence = to.distributions.kl_divergence(previous_distribution.distribution, distribution.distribution)
kl_divergence = to.distributions.kl_divergence(
previous_distribution.distribution, distribution.distribution
)
return kl_divergence.detach().numpy().item()

def kl_constraint_fn_prime(x):
"""Compute the derivative for the KL-constraint (used for scipy optimizer)."""
distribution = MultivariateNormalWrapper.from_stacked(self.dim, x)
kl_divergence = to.distributions.kl_divergence(previous_distribution.distribution, distribution.distribution)
kl_divergence = to.distributions.kl_divergence(
previous_distribution.distribution, distribution.distribution
)
grads = to.autograd.grad(kl_divergence, list(distribution.parameters()))
return np.concatenate([g.detach().numpy() for g in grads])

Expand All @@ -428,7 +442,9 @@ def kl_constraint_fn_prime(x):
jac=kl_constraint_fn_prime,
)

def _make_performance_constraint(self, contexts: to.Tensor, contexts_old_log_prob: to.Tensor, values: to.Tensor, performance_lower_bound: float) -> NonlinearConstraint:
def _make_performance_constraint(
self, contexts: to.Tensor, contexts_old_log_prob: to.Tensor, values: to.Tensor, performance_lower_bound: float
) -> NonlinearConstraint:
def performance_constraint_fn(x):
"""Compute the constraint for the expected performance under the proposed distribution."""
distribution = MultivariateNormalWrapper.from_stacked(self.dim, x)
Expand Down Expand Up @@ -470,7 +486,9 @@ def _log_context_distribution(self):
def _sample_proposal_rollouts(self) -> List[List[StepSequence]]:
return self.subrtn_sampler.rollouts

def _extract_particles(self, rollouts_all: List[List[StepSequence]], distribution: MultivariateNormalWrapper) -> Tuple[to.Tensor, to.Tensor, to.Tensor]:
def _extract_particles(
self, rollouts_all: List[List[StepSequence]], distribution: MultivariateNormalWrapper
) -> Tuple[to.Tensor, to.Tensor, to.Tensor]:
def get_domain_param_value(ro: StepSequence, param_name: str) -> np.ndarray:
domain_param_dict = ro.rollout_info["domain_param"]
untransformed_param_name = param_name + DomainParamTransform.UNTRANSFORMED_DOMAIN_PARAMETER_SUFFIX
Expand All @@ -491,7 +509,9 @@ def get_domain_param_value(ro: StepSequence, param_name: str) -> np.ndarray:
return contexts, contexts_log_prob, values

# noinspection PyMethodMayBeStatic
def _compute_expected_performance(self, distribution: MultivariateNormalWrapper, context: to.Tensor, old_log_prop: to.Tensor, values: to.Tensor) -> to.Tensor:
def _compute_expected_performance(
self, distribution: MultivariateNormalWrapper, context: to.Tensor, old_log_prop: to.Tensor, values: to.Tensor
) -> to.Tensor:
"""Calculate the expected performance after an update step."""
context_ratio = to.exp(distribution.distribution.log_prob(context) - old_log_prop)
return to.mean(context_ratio * values)
Expand Down
68 changes: 42 additions & 26 deletions Pyrado/pyrado/domain_randomization/domain_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ class DomainParam:
"""Class to store and manage (probably multiple) domain parameter a.k.a. physics parameter a.k.a. simulator parameter"""

def __init__(
self,
name: Union[str, List[str]],
clip_lo: Optional[Union[int, float]] = -pyrado.inf,
clip_up: Optional[Union[int, float]] = pyrado.inf,
roundint: bool = False,
self,
name: Union[str, List[str]],
clip_lo: Optional[Union[int, float]] = -pyrado.inf,
clip_up: Optional[Union[int, float]] = pyrado.inf,
roundint: bool = False,
):
"""
Constructor, also see the constructor of DomainRandomizer.
Expand Down Expand Up @@ -97,7 +97,7 @@ def adapt(self, domain_distr_param: str, domain_distr_param_value: Union[float,
if domain_distr_param not in self.get_field_names():
raise pyrado.KeyErr(
msg=f"The domain parameter {self.name} does not have a domain distribution parameter "
f"called {domain_distr_param}!"
f"called {domain_distr_param}!"
)
setattr(self, domain_distr_param, domain_distr_param_value)

Expand Down Expand Up @@ -314,14 +314,14 @@ def sample(self, num_samples: int = 1) -> List[to.Tensor]:

class SelfPacedDomainParam(DomainParam):
def __init__(
self,
name: List[str],
target_mean: to.Tensor,
target_cov_flat: to.Tensor,
init_mean: to.Tensor,
init_cov_flat: to.Tensor,
clip_lo: float,
clip_up: float,
self,
name: List[str],
target_mean: to.Tensor,
target_cov_flat: to.Tensor,
init_mean: to.Tensor,
init_cov_flat: to.Tensor,
clip_lo: float,
clip_up: float,
):
"""
Constructor
Expand Down Expand Up @@ -362,12 +362,12 @@ def get_field_names(self) -> Sequence[str]:

@staticmethod
def make_broadening(
name: List[str],
mean: List[float],
init_cov_portion: float = 0.001,
target_cov_portion: float = 0.1,
clip_lo: float = -pyrado.inf,
clip_up: float = pyrado.inf,
name: List[str],
mean: List[float],
init_cov_portion: float = 0.001,
target_cov_portion: float = 0.1,
clip_lo: float = -pyrado.inf,
clip_up: float = pyrado.inf,
) -> "SelfPacedDomainParam":
"""
Creates a self-paced domain parameter having the same initial and target mean, but a larger variance on the
Expand All @@ -394,7 +394,7 @@ def make_broadening(
)

@staticmethod
def from_domain_randomizer(domain_randomizer, *, target_cov_factor=1., init_cov_factor=1 / 100):
def from_domain_randomizer(domain_randomizer, *, target_cov_factor=1.0, init_cov_factor=1 / 100):
"""
Creates a self-paced domain parameter having the same initial and target mean and target variance given by the domain randomizer's variance (scaled by `target_cov_factor`). The initial variance is also given by the domain randomizer's variance (scaled by `init_cov_factor`).
Expand All @@ -403,15 +403,31 @@ def from_domain_randomizer(domain_randomizer, *, target_cov_factor=1., init_cov_
:param init_cov_factor: scaling of the randomizer's variance to get the init variance; defaults to `1/100`
:return: the self-paced domain parameter
"""
name, target_mean, target_cov_flat, init_mean, init_cov_flat, = [], [], [], [], []
(
name,
target_mean,
target_cov_flat,
init_mean,
init_cov_flat,
) = (
[],
[],
[],
[],
[],
)
for domain_param in domain_randomizer.domain_params:
if not isinstance(domain_param, NormalDomainParam):
raise pyrado.TypeErr(given=domain_param, expected_type=NormalDomainParam, msg="each domain_param must be a NormalDomainParam")
raise pyrado.TypeErr(
given=domain_param,
expected_type=NormalDomainParam,
msg="each domain_param must be a NormalDomainParam",
)
name.append(domain_param.name)
target_mean.append(domain_param.mean)
target_cov_flat.append(target_cov_factor * domain_param.std ** 2)
target_cov_flat.append(target_cov_factor * domain_param.std**2)
init_mean.append(domain_param.mean)
init_cov_flat.append(init_cov_factor * domain_param.std ** 2)
init_cov_flat.append(init_cov_factor * domain_param.std**2)
return SelfPacedDomainParam(
name=name,
target_mean=to.tensor(target_mean),
Expand Down Expand Up @@ -443,7 +459,7 @@ def context_cov(self) -> to.Tensor:
return self.context_cov_chol @ self.context_cov_chol.T

def info(self) -> dict:
""
""""""
return {
"name": self.name,
"target_mean": self.target_mean,
Expand Down
4 changes: 3 additions & 1 deletion Pyrado/pyrado/sampling/parallel_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,9 @@ def eval_domain_params(
# Run with progress bar
with tqdm(leave=False, file=sys.stdout, unit="rollouts", desc="Sampling") as pb:
# we set the sub seed to zero since every run will have its personal sub sub seed
return pool.run_map(functools.partial(_ps_run_one_domain_param, eval=True, seed=seed, sub_seed=0), list(enumerate(params)), pb)
return pool.run_map(
functools.partial(_ps_run_one_domain_param, eval=True, seed=seed, sub_seed=0), list(enumerate(params)), pb
)


def eval_nominal_domain(
Expand Down

0 comments on commit 54e7eb0

Please sign in to comment.