From 88fc99cba0cf12efe5904c48573e29b834ec2f80 Mon Sep 17 00:00:00 2001 From: epignatelli Date: Sat, 13 Jan 2024 07:29:43 +0000 Subject: [PATCH 1/2] fix: `Environment.wraps` signature now has correct arguments --- helx/envs/environment.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/helx/envs/environment.py b/helx/envs/environment.py index 3e8f6bb..2ec9e47 100644 --- a/helx/envs/environment.py +++ b/helx/envs/environment.py @@ -57,8 +57,8 @@ def step( class EnvironmentWrapper(Environment): env: Any = struct.field(pytree_node=False) - @abc.abstractmethod - def wraps(self, env: Any) -> Timestep: + @abc.abstractclassmethod + def wraps(cls, env: Any) -> EnvironmentWrapper: raise NotImplementedError() def unwrapped(self) -> Any: From fab10cc761be308e70aa367db5063c16b7e5aa2d Mon Sep 17 00:00:00 2001 From: epignatelli Date: Fri, 12 Jan 2024 16:25:36 +0000 Subject: [PATCH 2/2] feat: Timestep now supports indexing --- helx/_version.py | 2 +- helx/base/mdp.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/helx/_version.py b/helx/_version.py index 6c5cb36..309a43e 100644 --- a/helx/_version.py +++ b/helx/_version.py @@ -1,4 +1,4 @@ # file generated by setuptools_scm # don't change, don't track in version control -__version__ = "1.1.4" +__version__ = "1.1.5" __version_info__ = tuple(int(i) for i in __version__.split(".") if i.isdigit()) diff --git a/helx/base/mdp.py b/helx/base/mdp.py index 52c6d90..0f9fa6c 100644 --- a/helx/base/mdp.py +++ b/helx/base/mdp.py @@ -19,6 +19,7 @@ from jax import Array import jax.numpy as jnp +import jax.tree_util as jtu from flax import struct @@ -42,3 +43,9 @@ class Timestep(struct.PyTreeNode): """The true state of the MDP, $s_t$ before taking action `action`""" info: Dict[str, Any] = struct.field(default_factory=dict) """Additional information about the environment. Useful for accumulations (e.g. returns)""" + + def __getitem__(self, key: Any) -> Timestep: + return jtu.tree_map(lambda x: x[key], self) + + def __setitem__(self, key: Any, value: Any) -> Timestep: + return jtu.tree_map(lambda x: x.at[key].set(value), self)