Skip to content

Commit

Permalink
[BugFix] fix reward dim problem of shared baseline
Browse files Browse the repository at this point in the history
  • Loading branch information
cbhua committed Oct 18, 2023
1 parent e5f9df1 commit b952d0f
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions rl4co/models/rl/reinforce/baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def eval(self, td, reward, env=None):
class SharedBaseline(REINFORCEBaseline):
"""Shared baseline: return mean of reward as baseline"""

def eval(self, td, reward, env=None, on_dim=1): # e.g. [batch, pomo, ...]
return reward.mean(dim=on_dim, keepdims=True), 0
def eval(self, td, reward, env=None): # e.g. [batch, pomo, ...]
return reward.mean(), 0

Check warning on line 57 in rl4co/models/rl/reinforce/baselines.py

View check run for this annotation

Codecov / codecov/patch

rl4co/models/rl/reinforce/baselines.py#L57

Added line #L57 was not covered by tests


class ExponentialBaseline(REINFORCEBaseline):
Expand Down

0 comments on commit b952d0f

Please sign in to comment.