diff --git a/rl4co/models/rl/reinforce/baselines.py b/rl4co/models/rl/reinforce/baselines.py index 8117edd1..467b9a62 100644 --- a/rl4co/models/rl/reinforce/baselines.py +++ b/rl4co/models/rl/reinforce/baselines.py @@ -53,8 +53,8 @@ def eval(self, td, reward, env=None): class SharedBaseline(REINFORCEBaseline): """Shared baseline: return mean of reward as baseline""" - def eval(self, td, reward, env=None, on_dim=1): # e.g. [batch, pomo, ...] - return reward.mean(dim=on_dim, keepdims=True), 0 + def eval(self, td, reward, env=None): # e.g. [batch, pomo, ...] + return reward.mean(), 0 class ExponentialBaseline(REINFORCEBaseline):