Skip to content

Commit

Permalink
Fixes according to black and isort (#99)
Browse files Browse the repository at this point in the history
  • Loading branch information
fdamken authored Oct 9, 2023
1 parent 94ab4e8 commit 93055bb
Show file tree
Hide file tree
Showing 14 changed files with 7 additions and 18 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/linting.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
- name: Run black
uses: psf/black@stable
with:
args: "--check Pyrado setup_deps.py RcsPySim/setup.py"
src: "Pyrado setup_deps.py RcsPySim/setup.py"
isort:
name: Checking Import Order with isort
runs-on: ubuntu-latest
Expand Down
1 change: 0 additions & 1 deletion Pyrado/pyrado/algorithms/episodic/predefined_lqr.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def sampler(self, sampler: ParallelRolloutSampler):
self._sampler = sampler

def step(self, snapshot_mode: str, meta_info: dict = None):

if isinstance(inner_env(self._env), BallOnPlate5DSim):
ctrl_gains = to.tensor(
[
Expand Down
2 changes: 0 additions & 2 deletions Pyrado/pyrado/algorithms/meta/adr.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,6 @@ def __init__(
logger: StepLogger = None,
device: str = "cuda" if to.cuda.is_available() else "cpu",
):

"""
Constructor
Expand Down Expand Up @@ -472,7 +471,6 @@ def get_reward(self, traj: StepSequence) -> to.Tensor:
def train(
self, reference_trajectory: StepSequence, randomized_trajectory: StepSequence, num_epoch: int
) -> to.Tensor:

reference_batch_generator = reference_trajectory.iterate_rollouts()
random_batch_generator = randomized_trajectory.iterate_rollouts()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ def __init__(
self._lr_scheduler = lr_scheduler(self.optim, **lr_scheduler_hparam)

def step(self, snapshot_mode: str, meta_info: dict = None):

# Feed one epoch of the training set to the policy
if self.windowed:
# Predict
Expand Down
1 change: 0 additions & 1 deletion Pyrado/pyrado/algorithms/step_based/dql.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,6 @@ def update(self):
file=sys.stdout,
leave=False,
):

# Sample steps and the associated next step from the replay memory
steps, next_steps = self._memory.sample(self.batch_size)
steps.torch(data_type=to.get_default_dtype())
Expand Down
1 change: 0 additions & 1 deletion Pyrado/pyrado/algorithms/step_based/gae.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,6 @@ def update(self, rollouts: Sequence[StepSequence], use_empirical_returns: bool =

# Iterate over all gathered samples num_epoch times
for e in range(self.num_epoch):

for batch in tqdm(
concat_ros.split_shuffled_batches(
self.batch_size, complete_rollouts=isinstance(self.vfcn, RecurrentPolicy)
Expand Down
2 changes: 0 additions & 2 deletions Pyrado/pyrado/algorithms/step_based/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ def update(self, rollouts: Sequence[StepSequence]):

# Iterations over the whole data set
for e in range(self.num_epoch):

for batch in tqdm(
concat_ros.split_shuffled_batches(self.batch_size, complete_rollouts=self._policy.is_recurrent),
total=num_iter_from_rollouts(None, concat_ros, self.batch_size),
Expand Down Expand Up @@ -412,7 +411,6 @@ def update(self, rollouts: Sequence[StepSequence]):

# Iterations over the whole data set
for e in range(self.num_epoch):

for batch in tqdm(
concat_ros.split_shuffled_batches(
self.batch_size,
Expand Down
1 change: 0 additions & 1 deletion Pyrado/pyrado/environments/rcspysim/planar_insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ class PlanarInsertSim(RcsSim, Serializable):
"""

def __init__(self, task_args: dict, collision_config: dict = None, max_dist_force: float = None, **kwargs):

"""
Constructor
Expand Down
4 changes: 1 addition & 3 deletions Pyrado/pyrado/policies/feed_forward/poly_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,7 @@ class TraceablePolySplineTimePolicy(nn.Module):
t_init: float
t_curr: float
overtime_behavior: str
act_space_shape: Tuple[
int,
]
act_space_shape: Tuple[int,]
act_space_flat_dim: int

def __init__(
Expand Down
1 change: 0 additions & 1 deletion Pyrado/pyrado/sampling/parallel_rollout_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,6 @@ def sample(
disable=(not self.show_progress_bar),
unit="steps" if self.min_steps is not None else "rollouts",
) as pb:

if self.min_steps is None:
if init_states is None and domain_params is None:
# Simply run min_rollouts times
Expand Down
5 changes: 5 additions & 0 deletions Pyrado/pyrado/sampling/sequences.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def sequence_add_init(x_init, iter, dtype=int):
:param dtype: data type to cast to (either int of float)
:return: element at the given iteration and array of the whole sequence
"""

# non-exponential growth
def iter_function(x_seq, i, x_init):
return x_seq[0, :] * (i + 1)
Expand All @@ -115,6 +116,7 @@ def sequence_rec_double(x_init, iter, dtype=int):
:param dtype: data type to cast to (either int of float)
:return: element at the given iteration and array of the whole sequence
"""

# exponential growth
def iter_function(x_seq, i, x_init):
return x_seq[i - 1, :] * 2.0
Expand All @@ -131,6 +133,7 @@ def sequence_sqrt(x_init, iter, dtype=int):
:param dtype: data type to cast to (either int of float)
:return: element at the given iteration and array of the whole sequence
"""

# non-exponential growth
def iter_function(x_seq, i, x_init):
return x_seq[0, :] * np.sqrt(i + 1) # i+1 because sqrt(1) = 1
Expand All @@ -147,6 +150,7 @@ def sequence_rec_sqrt(x_init, iter, dtype=int):
:param dtype: data type to cast to (either int of float)
:return: element at the given iteration and array of the whole sequence
"""

# exponential growth
def iter_function(x_seq, i, x_init):
return x_seq[i - 1, :] * np.sqrt(i + 1) # i+1 because sqrt(1) = 1
Expand All @@ -163,6 +167,7 @@ def sequence_nlog2(x_init, iter, dtype=int):
:param dtype: data type to cast to (either int of float)
:return: element at the given iteration and array of the whole sequence
"""

# non-exponential growth
def iter_function(x_seq, i, x_init):
return x_seq[0, :] * i * np.log2(i + 2) # i+2 because log2(1) = 0 and log2(2) = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ def check_E_n_Jhat(th_n_opt, n):
b_Jhat_n_hist = np.empty((num_samples, num_iter))

for s in range(num_samples):

for n in range(1, num_iter + 1):
n_V = np.random.binomial(n, psi) # perform n Bernoulli trials
n_M = n - n_V
Expand Down
1 change: 0 additions & 1 deletion Pyrado/scripts/hyperparam_optimization/hopt_qq-su_ppo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ def train_and_eval(trial: optuna.Trial, study_dir: str, seed: int):


if __name__ == "__main__":

# Parse command line arguments
args = get_argparser().parse_args()

Expand Down
2 changes: 0 additions & 2 deletions setup_deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,6 @@ def members(ml):
with tarfile.open(tf.name) as tar:

def is_within_directory(directory, target):

abs_directory = os.path.abspath(directory)
abs_target = os.path.abspath(target)

Expand All @@ -324,7 +323,6 @@ def is_within_directory(directory, target):
return prefix == abs_directory

def safe_extract(tar, path=".", members=None, *, numeric_owner=False):

for member in tar.getmembers():
member_path = os.path.join(path, member.name)
if not is_within_directory(path, member_path):
Expand Down

0 comments on commit 93055bb

Please sign in to comment.