Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix black and isort linting #99

Merged
merged 1 commit into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/linting.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
- name: Run black
uses: psf/black@stable
with:
args: "--check Pyrado setup_deps.py RcsPySim/setup.py"
src: "Pyrado setup_deps.py RcsPySim/setup.py"
isort:
name: Checking Import Order with isort
runs-on: ubuntu-latest
Expand Down
1 change: 0 additions & 1 deletion Pyrado/pyrado/algorithms/episodic/predefined_lqr.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def sampler(self, sampler: ParallelRolloutSampler):
self._sampler = sampler

def step(self, snapshot_mode: str, meta_info: dict = None):

if isinstance(inner_env(self._env), BallOnPlate5DSim):
ctrl_gains = to.tensor(
[
Expand Down
2 changes: 0 additions & 2 deletions Pyrado/pyrado/algorithms/meta/adr.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,6 @@ def __init__(
logger: StepLogger = None,
device: str = "cuda" if to.cuda.is_available() else "cpu",
):

"""
Constructor
Expand Down Expand Up @@ -472,7 +471,6 @@ def get_reward(self, traj: StepSequence) -> to.Tensor:
def train(
self, reference_trajectory: StepSequence, randomized_trajectory: StepSequence, num_epoch: int
) -> to.Tensor:

reference_batch_generator = reference_trajectory.iterate_rollouts()
random_batch_generator = randomized_trajectory.iterate_rollouts()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ def __init__(
self._lr_scheduler = lr_scheduler(self.optim, **lr_scheduler_hparam)

def step(self, snapshot_mode: str, meta_info: dict = None):

# Feed one epoch of the training set to the policy
if self.windowed:
# Predict
Expand Down
1 change: 0 additions & 1 deletion Pyrado/pyrado/algorithms/step_based/dql.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,6 @@ def update(self):
file=sys.stdout,
leave=False,
):

# Sample steps and the associated next step from the replay memory
steps, next_steps = self._memory.sample(self.batch_size)
steps.torch(data_type=to.get_default_dtype())
Expand Down
1 change: 0 additions & 1 deletion Pyrado/pyrado/algorithms/step_based/gae.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,6 @@ def update(self, rollouts: Sequence[StepSequence], use_empirical_returns: bool =

# Iterate over all gathered samples num_epoch times
for e in range(self.num_epoch):

for batch in tqdm(
concat_ros.split_shuffled_batches(
self.batch_size, complete_rollouts=isinstance(self.vfcn, RecurrentPolicy)
Expand Down
2 changes: 0 additions & 2 deletions Pyrado/pyrado/algorithms/step_based/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ def update(self, rollouts: Sequence[StepSequence]):

# Iterations over the whole data set
for e in range(self.num_epoch):

for batch in tqdm(
concat_ros.split_shuffled_batches(self.batch_size, complete_rollouts=self._policy.is_recurrent),
total=num_iter_from_rollouts(None, concat_ros, self.batch_size),
Expand Down Expand Up @@ -412,7 +411,6 @@ def update(self, rollouts: Sequence[StepSequence]):

# Iterations over the whole data set
for e in range(self.num_epoch):

for batch in tqdm(
concat_ros.split_shuffled_batches(
self.batch_size,
Expand Down
1 change: 0 additions & 1 deletion Pyrado/pyrado/environments/rcspysim/planar_insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ class PlanarInsertSim(RcsSim, Serializable):
"""

def __init__(self, task_args: dict, collision_config: dict = None, max_dist_force: float = None, **kwargs):

"""
Constructor
Expand Down
4 changes: 1 addition & 3 deletions Pyrado/pyrado/policies/feed_forward/poly_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,7 @@ class TraceablePolySplineTimePolicy(nn.Module):
t_init: float
t_curr: float
overtime_behavior: str
act_space_shape: Tuple[
int,
]
act_space_shape: Tuple[int,]
act_space_flat_dim: int

def __init__(
Expand Down
1 change: 0 additions & 1 deletion Pyrado/pyrado/sampling/parallel_rollout_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,6 @@ def sample(
disable=(not self.show_progress_bar),
unit="steps" if self.min_steps is not None else "rollouts",
) as pb:

if self.min_steps is None:
if init_states is None and domain_params is None:
# Simply run min_rollouts times
Expand Down
5 changes: 5 additions & 0 deletions Pyrado/pyrado/sampling/sequences.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def sequence_add_init(x_init, iter, dtype=int):
:param dtype: data type to cast to (either int of float)
:return: element at the given iteration and array of the whole sequence
"""

# non-exponential growth
def iter_function(x_seq, i, x_init):
return x_seq[0, :] * (i + 1)
Expand All @@ -115,6 +116,7 @@ def sequence_rec_double(x_init, iter, dtype=int):
:param dtype: data type to cast to (either int of float)
:return: element at the given iteration and array of the whole sequence
"""

# exponential growth
def iter_function(x_seq, i, x_init):
return x_seq[i - 1, :] * 2.0
Expand All @@ -131,6 +133,7 @@ def sequence_sqrt(x_init, iter, dtype=int):
:param dtype: data type to cast to (either int of float)
:return: element at the given iteration and array of the whole sequence
"""

# non-exponential growth
def iter_function(x_seq, i, x_init):
return x_seq[0, :] * np.sqrt(i + 1) # i+1 because sqrt(1) = 1
Expand All @@ -147,6 +150,7 @@ def sequence_rec_sqrt(x_init, iter, dtype=int):
:param dtype: data type to cast to (either int of float)
:return: element at the given iteration and array of the whole sequence
"""

# exponential growth
def iter_function(x_seq, i, x_init):
return x_seq[i - 1, :] * np.sqrt(i + 1) # i+1 because sqrt(1) = 1
Expand All @@ -163,6 +167,7 @@ def sequence_nlog2(x_init, iter, dtype=int):
:param dtype: data type to cast to (either int of float)
:return: element at the given iteration and array of the whole sequence
"""

# non-exponential growth
def iter_function(x_seq, i, x_init):
return x_seq[0, :] * i * np.log2(i + 2) # i+2 because log2(1) = 0 and log2(2) = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ def check_E_n_Jhat(th_n_opt, n):
b_Jhat_n_hist = np.empty((num_samples, num_iter))

for s in range(num_samples):

for n in range(1, num_iter + 1):
n_V = np.random.binomial(n, psi) # perform n Bernoulli trials
n_M = n - n_V
Expand Down
1 change: 0 additions & 1 deletion Pyrado/scripts/hyperparam_optimization/hopt_qq-su_ppo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ def train_and_eval(trial: optuna.Trial, study_dir: str, seed: int):


if __name__ == "__main__":

# Parse command line arguments
args = get_argparser().parse_args()

Expand Down
2 changes: 0 additions & 2 deletions setup_deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,6 @@ def members(ml):
with tarfile.open(tf.name) as tar:

def is_within_directory(directory, target):

abs_directory = os.path.abspath(directory)
abs_target = os.path.abspath(target)

Expand All @@ -323,7 +322,6 @@ def is_within_directory(directory, target):
return prefix == abs_directory

def safe_extract(tar, path=".", members=None, *, numeric_owner=False):

for member in tar.getmembers():
member_path = os.path.join(path, member.name)
if not is_within_directory(path, member_path):
Expand Down
Loading