Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring is all you need #98

Merged
merged 27 commits into from
Oct 29, 2023
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
25345fc
Set Python version to 3.7 using PyEnv
fdamken Aug 15, 2023
5013bcf
Fix parallel evaluation
fdamken Aug 15, 2023
4b56030
Remove non-existing Pandas import
fdamken Aug 15, 2023
3c5d16c
Add SPDR parameter factory
fdamken Aug 15, 2023
dfd13f1
Add logging of number of particles
fdamken Aug 15, 2023
2a3451c
Remove commented code
fdamken Aug 15, 2023
a4dc9d0
Refactor SPDR for readability
fdamken Aug 15, 2023
1645b0b
virtualenv is all you need
fdamken Oct 3, 2023
fa00845
Merge branch 'update_setup_deps'
fdamken Oct 3, 2023
c526870
apply black and isort
fdamken Oct 8, 2023
33cfb08
Merge branch 'reformat'
fdamken Oct 8, 2023
54e7eb0
formatting
fdamken Oct 8, 2023
98dfea1
fix black and isort linting
fdamken Oct 8, 2023
8881a0e
Merge branch 'master' into cleanup_spdr
fdamken Oct 8, 2023
2279994
Merge remote-tracking branch 'upstream/master' into cleanup_spdr
fdamken Oct 9, 2023
3836256
remove empty docstring
fdamken Oct 15, 2023
7af231c
address review issues
fdamken Oct 22, 2023
034c453
Ignore local changes/symlinked pre-commit
miterion Oct 9, 2023
386d9f7
Install torch-based deps with dependencies
miterion Oct 15, 2023
81b36c7
Add mujoco as a python dependency
miterion Oct 15, 2023
bcedd70
Remove mujoco-py installation from install script
miterion Oct 15, 2023
00cd6d0
Switch to new mujoco package
miterion Oct 22, 2023
9b48aa8
Remove mujoco submodule
miterion Oct 22, 2023
4fcfaa7
Update pre-commit hooks to latest version
miterion Oct 22, 2023
f5022d4
Apply isort to half_cheetah
miterion Oct 22, 2023
08ef8c8
Remove unnecessary subfolder length call
miterion Oct 27, 2023
1e19a6a
Merge remote-tracking branch 'upstream/mujoco-update' into cleanup_spdr
fdamken Oct 27, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/linting.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
- name: Run black
uses: psf/black@stable
with:
args: "--check Pyrado setup_deps.py RcsPySim/setup.py"
src: "Pyrado setup_deps.py RcsPySim/setup.py"
isort:
name: Checking Import Order with isort
runs-on: ubuntu-latest
Expand Down
1 change: 1 addition & 0 deletions .python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.7
1 change: 0 additions & 1 deletion Pyrado/pyrado/algorithms/episodic/predefined_lqr.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def sampler(self, sampler: ParallelRolloutSampler):
self._sampler = sampler

def step(self, snapshot_mode: str, meta_info: dict = None):

