Skip to content

Commit

Permalink
[Feat] +automatic multistart
Browse files Browse the repository at this point in the history
  • Loading branch information
fedebotu committed Nov 30, 2023
1 parent 0aaed80 commit ba7b8c4
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion rl4co/models/zoo/active_search/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def training_step(self, batch, batch_idx):
td_init = self.env.reset(batch)
n_aug, n_start, n_runs = (
self.augmentation.num_augment,
get_num_starts(td_init),
get_num_starts(td_init, self.env.name),
self.hparams.num_parallel_runs,
)
td_init = self.augmentation(td_init)
Expand Down
2 changes: 1 addition & 1 deletion rl4co/models/zoo/eas/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def training_step(self, batch, batch_idx):
td_init = self.env.reset(batch)
n_aug, n_start, n_runs = (
self.augmentation.num_augment,
get_num_starts(td_init),
get_num_starts(td_init, self.env.name),
self.hparams.num_parallel_runs,
)
td_init = self.augmentation(td_init)
Expand Down
2 changes: 1 addition & 1 deletion rl4co/models/zoo/pomo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def shared_step(
):
td = self.env.reset(batch)
n_aug, n_start = self.num_augment, self.num_starts
n_start = get_num_starts(td) if n_start is None else n_start
n_start = get_num_starts(td, self.env.name) if n_start is None else n_start

# During training, we do not augment the data
if phase == "train":
Expand Down
2 changes: 1 addition & 1 deletion rl4co/models/zoo/symnco/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def shared_step(
):
td = self.env.reset(batch)
n_aug, n_start = self.num_augment, self.num_starts
n_start = get_num_starts(td) if n_start is None else n_start
n_start = get_num_starts(td, self.env.name) if n_start is None else n_start

# Symmetric augmentation
if n_aug > 1:
Expand Down

0 comments on commit ba7b8c4

Please sign in to comment.