From b952d0f569794a325460245aff4edba5941d4cb7 Mon Sep 17 00:00:00 2001 From: Chuanbo Hua Date: Thu, 19 Oct 2023 01:27:48 +0900 Subject: [PATCH] [BugFix] fix reward dim problem of shared baseline --- rl4co/models/rl/reinforce/baselines.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rl4co/models/rl/reinforce/baselines.py b/rl4co/models/rl/reinforce/baselines.py index 8117edd1..467b9a62 100644 --- a/rl4co/models/rl/reinforce/baselines.py +++ b/rl4co/models/rl/reinforce/baselines.py @@ -53,8 +53,8 @@ def eval(self, td, reward, env=None): class SharedBaseline(REINFORCEBaseline): """Shared baseline: return mean of reward as baseline""" - def eval(self, td, reward, env=None, on_dim=1): # e.g. [batch, pomo, ...] - return reward.mean(dim=on_dim, keepdims=True), 0 + def eval(self, td, reward, env=None): # e.g. [batch, pomo, ...] + return reward.mean(), 0 class ExponentialBaseline(REINFORCEBaseline):