From 88857bc35f546c41e5164577c3d4e85b5a334949 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CEduardo?= Date: Sun, 13 Aug 2023 09:55:03 +0100 Subject: [PATCH] feat: add navix wrapper --- helx/environment/interop.py | 13 +-- helx/environment/navix.py | 185 ++++++++++++++++++++---------------- 2 files changed, 109 insertions(+), 89 deletions(-) diff --git a/helx/environment/interop.py b/helx/environment/interop.py index 7d46a18..aa8a974 100644 --- a/helx/environment/interop.py +++ b/helx/environment/interop.py @@ -22,6 +22,7 @@ import gymnasium.core from gymnax.environments.environment import Environment as GymnaxEnvironment, EnvParams import brax.envs +import navix as nx from .environment import EnvironmentWrapper from .bsuite import BsuiteWrapper @@ -30,7 +31,7 @@ from .gymnasium import GymnasiumWrapper from .gymnax import GymnaxWrapper from .brax import BraxWrapper -# from .navix import NavixWrapper +from .navix import NavixWrapper @overload @@ -68,9 +69,9 @@ def to_helx(env: brax.envs.Env) -> BraxWrapper: ... -# @overload -# def to_helx(env: nx.environments.Environment) -> NavixWrapper: -# ... +@overload +def to_helx(env: nx.environments.Environment) -> NavixWrapper: + ... def to_helx(env: Any) -> EnvironmentWrapper: @@ -99,8 +100,8 @@ def to_helx(env: Any) -> EnvironmentWrapper: return GymnaxWrapper.wraps(env) elif isinstance(env_for_type, brax.envs.Env): return BraxWrapper.wraps(env) - # elif isinstance(env_for_type, nx.environments.Environment): - # return NavixWrapper.wraps(env) + elif isinstance(env_for_type, nx.environments.Environment): + return NavixWrapper.wraps(env) else: raise TypeError( f"Environment type {type(env)} is not supported. " diff --git a/helx/environment/navix.py b/helx/environment/navix.py index d8b7363..7bbf46f 100644 --- a/helx/environment/navix.py +++ b/helx/environment/navix.py @@ -1,83 +1,102 @@ -# # Copyright 2023 The Helx Authors. -# # -# # Licensed under the Apache License, Version 2.0 (the "License"); -# # you may not use this file except in compliance with the License. -# # You may obtain a copy of the License at -# # -# # http://www.apache.org/licenses/LICENSE-2.0 -# # -# # Unless required by applicable law or agreed to in writing, software -# # distributed under the License is distributed on an "AS IS" BASIS, -# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# # See the License for the specific language governing permissions and -# # limitations under the License. -# from __future__ import annotations -# from typing import overload - -# import jax -# import jax.numpy as jnp -# import navix as nx -# from jax.random import KeyArray - -# from ..spaces import Continuous, Discrete, Space -# from ..mdp import StepType, Timestep -# from .environment import EnvironmentWrapper - - -# @overload -# def to_helx(space: nx.spaces.Discrete) -> Discrete: -# ... - - -# @overload -# def to_helx(space: nx.spaces.Continuous) -> Continuous: -# ... - - -# @overload -# def to_helx(space: nx.spaces.Space) -> Space: -# ... - - -# def to_helx(space: nx.spaces.Space) -> Space: -# if isinstance(space, nx.spaces.Discrete): -# return Discrete(space.maximum.item(), shape=space.shape, dtype=space.dtype) -# elif isinstance(space, nx.spaces.Continuous): -# return Continuous( -# shape=space.shape, -# minimum=space.minimum.item(), -# maximum=space.maximum.item(), -# ) -# else: -# raise NotImplementedError( -# "Cannot convert dm_env space of type {}".format(type(space)) -# ) - - -# class NavixWrapper(EnvironmentWrapper): -# """Static class to convert between Gymnax environments and helx environments.""" -# env: nx.environments.Environment - -# @classmethod -# def wraps(cls, env: nx.environments.Environment) -> NavixWrapper: -# return cls( -# env=env, -# observation_space=to_helx(env.observation_space), -# action_space=to_helx(env.action_space), -# reward_space=Continuous(), -# ) - -# def reset(self, key: KeyArray) -> Timestep: -# timestep = self.env.reset(key) -# return Timestep( -# t=jnp.asarray(0), -# observation=timestep.observation, -# reward=timestep.reward, -# step_type=StepType.TRANSITION, -# action=jnp.asarray(-1), -# state=timestep.state, -# info=timestep.info -# ) - -# def _step(self, key: KeyArray, timestep: Timestep, action: jax.Array) -> Timestep: -# self.env.step(timestep) +# Copyright 2023 The Helx Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations +from typing import overload + +import jax +import jax.numpy as jnp +import navix as nx +from jax.random import KeyArray + +from ..spaces import Continuous, Discrete, Space +from ..mdp import TRANSITION, Timestep +from .environment import EnvironmentWrapper + + +@overload +def to_helx(space: nx.spaces.Discrete) -> Discrete: + ... + + +@overload +def to_helx(space: nx.spaces.Continuous) -> Continuous: + ... + + +@overload +def to_helx(space: nx.spaces.Space) -> Space: + ... + + +def to_helx(space: nx.spaces.Space) -> Space: + if isinstance(space, nx.spaces.Discrete): + return Discrete(space.maximum.item(), shape=space.shape, dtype=space.dtype) + elif isinstance(space, nx.spaces.Continuous): + return Continuous( + shape=space.shape, + minimum=space.minimum.item(), + maximum=space.maximum.item(), + ) + else: + raise NotImplementedError( + "Cannot convert dm_env space of type {}".format(type(space)) + ) + + +class NavixWrapper(EnvironmentWrapper): + """Static class to convert between Gymnax environments and helx environments.""" + env: nx.environments.Environment + + @classmethod + def wraps(cls, env: nx.environments.Environment) -> NavixWrapper: + return cls( + env=env, + observation_space=to_helx(env.observation_space), + action_space=to_helx(env.action_space), + reward_space=Continuous(), + ) + + def reset(self, key: KeyArray) -> Timestep: + timestep = self.env.reset(key) + return Timestep( + t=jnp.asarray(0), + observation=timestep.observation, + reward=timestep.reward, + step_type=TRANSITION, + action=jnp.asarray(-1), + state=timestep.state, + info=timestep.info + ) + + def _step(self, key: KeyArray, timestep: Timestep, action: jax.Array) -> Timestep: + current_step = nx.environments.environment.Timestep( + t=timestep.t, + observation=timestep.observation, + reward=timestep.reward, + step_type=timestep.step_type, + action=action, + state=timestep.state, + info=timestep.info + ) + nexst_step = self.env.step(current_step, action) + return Timestep( + t=nexst_step.t, + observation=nexst_step.observation, + reward=nexst_step.reward, + step_type=nexst_step.step_type, + action=action, + state=nexst_step.state, + info=nexst_step.info + ) +