From ba7b8c4255c6938ffac84483b8df1c0a4984d54e Mon Sep 17 00:00:00 2001 From: Federico Berto Date: Thu, 30 Nov 2023 21:20:31 +0900 Subject: [PATCH] [Feat] +automatic multistart --- rl4co/models/zoo/active_search/search.py | 2 +- rl4co/models/zoo/eas/search.py | 2 +- rl4co/models/zoo/pomo/model.py | 2 +- rl4co/models/zoo/symnco/model.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/rl4co/models/zoo/active_search/search.py b/rl4co/models/zoo/active_search/search.py index 1d13ef70..f5cc4ef3 100644 --- a/rl4co/models/zoo/active_search/search.py +++ b/rl4co/models/zoo/active_search/search.py @@ -112,7 +112,7 @@ def training_step(self, batch, batch_idx): td_init = self.env.reset(batch) n_aug, n_start, n_runs = ( self.augmentation.num_augment, - get_num_starts(td_init), + get_num_starts(td_init, self.env.name), self.hparams.num_parallel_runs, ) td_init = self.augmentation(td_init) diff --git a/rl4co/models/zoo/eas/search.py b/rl4co/models/zoo/eas/search.py index 1aeab867..7fd33ccb 100644 --- a/rl4co/models/zoo/eas/search.py +++ b/rl4co/models/zoo/eas/search.py @@ -142,7 +142,7 @@ def training_step(self, batch, batch_idx): td_init = self.env.reset(batch) n_aug, n_start, n_runs = ( self.augmentation.num_augment, - get_num_starts(td_init), + get_num_starts(td_init, self.env.name), self.hparams.num_parallel_runs, ) td_init = self.augmentation(td_init) diff --git a/rl4co/models/zoo/pomo/model.py b/rl4co/models/zoo/pomo/model.py index 1f2383d7..f55f5c43 100644 --- a/rl4co/models/zoo/pomo/model.py +++ b/rl4co/models/zoo/pomo/model.py @@ -81,7 +81,7 @@ def shared_step( ): td = self.env.reset(batch) n_aug, n_start = self.num_augment, self.num_starts - n_start = get_num_starts(td) if n_start is None else n_start + n_start = get_num_starts(td, self.env.name) if n_start is None else n_start # During training, we do not augment the data if phase == "train": diff --git a/rl4co/models/zoo/symnco/model.py b/rl4co/models/zoo/symnco/model.py index d0639552..340122e7 100644 --- a/rl4co/models/zoo/symnco/model.py +++ b/rl4co/models/zoo/symnco/model.py @@ -71,7 +71,7 @@ def shared_step( ): td = self.env.reset(batch) n_aug, n_start = self.num_augment, self.num_starts - n_start = get_num_starts(td) if n_start is None else n_start + n_start = get_num_starts(td, self.env.name) if n_start is None else n_start # Symmetric augmentation if n_aug > 1: