diff --git a/assume/strategies/learning_strategies.py b/assume/strategies/learning_strategies.py index 230c7519..a1cc94e4 100644 --- a/assume/strategies/learning_strategies.py +++ b/assume/strategies/learning_strategies.py @@ -424,7 +424,7 @@ def load_actor_params(self, load_path): """ directory = f"{load_path}/actors/actor_{self.unit_id}.pt" - params = th.load(directory) + params = th.load(directory, map_location=self.device) self.actor = Actor(self.obs_dim, self.act_dim, self.float_type) self.actor.load_state_dict(params["actor"])