From 15137bd2bcd897ee2ca6f8a831ec9d136ff7b3fb Mon Sep 17 00:00:00 2001 From: epignatelli Date: Fri, 12 Jan 2024 16:25:36 +0000 Subject: [PATCH] feat: Timestep now supports indexing --- helx/_version.py | 2 +- helx/base/mdp.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/helx/_version.py b/helx/_version.py index 6c5cb36..309a43e 100644 --- a/helx/_version.py +++ b/helx/_version.py @@ -1,4 +1,4 @@ # file generated by setuptools_scm # don't change, don't track in version control -__version__ = "1.1.4" +__version__ = "1.1.5" __version_info__ = tuple(int(i) for i in __version__.split(".") if i.isdigit()) diff --git a/helx/base/mdp.py b/helx/base/mdp.py index 52c6d90..0f9fa6c 100644 --- a/helx/base/mdp.py +++ b/helx/base/mdp.py @@ -19,6 +19,7 @@ from jax import Array import jax.numpy as jnp +import jax.tree_util as jtu from flax import struct @@ -42,3 +43,9 @@ class Timestep(struct.PyTreeNode): """The true state of the MDP, $s_t$ before taking action `action`""" info: Dict[str, Any] = struct.field(default_factory=dict) """Additional information about the environment. Useful for accumulations (e.g. returns)""" + + def __getitem__(self, key: Any) -> Timestep: + return jtu.tree_map(lambda x: x[key], self) + + def __setitem__(self, key: Any, value: Any) -> Timestep: + return jtu.tree_map(lambda x: x.at[key].set(value), self)