Skip to content

Commit

Permalink
rearrange networks
Browse files Browse the repository at this point in the history
  • Loading branch information
epignatelli committed Jan 12, 2023
1 parent 99bc8e1 commit cb413a5
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 20 deletions.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import cast

import gymnasium
import bsuite
import gym
import optax
from absl import app, flags, logging

Expand All @@ -24,7 +23,7 @@ def main(argv):
logging.info("Starting")

# environment
env = bsuite.load_from_id("catch/0")
env = gym.make("procgen:procgen-coinrun-v0", max_episode_steps=100)
env = helx.environment.make_from(env)

# optimiser
Expand Down Expand Up @@ -59,7 +58,7 @@ def main(argv):
network=network, optimiser=optimiser, hparams=hparams, seed=0
)

helx.experiment.run(agent, env, 100)
helx.experiment.run(agent, env, 2)


if __name__ == "__main__":
Expand Down
24 changes: 12 additions & 12 deletions helx/networks/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,18 @@ def __call__(self, x: Array, *args, **kwargs) -> Array:
return x


class Temperature(nn.Module):
initial_temperature: float = 1.0

@nn.compact
def __call__(self, observation: Array, action: Array) -> Array:
log_temperature = self.param(
"log_temperature",
init_fn=lambda key: jnp.full((), jnp.log(self.initial_temperature)),
)
return jnp.exp(log_temperature)


class AgentNetwork(nn.Module):
"""Defines the network architecture of an agent, and can be used as it is.
Args:
Expand Down Expand Up @@ -291,15 +303,3 @@ def extra(
action,
)
)


class Temperature(nn.Module):
initial_temperature: float = 1.0

@nn.compact
def __call__(self, observation: Array, action: Array) -> Array:
log_temperature = self.param(
"log_temperature",
init_fn=lambda key: jnp.full((), jnp.log(self.initial_temperature)),
)
return jnp.exp(log_temperature)
9 changes: 5 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
black
flake8
typing-extensions
pylint
pytest
absl-py
jax
jax_dataclasses
chex
optax
rlax
flax
dm_env
wandb
absl-py
gymnasium[all]>=0.26 # 0.26 introduced the termination/truncation API
gym[all]>=0.26 # 0.26 introduced the termination/truncation API
mujoco
procgen
bsuite
typing-extensions
pylint
wandb

0 comments on commit cb413a5

Please sign in to comment.