From 93055bbf69e1dfc4aeb40e983b709b1ef7ed0170 Mon Sep 17 00:00:00 2001 From: Fabian Damken Date: Mon, 9 Oct 2023 02:48:11 -0400 Subject: [PATCH] Fixes according to black and isort (#99) --- .github/workflows/linting.yml | 2 +- Pyrado/pyrado/algorithms/episodic/predefined_lqr.py | 1 - Pyrado/pyrado/algorithms/meta/adr.py | 2 -- Pyrado/pyrado/algorithms/regression/timeseries_prediction.py | 1 - Pyrado/pyrado/algorithms/step_based/dql.py | 1 - Pyrado/pyrado/algorithms/step_based/gae.py | 1 - Pyrado/pyrado/algorithms/step_based/ppo.py | 2 -- Pyrado/pyrado/environments/rcspysim/planar_insert.py | 1 - Pyrado/pyrado/policies/feed_forward/poly_time.py | 4 +--- Pyrado/pyrado/sampling/parallel_rollout_sampler.py | 1 - Pyrado/pyrado/sampling/sequences.py | 5 +++++ .../evaluation/paper_specific/sob_illustrative_example.py | 1 - Pyrado/scripts/hyperparam_optimization/hopt_qq-su_ppo2.py | 1 - setup_deps.py | 2 -- 14 files changed, 7 insertions(+), 18 deletions(-) diff --git a/.github/workflows/linting.yml b/.github/workflows/linting.yml index 5f7851140ad..abb95a7d05d 100644 --- a/.github/workflows/linting.yml +++ b/.github/workflows/linting.yml @@ -12,7 +12,7 @@ jobs: - name: Run black uses: psf/black@stable with: - args: "--check Pyrado setup_deps.py RcsPySim/setup.py" + src: "Pyrado setup_deps.py RcsPySim/setup.py" isort: name: Checking Import Order with isort runs-on: ubuntu-latest diff --git a/Pyrado/pyrado/algorithms/episodic/predefined_lqr.py b/Pyrado/pyrado/algorithms/episodic/predefined_lqr.py index 48d8b5a8cca..3c773f90cdf 100644 --- a/Pyrado/pyrado/algorithms/episodic/predefined_lqr.py +++ b/Pyrado/pyrado/algorithms/episodic/predefined_lqr.py @@ -102,7 +102,6 @@ def sampler(self, sampler: ParallelRolloutSampler): self._sampler = sampler def step(self, snapshot_mode: str, meta_info: dict = None): - if isinstance(inner_env(self._env), BallOnPlate5DSim): ctrl_gains = to.tensor( [ diff --git a/Pyrado/pyrado/algorithms/meta/adr.py b/Pyrado/pyrado/algorithms/meta/adr.py index 14d34f4a22c..35eb386cb88 100644 --- a/Pyrado/pyrado/algorithms/meta/adr.py +++ b/Pyrado/pyrado/algorithms/meta/adr.py @@ -431,7 +431,6 @@ def __init__( logger: StepLogger = None, device: str = "cuda" if to.cuda.is_available() else "cpu", ): - """ Constructor @@ -472,7 +471,6 @@ def get_reward(self, traj: StepSequence) -> to.Tensor: def train( self, reference_trajectory: StepSequence, randomized_trajectory: StepSequence, num_epoch: int ) -> to.Tensor: - reference_batch_generator = reference_trajectory.iterate_rollouts() random_batch_generator = randomized_trajectory.iterate_rollouts() diff --git a/Pyrado/pyrado/algorithms/regression/timeseries_prediction.py b/Pyrado/pyrado/algorithms/regression/timeseries_prediction.py index 11326aee1f5..d33ea819c9a 100644 --- a/Pyrado/pyrado/algorithms/regression/timeseries_prediction.py +++ b/Pyrado/pyrado/algorithms/regression/timeseries_prediction.py @@ -104,7 +104,6 @@ def __init__( self._lr_scheduler = lr_scheduler(self.optim, **lr_scheduler_hparam) def step(self, snapshot_mode: str, meta_info: dict = None): - # Feed one epoch of the training set to the policy if self.windowed: # Predict diff --git a/Pyrado/pyrado/algorithms/step_based/dql.py b/Pyrado/pyrado/algorithms/step_based/dql.py index ff38c66ab8b..36ca2e6c6b5 100644 --- a/Pyrado/pyrado/algorithms/step_based/dql.py +++ b/Pyrado/pyrado/algorithms/step_based/dql.py @@ -192,7 +192,6 @@ def update(self): file=sys.stdout, leave=False, ): - # Sample steps and the associated next step from the replay memory steps, next_steps = self._memory.sample(self.batch_size) steps.torch(data_type=to.get_default_dtype()) diff --git a/Pyrado/pyrado/algorithms/step_based/gae.py b/Pyrado/pyrado/algorithms/step_based/gae.py index 0f80a42f85c..ac125f51d21 100644 --- a/Pyrado/pyrado/algorithms/step_based/gae.py +++ b/Pyrado/pyrado/algorithms/step_based/gae.py @@ -239,7 +239,6 @@ def update(self, rollouts: Sequence[StepSequence], use_empirical_returns: bool = # Iterate over all gathered samples num_epoch times for e in range(self.num_epoch): - for batch in tqdm( concat_ros.split_shuffled_batches( self.batch_size, complete_rollouts=isinstance(self.vfcn, RecurrentPolicy) diff --git a/Pyrado/pyrado/algorithms/step_based/ppo.py b/Pyrado/pyrado/algorithms/step_based/ppo.py index 2c93f0ccd5f..7f8d094318e 100644 --- a/Pyrado/pyrado/algorithms/step_based/ppo.py +++ b/Pyrado/pyrado/algorithms/step_based/ppo.py @@ -176,7 +176,6 @@ def update(self, rollouts: Sequence[StepSequence]): # Iterations over the whole data set for e in range(self.num_epoch): - for batch in tqdm( concat_ros.split_shuffled_batches(self.batch_size, complete_rollouts=self._policy.is_recurrent), total=num_iter_from_rollouts(None, concat_ros, self.batch_size), @@ -412,7 +411,6 @@ def update(self, rollouts: Sequence[StepSequence]): # Iterations over the whole data set for e in range(self.num_epoch): - for batch in tqdm( concat_ros.split_shuffled_batches( self.batch_size, diff --git a/Pyrado/pyrado/environments/rcspysim/planar_insert.py b/Pyrado/pyrado/environments/rcspysim/planar_insert.py index 5dcedde15c9..82e68b54313 100644 --- a/Pyrado/pyrado/environments/rcspysim/planar_insert.py +++ b/Pyrado/pyrado/environments/rcspysim/planar_insert.py @@ -75,7 +75,6 @@ class PlanarInsertSim(RcsSim, Serializable): """ def __init__(self, task_args: dict, collision_config: dict = None, max_dist_force: float = None, **kwargs): - """ Constructor diff --git a/Pyrado/pyrado/policies/feed_forward/poly_time.py b/Pyrado/pyrado/policies/feed_forward/poly_time.py index cf5eb2c9dbf..c01702774a1 100644 --- a/Pyrado/pyrado/policies/feed_forward/poly_time.py +++ b/Pyrado/pyrado/policies/feed_forward/poly_time.py @@ -251,9 +251,7 @@ class TraceablePolySplineTimePolicy(nn.Module): t_init: float t_curr: float overtime_behavior: str - act_space_shape: Tuple[ - int, - ] + act_space_shape: Tuple[int,] act_space_flat_dim: int def __init__( diff --git a/Pyrado/pyrado/sampling/parallel_rollout_sampler.py b/Pyrado/pyrado/sampling/parallel_rollout_sampler.py index 59006730c45..dbb90558624 100644 --- a/Pyrado/pyrado/sampling/parallel_rollout_sampler.py +++ b/Pyrado/pyrado/sampling/parallel_rollout_sampler.py @@ -276,7 +276,6 @@ def sample( disable=(not self.show_progress_bar), unit="steps" if self.min_steps is not None else "rollouts", ) as pb: - if self.min_steps is None: if init_states is None and domain_params is None: # Simply run min_rollouts times diff --git a/Pyrado/pyrado/sampling/sequences.py b/Pyrado/pyrado/sampling/sequences.py index bfa5d2be8f0..de132f422ad 100644 --- a/Pyrado/pyrado/sampling/sequences.py +++ b/Pyrado/pyrado/sampling/sequences.py @@ -99,6 +99,7 @@ def sequence_add_init(x_init, iter, dtype=int): :param dtype: data type to cast to (either int of float) :return: element at the given iteration and array of the whole sequence """ + # non-exponential growth def iter_function(x_seq, i, x_init): return x_seq[0, :] * (i + 1) @@ -115,6 +116,7 @@ def sequence_rec_double(x_init, iter, dtype=int): :param dtype: data type to cast to (either int of float) :return: element at the given iteration and array of the whole sequence """ + # exponential growth def iter_function(x_seq, i, x_init): return x_seq[i - 1, :] * 2.0 @@ -131,6 +133,7 @@ def sequence_sqrt(x_init, iter, dtype=int): :param dtype: data type to cast to (either int of float) :return: element at the given iteration and array of the whole sequence """ + # non-exponential growth def iter_function(x_seq, i, x_init): return x_seq[0, :] * np.sqrt(i + 1) # i+1 because sqrt(1) = 1 @@ -147,6 +150,7 @@ def sequence_rec_sqrt(x_init, iter, dtype=int): :param dtype: data type to cast to (either int of float) :return: element at the given iteration and array of the whole sequence """ + # exponential growth def iter_function(x_seq, i, x_init): return x_seq[i - 1, :] * np.sqrt(i + 1) # i+1 because sqrt(1) = 1 @@ -163,6 +167,7 @@ def sequence_nlog2(x_init, iter, dtype=int): :param dtype: data type to cast to (either int of float) :return: element at the given iteration and array of the whole sequence """ + # non-exponential growth def iter_function(x_seq, i, x_init): return x_seq[0, :] * i * np.log2(i + 2) # i+2 because log2(1) = 0 and log2(2) = 1 diff --git a/Pyrado/scripts/evaluation/paper_specific/sob_illustrative_example.py b/Pyrado/scripts/evaluation/paper_specific/sob_illustrative_example.py index e9ecfda3638..c035c4671d2 100755 --- a/Pyrado/scripts/evaluation/paper_specific/sob_illustrative_example.py +++ b/Pyrado/scripts/evaluation/paper_specific/sob_illustrative_example.py @@ -139,7 +139,6 @@ def check_E_n_Jhat(th_n_opt, n): b_Jhat_n_hist = np.empty((num_samples, num_iter)) for s in range(num_samples): - for n in range(1, num_iter + 1): n_V = np.random.binomial(n, psi) # perform n Bernoulli trials n_M = n - n_V diff --git a/Pyrado/scripts/hyperparam_optimization/hopt_qq-su_ppo2.py b/Pyrado/scripts/hyperparam_optimization/hopt_qq-su_ppo2.py index 6991db4a80b..2caccaf77d8 100755 --- a/Pyrado/scripts/hyperparam_optimization/hopt_qq-su_ppo2.py +++ b/Pyrado/scripts/hyperparam_optimization/hopt_qq-su_ppo2.py @@ -138,7 +138,6 @@ def train_and_eval(trial: optuna.Trial, study_dir: str, seed: int): if __name__ == "__main__": - # Parse command line arguments args = get_argparser().parse_args() diff --git a/setup_deps.py b/setup_deps.py index 6386e9bd3ec..6fe7b8333ba 100644 --- a/setup_deps.py +++ b/setup_deps.py @@ -315,7 +315,6 @@ def members(ml): with tarfile.open(tf.name) as tar: def is_within_directory(directory, target): - abs_directory = os.path.abspath(directory) abs_target = os.path.abspath(target) @@ -324,7 +323,6 @@ def is_within_directory(directory, target): return prefix == abs_directory def safe_extract(tar, path=".", members=None, *, numeric_owner=False): - for member in tar.getmembers(): member_path = os.path.join(path, member.name) if not is_within_directory(path, member_path):