if isinstance(inner_env(self._env), BallOnPlate5DSim):
ctrl_gains = to.tensor(
[
Expand Down
2 changes: 0 additions & 2 deletions Pyrado/pyrado/algorithms/meta/adr.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,6 @@ def __init__(
logger: StepLogger = None,
device: str = "cuda" if to.cuda.is_available() else "cpu",
):

"""
Constructor
Expand Down Expand Up @@ -472,7 +471,6 @@ def get_reward(self, traj: StepSequence) -> to.Tensor:
def train(
self, reference_trajectory: StepSequence, randomized_trajectory: StepSequence, num_epoch: int
) -> to.Tensor:

reference_batch_generator = reference_trajectory.iterate_rollouts()
random_batch_generator = randomized_trajectory.iterate_rollouts()

Expand Down
329 changes: 171 additions & 158 deletions Pyrado/pyrado/algorithms/meta/spdr.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ def __init__(
self._lr_scheduler = lr_scheduler(self.optim, **lr_scheduler_hparam)

def step(self, snapshot_mode: str, meta_info: dict = None):

# Feed one epoch of the training set to the policy
if self.windowed:
# Predict
Expand Down
1 change: 0 additions & 1 deletion Pyrado/pyrado/algorithms/step_based/dql.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,6 @@ def update(self):
file=sys.stdout,
leave=False,
):

# Sample steps and the associated next step from the replay memory
steps, next_steps = self._memory.sample(self.batch_size)
steps.torch(data_type=to.get_default_dtype())
Expand Down
1 change: 0 additions & 1 deletion Pyrado/pyrado/algorithms/step_based/gae.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,6 @@ def update(self, rollouts: Sequence[StepSequence], use_empirical_returns: bool =

# Iterate over all gathered samples num_epoch times
for e in range(self.num_epoch):

for batch in tqdm(
concat_ros.split_shuffled_batches(
self.batch_size, complete_rollouts=isinstance(self.vfcn, RecurrentPolicy)
Expand Down
2 changes: 0 additions & 2 deletions Pyrado/pyrado/algorithms/step_based/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ def update(self, rollouts: Sequence[StepSequence]):

# Iterations over the whole data set
for e in range(self.num_epoch):

for batch in tqdm(
concat_ros.split_shuffled_batches(self.batch_size, complete_rollouts=self._policy.is_recurrent),
total=num_iter_from_rollouts(None, concat_ros, self.batch_size),
Expand Down Expand Up @@ -412,7 +411,6 @@ def update(self, rollouts: Sequence[StepSequence]):

# Iterations over the whole data set
for e in range(self.num_epoch):

for batch in tqdm(
concat_ros.split_shuffled_batches(
self.batch_size,
Expand Down
57 changes: 57 additions & 0 deletions Pyrado/pyrado/domain_randomization/domain_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,51 @@ def make_broadening(
clip_up=clip_up,
)

@staticmethod
def from_domain_randomizer(domain_randomizer, *, target_cov_factor=1.0, init_cov_factor=1 / 100):
"""
Creates a self-paced domain parameter having the same initial and target mean and target variance given by the domain randomizer's variance (scaled by `target_cov_factor`). The initial variance is also given by the domain randomizer's variance (scaled by `init_cov_factor`).

:param domain_randomizer: randomizer to grab the data from
:param target_cov_factor: scaling of the randomizer's variance to get the target variance; defaults to `1`
:param init_cov_factor: scaling of the randomizer's variance to get the init variance; defaults to `1/100`
:return: the self-paced domain parameter
"""
(
name,
target_mean,
target_cov_flat,
init_mean,
init_cov_flat,
) = (
[],
[],
[],
[],
[],
)
for domain_param in domain_randomizer.domain_params:
if not isinstance(domain_param, NormalDomainParam):
raise pyrado.TypeErr(
given=domain_param,
expected_type=NormalDomainParam,
msg="each domain_param must be a NormalDomainParam",
)
name.append(domain_param.name)
target_mean.append(domain_param.mean)
target_cov_flat.append(target_cov_factor * domain_param.std**2)
init_mean.append(domain_param.mean)
init_cov_flat.append(init_cov_factor * domain_param.std**2)
return SelfPacedDomainParam(
name=name,
target_mean=to.tensor(target_mean),
target_cov_flat=to.tensor(target_cov_flat),
init_mean=to.tensor(init_mean),
init_cov_flat=to.tensor(init_cov_flat),
clip_lo=-pyrado.inf,
clip_up=+pyrado.inf,
)

fdamken marked this conversation as resolved.
Show resolved Hide resolved
@property
def target_distr(self) -> MultivariateNormal:
"""Get the target distribution."""
Expand All @@ -413,6 +458,18 @@ def context_cov(self) -> to.Tensor:
"""Get the current covariance matrix."""
return self.context_cov_chol @ self.context_cov_chol.T

def info(self) -> dict:
""""""
fdamken marked this conversation as resolved.
Show resolved Hide resolved
return {
"name": self.name,
"target_mean": self.target_mean,
"target_cov_chol": self.target_cov_chol,
"init_mean": self.init_mean,
"init_cov_chol": self.init_cov_chol,
"clip_lo": self.clip_lo,
"clip_up": self.clip_up,
}

def adapt(self, domain_distr_param: str, domain_distr_param_value: to.Tensor):
"""
Update this domain parameter.
Expand Down
1 change: 0 additions & 1 deletion Pyrado/pyrado/environments/rcspysim/planar_insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ class PlanarInsertSim(RcsSim, Serializable):
"""

def __init__(self, task_args: dict, collision_config: dict = None, max_dist_force: float = None, **kwargs):

"""
Constructor
Expand Down
5 changes: 0 additions & 5 deletions Pyrado/pyrado/plotting/heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from matplotlib import colors
from matplotlib import pyplot as plt
from matplotlib import ticker
from pandas.core.indexes.numeric import NumericIndex

import pyrado
from pyrado.plotting.utils import draw_sep_cbar
Expand Down Expand Up @@ -191,10 +190,6 @@ def draw_heatmap(
:return: handles to the heat map and the color bar figures (`None` if not existent)
"""
if isinstance(data, pd.DataFrame):
if not isinstance(data.index, NumericIndex):
fdamken marked this conversation as resolved.
Show resolved Hide resolved
raise pyrado.TypeErr(given=data.index, expected_type=NumericIndex)
if not isinstance(data.columns, NumericIndex):
raise pyrado.TypeErr(given=data.columns, expected_type=NumericIndex)
# Extract the data
x = data.columns
y = data.index
Expand Down
4 changes: 1 addition & 3 deletions Pyrado/pyrado/policies/feed_forward/poly_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,7 @@ class TraceablePolySplineTimePolicy(nn.Module):
t_init: float
t_curr: float
overtime_behavior: str
act_space_shape: Tuple[
int,
]
act_space_shape: Tuple[int,]
act_space_flat_dim: int

def __init__(
Expand Down
8 changes: 5 additions & 3 deletions Pyrado/pyrado/sampling/parallel_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,9 @@ def eval_domain_params(
# Run with progress bar
with tqdm(leave=False, file=sys.stdout, unit="rollouts", desc="Sampling") as pb:
# we set the sub seed to zero since every run will have its personal sub sub seed
return pool.run_map(functools.partial(_ps_run_one_domain_param, eval=True, seed=seed, sub_seed=0), params, pb)
return pool.run_map(
functools.partial(_ps_run_one_domain_param, eval=True, seed=seed, sub_seed=0), list(enumerate(params)), pb
)


def eval_nominal_domain(
Expand All @@ -128,7 +130,7 @@ def eval_nominal_domain(

# Run with progress bar
with tqdm(leave=False, file=sys.stdout, unit="rollouts", desc="Sampling") as pb:
return pool.run_map(functools.partial(_ps_run_one_init_state, eval=True), init_states, pb)
return pool.run_map(functools.partial(_ps_run_one_init_state, eval=True), list(enumerate(init_states)), pb)


def eval_randomized_domain(
Expand All @@ -152,7 +154,7 @@ def eval_randomized_domain(

# Run with progress bar
with tqdm(leave=False, file=sys.stdout, unit="rollouts", desc="Sampling") as pb:
return pool.run_map(functools.partial(_ps_run_one_init_state, eval=True), init_states, pb)
return pool.run_map(functools.partial(_ps_run_one_init_state, eval=True), list(enumerate(init_states)), pb)


def eval_domain_params_with_segmentwise_reset(
Expand Down
1 change: 0 additions & 1 deletion Pyrado/pyrado/sampling/parallel_rollout_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,6 @@ def sample(
disable=(not self.show_progress_bar),
unit="steps" if self.min_steps is not None else "rollouts",
) as pb:

if self.min_steps is None:
if init_states is None and domain_params is None:
# Simply run min_rollouts times
Expand Down
5 changes: 5 additions & 0 deletions Pyrado/pyrado/sampling/sequences.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def sequence_add_init(x_init, iter, dtype=int):
:param dtype: data type to cast to (either int of float)
:return: element at the given iteration and array of the whole sequence
"""

# non-exponential growth
def iter_function(x_seq, i, x_init):
return x_seq[0, :] * (i + 1)
Expand All @@ -115,6 +116,7 @@ def sequence_rec_double(x_init, iter, dtype=int):
:param dtype: data type to cast to (either int of float)
:return: element at the given iteration and array of the whole sequence
"""

# exponential growth
def iter_function(x_seq, i, x_init):
return x_seq[i - 1, :] * 2.0
Expand All @@ -131,6 +133,7 @@ def sequence_sqrt(x_init, iter, dtype=int):
:param dtype: data type to cast to (either int of float)
:return: element at the given iteration and array of the whole sequence
"""

# non-exponential growth
def iter_function(x_seq, i, x_init):
return x_seq[0, :] * np.sqrt(i + 1) # i+1 because sqrt(1) = 1
Expand All @@ -147,6 +150,7 @@ def sequence_rec_sqrt(x_init, iter, dtype=int):
:param dtype: data type to cast to (either int of float)
:return: element at the given iteration and array of the whole sequence
"""

# exponential growth
def iter_function(x_seq, i, x_init):
return x_seq[i - 1, :] * np.sqrt(i + 1) # i+1 because sqrt(1) = 1
Expand All @@ -163,6 +167,7 @@ def sequence_nlog2(x_init, iter, dtype=int):
:param dtype: data type to cast to (either int of float)
:return: element at the given iteration and array of the whole sequence
"""

# non-exponential growth
def iter_function(x_seq, i, x_init):
return x_seq[0, :] * i * np.log2(i + 2) # i+2 because log2(1) = 0 and log2(2) = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ def check_E_n_Jhat(th_n_opt, n):
b_Jhat_n_hist = np.empty((num_samples, num_iter))

for s in range(num_samples):

for n in range(1, num_iter + 1):
n_V = np.random.binomial(n, psi) # perform n Bernoulli trials
n_M = n - n_V
Expand Down
1 change: 0 additions & 1 deletion Pyrado/scripts/hyperparam_optimization/hopt_qq-su_ppo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ def train_and_eval(trial: optuna.Trial, study_dir: str, seed: int):


if __name__ == "__main__":

# Parse command line arguments
args = get_argparser().parse_args()

Expand Down
9 changes: 4 additions & 5 deletions setup_deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,11 @@
resources_dir = osp.join(dependency_dir, "resources")

# Global cmake prefix path
cmake_prefix_path = [
cmake_prefix_path = []
conda_prefix = os.getenv("CONDA_PREFIX")
if conda_prefix:
# Anaconda env root directory
os.environ["CONDA_PREFIX"]
]
cmake_prefix_path.append(conda_prefix)

# Required packages
required_packages = [
Expand Down Expand Up @@ -314,7 +315,6 @@ def members(ml):
with tarfile.open(tf.name) as tar:

def is_within_directory(directory, target):

abs_directory = os.path.abspath(directory)
abs_target = os.path.abspath(target)

Expand All @@ -323,7 +323,6 @@ def is_within_directory(directory, target):
return prefix == abs_directory

def safe_extract(tar, path=".", members=None, *, numeric_owner=False):

for member in tar.getmembers():
member_path = os.path.join(path, member.name)
if not is_within_directory(path, member_path):
Expand Down
Loading