diff --git a/brax/envs/wrappers/gym.py b/brax/envs/wrappers/gym.py index 13ba55cb..8d159b0c 100644 --- a/brax/envs/wrappers/gym.py +++ b/brax/envs/wrappers/gym.py @@ -77,12 +77,12 @@ def step(self, action): def seed(self, seed: int = 0): self._key = jax.random.PRNGKey(seed) - def render(self, mode='human'): + def render(self, mode='human', width=256, height=256): if mode == 'rgb_array': sys, state = self._env.sys, self._state if state is None: raise RuntimeError('must call reset or step before rendering') - return image.render_array(sys, state.pipeline_state, 256, 256) + return image.render_array(sys, state.pipeline_state, width=width, height=height) else: return super().render(mode=mode) # just raise an exception @@ -144,11 +144,12 @@ def step(self, action): def seed(self, seed: int = 0): self._key = jax.random.PRNGKey(seed) - def render(self, mode='human'): + def render(self, mode='human', width=256, height=256): if mode == 'rgb_array': sys, state = self._env.sys, self._state if state is None: raise RuntimeError('must call reset or step before rendering') - return image.render_array(sys, state.pipeline_state.take(0), 256, 256) + state_list = [state.take(i).pipeline_state for i in range(self.num_envs)] + return np.stack(image.render_array(sys, state_list, width=width, height=height)) else: return super().render(mode=mode) # just raise an exception diff --git a/brax/envs/wrappers/gym_test.py b/brax/envs/wrappers/gym_test.py index 6e668847..f7e0cb9f 100644 --- a/brax/envs/wrappers/gym_test.py +++ b/brax/envs/wrappers/gym_test.py @@ -43,7 +43,22 @@ def test_vector_action_space(self): np.testing.assert_array_equal( env.action_space.high, np.tile(base_env.sys.actuator.ctrl_range[:, 1], [256, 1])) - + + def test_render(self): + """Tests rendering in the GymWrapper.""" + base_env = envs.create('pusher') + env = gym.GymWrapper(base_env) + _ = env.reset() + img = env.render(mode='rgb_array', width=250, height=236) + self.assertEqual(img.shape, (236, 250, 3)) + + def test_vector_render(self): + """Tests rendering in the VectorGymWrapper.""" + base_env = envs.create('pusher') + env = gym.VectorGymWrapper(training.VmapWrapper(base_env, batch_size=2)) + _ = env.reset() + img = env.render(mode='rgb_array', width=128, height=128) + self.assertEqual(img.shape, (2, 128, 128, 3)) if __name__ == '__main__': absltest.main()