From 8497d530304c1e9365f42f5e016641a5a227c592 Mon Sep 17 00:00:00 2001 From: Rohan Patil <31570118+bridgesign@users.noreply.github.com> Date: Fri, 1 Nov 2024 11:05:47 -0700 Subject: [PATCH 1/5] VectorEnv gym visulization correction This is a solution in context of https://github.com/google/brax/issues/535 --- brax/envs/wrappers/gym.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/brax/envs/wrappers/gym.py b/brax/envs/wrappers/gym.py index 13ba55cb..b129af0b 100644 --- a/brax/envs/wrappers/gym.py +++ b/brax/envs/wrappers/gym.py @@ -149,6 +149,6 @@ def render(self, mode='human'): sys, state = self._env.sys, self._state if state is None: raise RuntimeError('must call reset or step before rendering') - return image.render_array(sys, state.pipeline_state.take(0), 256, 256) + return np.stack([image.render_array(sys, state.take(i).pipeline_state, 256, 256) for i in range(self.num_envs)]) else: return super().render(mode=mode) # just raise an exception From 18a83a7e6d49fe3faf47e34ec2243e4359e6410a Mon Sep 17 00:00:00 2001 From: bridgesign Date: Sat, 2 Nov 2024 07:32:45 +0000 Subject: [PATCH 2/5] Added tests. Allow width-height configuration --- brax/envs/wrappers/gym.py | 8 ++++---- brax/envs/wrappers/gym_test.py | 17 ++++++++++++++++- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/brax/envs/wrappers/gym.py b/brax/envs/wrappers/gym.py index b129af0b..fadef799 100644 --- a/brax/envs/wrappers/gym.py +++ b/brax/envs/wrappers/gym.py @@ -77,12 +77,12 @@ def step(self, action): def seed(self, seed: int = 0): self._key = jax.random.PRNGKey(seed) - def render(self, mode='human'): + def render(self, mode='human', width=256, height=256): if mode == 'rgb_array': sys, state = self._env.sys, self._state if state is None: raise RuntimeError('must call reset or step before rendering') - return image.render_array(sys, state.pipeline_state, 256, 256) + return image.render_array(sys, state.pipeline_state, width=width, height=height) else: return super().render(mode=mode) # just raise an exception @@ -144,11 +144,11 @@ def step(self, action): def seed(self, seed: int = 0): self._key = jax.random.PRNGKey(seed) - def render(self, mode='human'): + def render(self, mode='human', width=256, height=256): if mode == 'rgb_array': sys, state = self._env.sys, self._state if state is None: raise RuntimeError('must call reset or step before rendering') - return np.stack([image.render_array(sys, state.take(i).pipeline_state, 256, 256) for i in range(self.num_envs)]) + return np.stack([image.render_array(sys, state.take(i).pipeline_state, width=width, height=height) for i in range(self.num_envs)]) else: return super().render(mode=mode) # just raise an exception diff --git a/brax/envs/wrappers/gym_test.py b/brax/envs/wrappers/gym_test.py index 6e668847..21c06f43 100644 --- a/brax/envs/wrappers/gym_test.py +++ b/brax/envs/wrappers/gym_test.py @@ -43,7 +43,22 @@ def test_vector_action_space(self): np.testing.assert_array_equal( env.action_space.high, np.tile(base_env.sys.actuator.ctrl_range[:, 1], [256, 1])) - + + def test_render(self): + """Tests rendering in the GymWrapper.""" + base_env = envs.create('pusher') + env = gym.GymWrapper(base_env) + env.reset() + img = env.render(mode='rgb_array', width=250, height=236) + self.assertEqual(img.shape, (236, 250, 3)) + + def test_vector_render(self): + """Tests rendering in the VectorGymWrapper.""" + base_env = envs.create('pusher') + env = gym.VectorGymWrapper(training.VmapWrapper(base_env, batch_size=256)) + env.reset() + img = env.render(mode='rgb_array', width=250, height=236) + self.assertEqual(img.shape, (256, 236, 250, 3)) if __name__ == '__main__': absltest.main() From 2928965c5e58293ad75e04868c125f0106623a2b Mon Sep 17 00:00:00 2001 From: bridgesign Date: Tue, 26 Nov 2024 19:51:40 -0800 Subject: [PATCH 3/5] Reduce batch size --- brax/envs/wrappers/gym_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/brax/envs/wrappers/gym_test.py b/brax/envs/wrappers/gym_test.py index 21c06f43..2a1e01ff 100644 --- a/brax/envs/wrappers/gym_test.py +++ b/brax/envs/wrappers/gym_test.py @@ -55,10 +55,10 @@ def test_render(self): def test_vector_render(self): """Tests rendering in the VectorGymWrapper.""" base_env = envs.create('pusher') - env = gym.VectorGymWrapper(training.VmapWrapper(base_env, batch_size=256)) + env = gym.VectorGymWrapper(training.VmapWrapper(base_env, batch_size=16)) env.reset() img = env.render(mode='rgb_array', width=250, height=236) - self.assertEqual(img.shape, (256, 236, 250, 3)) + self.assertEqual(img.shape, (16, 236, 250, 3)) if __name__ == '__main__': absltest.main() From 9cd206a6875d08333de6333bfa480339a4646208 Mon Sep 17 00:00:00 2001 From: bridgesign Date: Wed, 27 Nov 2024 16:04:15 -0800 Subject: [PATCH 4/5] Reduce size of test --- brax/envs/wrappers/gym_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/brax/envs/wrappers/gym_test.py b/brax/envs/wrappers/gym_test.py index 2a1e01ff..29d286a0 100644 --- a/brax/envs/wrappers/gym_test.py +++ b/brax/envs/wrappers/gym_test.py @@ -55,10 +55,10 @@ def test_render(self): def test_vector_render(self): """Tests rendering in the VectorGymWrapper.""" base_env = envs.create('pusher') - env = gym.VectorGymWrapper(training.VmapWrapper(base_env, batch_size=16)) + env = gym.VectorGymWrapper(training.VmapWrapper(base_env, batch_size=2)) env.reset() - img = env.render(mode='rgb_array', width=250, height=236) - self.assertEqual(img.shape, (16, 236, 250, 3)) + img = env.render(mode='rgb_array', width=128, height=128) + self.assertEqual(img.shape, (2, 128, 128, 3)) if __name__ == '__main__': absltest.main() From bb1fcf79471e2d334411e0fb6816f36672db1bec Mon Sep 17 00:00:00 2001 From: bridgesign Date: Sun, 1 Dec 2024 10:34:01 -0800 Subject: [PATCH 5/5] Change render to list pipeline states before conversion --- brax/envs/wrappers/gym.py | 3 ++- brax/envs/wrappers/gym_test.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/brax/envs/wrappers/gym.py b/brax/envs/wrappers/gym.py index fadef799..8d159b0c 100644 --- a/brax/envs/wrappers/gym.py +++ b/brax/envs/wrappers/gym.py @@ -149,6 +149,7 @@ def render(self, mode='human', width=256, height=256): sys, state = self._env.sys, self._state if state is None: raise RuntimeError('must call reset or step before rendering') - return np.stack([image.render_array(sys, state.take(i).pipeline_state, width=width, height=height) for i in range(self.num_envs)]) + state_list = [state.take(i).pipeline_state for i in range(self.num_envs)] + return np.stack(image.render_array(sys, state_list, width=width, height=height)) else: return super().render(mode=mode) # just raise an exception diff --git a/brax/envs/wrappers/gym_test.py b/brax/envs/wrappers/gym_test.py index 29d286a0..f7e0cb9f 100644 --- a/brax/envs/wrappers/gym_test.py +++ b/brax/envs/wrappers/gym_test.py @@ -48,7 +48,7 @@ def test_render(self): """Tests rendering in the GymWrapper.""" base_env = envs.create('pusher') env = gym.GymWrapper(base_env) - env.reset() + _ = env.reset() img = env.render(mode='rgb_array', width=250, height=236) self.assertEqual(img.shape, (236, 250, 3)) @@ -56,7 +56,7 @@ def test_vector_render(self): """Tests rendering in the VectorGymWrapper.""" base_env = envs.create('pusher') env = gym.VectorGymWrapper(training.VmapWrapper(base_env, batch_size=2)) - env.reset() + _ = env.reset() img = env.render(mode='rgb_array', width=128, height=128) self.assertEqual(img.shape, (2, 128, 128, 3))