diff --git a/helx/environment/gym.py b/helx/environment/gym.py index df32f69..34f6e62 100644 --- a/helx/environment/gym.py +++ b/helx/environment/gym.py @@ -52,7 +52,12 @@ def state(self) -> Array: return self._current_observation def reset(self, seed: int | None = None) -> Timestep: - obs, _ = self._env.reset(seed=seed) + try: + obs, _ = self._env.reset(seed=seed) + except TypeError: + # TODO(epignatelli): remove try/except when gym3 is updated. + # see: https://github.com/openai/gym3/issues/8 + obs, _ = self._env.reset() self._current_observation = jnp.asarray(obs) return Timestep(obs, None, StepType.TRANSITION) @@ -60,6 +65,7 @@ def step(self, action: Action) -> Timestep: action_ = np.asarray(action) next_step = self._env.step(action_) self._current_observation = jnp.asarray(next_step[0]) + gym.core.ObsType return Timestep.from_gym(next_step) def seed(self, seed: int) -> None: