Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reformatting (apply black and isort) #97

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion Pyrado/pyrado/algorithms/episodic/predefined_lqr.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def sampler(self, sampler: ParallelRolloutSampler):
self._sampler = sampler

def step(self, snapshot_mode: str, meta_info: dict = None):

if isinstance(inner_env(self._env), BallOnPlate5DSim):
ctrl_gains = to.tensor(
[
Expand Down
2 changes: 0 additions & 2 deletions Pyrado/pyrado/algorithms/meta/adr.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,6 @@ def __init__(
logger: StepLogger = None,
device: str = "cuda" if to.cuda.is_available() else "cpu",
):

"""
Constructor

Expand Down Expand Up @@ -472,7 +471,6 @@ def get_reward(self, traj: StepSequence) -> to.Tensor:
def train(
self, reference_trajectory: StepSequence, randomized_trajectory: StepSequence, num_epoch: int
) -> to.Tensor:

reference_batch_generator = reference_trajectory.iterate_rollouts()
random_batch_generator = randomized_trajectory.iterate_rollouts()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ def __init__(
self._lr_scheduler = lr_scheduler(self.optim, **lr_scheduler_hparam)

def step(self, snapshot_mode: str, meta_info: dict = None):

# Feed one epoch of the training set to the policy
if self.windowed:
# Predict
Expand Down
1 change: 0 additions & 1 deletion Pyrado/pyrado/algorithms/step_based/dql.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,6 @@ def update(self):
file=sys.stdout,
leave=False,
):

# Sample steps and the associated next step from the replay memory
steps, next_steps = self._memory.sample(self.batch_size)
steps.torch(data_type=to.get_default_dtype())
Expand Down
1 change: 0 additions & 1 deletion Pyrado/pyrado/algorithms/step_based/gae.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,6 @@ def update(self, rollouts: Sequence[StepSequence], use_empirical_returns: bool =

# Iterate over all gathered samples num_epoch times
for e in range(self.num_epoch):

for batch in tqdm(
concat_ros.split_shuffled_batches(
self.batch_size, complete_rollouts=isinstance(self.vfcn, RecurrentPolicy)
Expand Down
2 changes: 0 additions & 2 deletions Pyrado/pyrado/algorithms/step_based/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ def update(self, rollouts: Sequence[StepSequence]):

# Iterations over the whole data set
for e in range(self.num_epoch):

for batch in tqdm(
concat_ros.split_shuffled_batches(self.batch_size, complete_rollouts=self._policy.is_recurrent),
total=num_iter_from_rollouts(None, concat_ros, self.batch_size),
Expand Down Expand Up @@ -412,7 +411,6 @@ def update(self, rollouts: Sequence[StepSequence]):

# Iterations over the whole data set
for e in range(self.num_epoch):

for batch in tqdm(
concat_ros.split_shuffled_batches(
self.batch_size,
Expand Down
1 change: 0 additions & 1 deletion Pyrado/pyrado/environments/rcspysim/planar_insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ class PlanarInsertSim(RcsSim, Serializable):
"""

def __init__(self, task_args: dict, collision_config: dict = None, max_dist_force: float = None, **kwargs):

"""
Constructor

Expand Down
4 changes: 1 addition & 3 deletions Pyrado/pyrado/policies/feed_forward/poly_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,7 @@ class TraceablePolySplineTimePolicy(nn.Module):
t_init: float
t_curr: float
overtime_behavior: str
act_space_shape: Tuple[
int,
]
act_space_shape: Tuple[int,]
act_space_flat_dim: int

def __init__(
Expand Down
1 change: 0 additions & 1 deletion Pyrado/pyrado/sampling/parallel_rollout_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,6 @@ def sample(
disable=(not self.show_progress_bar),
unit="steps" if self.min_steps is not None else "rollouts",
) as pb:

if self.min_steps is None:
if init_states is None and domain_params is None:
# Simply run min_rollouts times
Expand Down
5 changes: 5 additions & 0 deletions Pyrado/pyrado/sampling/sequences.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def sequence_add_init(x_init, iter, dtype=int):
:param dtype: data type to cast to (either int of float)
:return: element at the given iteration and array of the whole sequence
"""

# non-exponential growth
def iter_function(x_seq, i, x_init):
return x_seq[0, :] * (i + 1)
Expand All @@ -115,6 +116,7 @@ def sequence_rec_double(x_init, iter, dtype=int):
:param dtype: data type to cast to (either int of float)
:return: element at the given iteration and array of the whole sequence
"""

# exponential growth
def iter_function(x_seq, i, x_init):
return x_seq[i - 1, :] * 2.0
Expand All @@ -131,6 +133,7 @@ def sequence_sqrt(x_init, iter, dtype=int):
:param dtype: data type to cast to (either int of float)
:return: element at the given iteration and array of the whole sequence
"""

# non-exponential growth
def iter_function(x_seq, i, x_init):
return x_seq[0, :] * np.sqrt(i + 1) # i+1 because sqrt(1) = 1
Expand All @@ -147,6 +150,7 @@ def sequence_rec_sqrt(x_init, iter, dtype=int):
:param dtype: data type to cast to (either int of float)
:return: element at the given iteration and array of the whole sequence
"""

# exponential growth
def iter_function(x_seq, i, x_init):
return x_seq[i - 1, :] * np.sqrt(i + 1) # i+1 because sqrt(1) = 1
Expand All @@ -163,6 +167,7 @@ def sequence_nlog2(x_init, iter, dtype=int):
:param dtype: data type to cast to (either int of float)
:return: element at the given iteration and array of the whole sequence
"""

# non-exponential growth
def iter_function(x_seq, i, x_init):
return x_seq[0, :] * i * np.log2(i + 2) # i+2 because log2(1) = 0 and log2(2) = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ def check_E_n_Jhat(th_n_opt, n):
b_Jhat_n_hist = np.empty((num_samples, num_iter))

for s in range(num_samples):

for n in range(1, num_iter + 1):
n_V = np.random.binomial(n, psi) # perform n Bernoulli trials
n_M = n - n_V
Expand Down
1 change: 0 additions & 1 deletion Pyrado/scripts/hyperparam_optimization/hopt_qq-su_ppo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ def train_and_eval(trial: optuna.Trial, study_dir: str, seed: int):


if __name__ == "__main__":

# Parse command line arguments
args = get_argparser().parse_args()

Expand Down
2 changes: 0 additions & 2 deletions setup_deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,6 @@ def members(ml):
with tarfile.open(tf.name) as tar:

def is_within_directory(directory, target):

abs_directory = os.path.abspath(directory)
abs_target = os.path.abspath(target)

Expand All @@ -323,7 +322,6 @@ def is_within_directory(directory, target):
return prefix == abs_directory

def safe_extract(tar, path=".", members=None, *, numeric_owner=False):

for member in tar.getmembers():
member_path = os.path.join(path, member.name)
if not is_within_directory(path, member_path):
Expand Down
Loading