Skip to content

Commit

Permalink
polish
Browse files Browse the repository at this point in the history
  • Loading branch information
rongkunxue committed Aug 15, 2024
1 parent 0893f0f commit 35afe58
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions grl/algorithms/gmpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,11 +701,22 @@ def evaluate(model, train_epoch, repeat=1):
evaluation_results = dict()

def policy(obs: np.ndarray) -> np.ndarray:
obs = torch.tensor(
obs,
dtype=torch.float32,
device=config.model.GPPolicy.device,
).unsqueeze(0)
if isinstance(obs, torch.Tensor):
obs = torch.tensor(
obs,
dtype=torch.float32,
device=config.model.GPPolicy.device,
).unsqueeze(0)
elif isinstance(obs, dict):
for key in obs:
obs[key] = torch.tensor(
obs[key],
dtype=torch.float32,
device=config.model.GPPolicy.device
).unsqueeze(0)
if obs[key].dim() == 1 and obs[key].shape[0] == 1:
obs[key] = obs[key].unsqueeze(1)
obs = TensorDict(obs, batch_size=[1])
action = (
model.sample(
condition=obs,
Expand Down

0 comments on commit 35afe58

Please sign in to comment.