diff --git a/examples/ppo_procgen_coinrun.py b/examples/ppo_minihack_keyroom.py similarity index 100% rename from examples/ppo_procgen_coinrun.py rename to examples/ppo_minihack_keyroom.py diff --git a/examples/sacd_gym_cartpole.py b/examples/sacd_procgen_coinrun.py similarity index 92% rename from examples/sacd_gym_cartpole.py rename to examples/sacd_procgen_coinrun.py index eb547f0..76dd93e 100644 --- a/examples/sacd_gym_cartpole.py +++ b/examples/sacd_procgen_coinrun.py @@ -1,7 +1,6 @@ from typing import cast -import gymnasium -import bsuite +import gym import optax from absl import app, flags, logging @@ -24,7 +23,7 @@ def main(argv): logging.info("Starting") # environment - env = bsuite.load_from_id("catch/0") + env = gym.make("procgen:procgen-coinrun-v0", max_episode_steps=100) env = helx.environment.make_from(env) # optimiser @@ -59,7 +58,7 @@ def main(argv): network=network, optimiser=optimiser, hparams=hparams, seed=0 ) - helx.experiment.run(agent, env, 100) + helx.experiment.run(agent, env, 2) if __name__ == "__main__": diff --git a/helx/networks/modules.py b/helx/networks/modules.py index 7f33cc7..5bea33c 100644 --- a/helx/networks/modules.py +++ b/helx/networks/modules.py @@ -138,6 +138,18 @@ def __call__(self, x: Array, *args, **kwargs) -> Array: return x +class Temperature(nn.Module): + initial_temperature: float = 1.0 + + @nn.compact + def __call__(self, observation: Array, action: Array) -> Array: + log_temperature = self.param( + "log_temperature", + init_fn=lambda key: jnp.full((), jnp.log(self.initial_temperature)), + ) + return jnp.exp(log_temperature) + + class AgentNetwork(nn.Module): """Defines the network architecture of an agent, and can be used as it is. Args: @@ -291,15 +303,3 @@ def extra( action, ) ) - - -class Temperature(nn.Module): - initial_temperature: float = 1.0 - - @nn.compact - def __call__(self, observation: Array, action: Array) -> Array: - log_temperature = self.param( - "log_temperature", - init_fn=lambda key: jnp.full((), jnp.log(self.initial_temperature)), - ) - return jnp.exp(log_temperature) diff --git a/requirements.txt b/requirements.txt index 537cc25..f386c92 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,9 @@ black flake8 +typing-extensions +pylint pytest +absl-py jax jax_dataclasses chex @@ -8,11 +11,9 @@ optax rlax flax dm_env -wandb -absl-py gymnasium[all]>=0.26 # 0.26 introduced the termination/truncation API gym[all]>=0.26 # 0.26 introduced the termination/truncation API mujoco +procgen bsuite -typing-extensions -pylint \ No newline at end of file +wandb \ No newline at end of file