diff --git a/README.md b/README.md index fbe28d0e..035bcc49 100644 --- a/README.md +++ b/README.md @@ -100,7 +100,7 @@ If you would like to reference Brax in a publication, please use: author = {C. Daniel Freeman and Erik Frey and Anton Raichuk and Sertan Girgin and Igor Mordatch and Olivier Bachem}, title = {Brax - A Differentiable Physics Engine for Large Scale Rigid Body Simulation}, url = {http://github.com/google/brax}, - version = {0.9.4}, + version = {0.10.0}, year = {2021}, } ``` diff --git a/brax/__init__.py b/brax/__init__.py index facac60c..e2084154 100644 --- a/brax/__init__.py +++ b/brax/__init__.py @@ -14,7 +14,7 @@ """Import top-level classes and functions here for encapsulation/clarity.""" -__version__ = '0.9.4' +__version__ = '0.10.0' from brax.base import Motion from brax.base import State diff --git a/brax/contact.py b/brax/contact.py index 757824b1..6afe2362 100644 --- a/brax/contact.py +++ b/brax/contact.py @@ -35,8 +35,7 @@ def get(sys: System, x: Transform) -> Optional[Contact]: Returns: Contact pytree """ - # TODO: use mjx.ncon. - ncon = mjx._src.collision_driver.ncon(sys) + ncon = mjx.ncon(sys) if not ncon: return None diff --git a/brax/io/json.py b/brax/io/json.py index 2b90c5f8..45598479 100644 --- a/brax/io/json.py +++ b/brax/io/json.py @@ -156,7 +156,7 @@ def dumps(sys: System, states: List[State]) -> Text: for id_ in range(sys.ngeom): link_idx = sys.geom_bodyid[id_] - 1 - rgba = sys.mj_model.geom_rgba[id_] + rgba = sys.geom_rgba[id_] if (rgba == [0.5, 0.5, 0.5, 1.0]).all(): # convert the default mjcf color to brax default color rgba = np.array([0.4, 0.33, 0.26, 1.0]) @@ -171,8 +171,7 @@ def dumps(sys: System, states: List[State]) -> Text: } if geom['name'] in ('Mesh', 'Box'): - # TODO: use sys.geom_dataid. - vert, face = _get_mesh(sys.mj_model, sys.mj_model.geom_dataid[id_]) + vert, face = _get_mesh(sys.mj_model, sys.geom_dataid[id_]) geom['vert'] = vert geom['face'] = face diff --git a/brax/io/mjcf.py b/brax/io/mjcf.py index fdfe16df..68512f99 100644 --- a/brax/io/mjcf.py +++ b/brax/io/mjcf.py @@ -385,25 +385,32 @@ def load_model(mj: mujoco.MjModel) -> System: ) # create actuators + # TODO: swap brax actuation for mjx actuation model. ctrl_range = mj.actuator_ctrlrange ctrl_range[~(mj.actuator_ctrllimited == 1), :] = np.array([-np.inf, np.inf]) force_range = mj.actuator_forcerange force_range[~(mj.actuator_forcelimited == 1), :] = np.array([-np.inf, np.inf]) - q_id = np.array([mj.jnt_qposadr[i] for i in mj.actuator_trnid[:, 0]]) - qd_id = np.array([mj.jnt_dofadr[i] for i in mj.actuator_trnid[:, 0]]) bias_q = mj.actuator_biasprm[:, 1] * (mj.actuator_biastype != 0) bias_qd = mj.actuator_biasprm[:, 2] * (mj.actuator_biastype != 0) + # mask actuators since brax only supports joint transmission types + act_mask = mj.actuator_trntype == mujoco.mjtTrn.mjTRN_JOINT + trnid = mj.actuator_trnid[act_mask, 0].astype(np.uint32) + q_id = mj.jnt_qposadr[trnid] + qd_id = mj.jnt_dofadr[trnid] + act_kwargs = { + 'gain': mj.actuator_gainprm[:, 0], + 'gear': mj.actuator_gear[:, 0], + 'ctrl_range': ctrl_range, + 'force_range': force_range, + 'bias_q': bias_q, + 'bias_qd': bias_qd, + } + act_kwargs = jax.tree_map(lambda x: x[act_mask], act_kwargs) - # TODO: remove brax actuators actuator = Actuator( # pytype: disable=wrong-arg-types q_id=q_id, qd_id=qd_id, - gain=mj.actuator_gainprm[:, 0], - gear=mj.actuator_gear[:, 0], - ctrl_range=ctrl_range, - force_range=force_range, - bias_q=bias_q, - bias_qd=bias_qd, + **act_kwargs ) # create non-pytree params. these do not live on device directly, and they diff --git a/brax/io/mjcf_test.py b/brax/io/mjcf_test.py index b4f0f181..a70caa4c 100644 --- a/brax/io/mjcf_test.py +++ b/brax/io/mjcf_test.py @@ -18,6 +18,7 @@ from absl.testing import absltest from brax import test_utils from brax.io import mjcf +import mujoco import numpy as np assert_almost_equal = np.testing.assert_array_almost_equal @@ -131,6 +132,14 @@ def test_world_fromto(self): sys = test_utils.load_fixture('world_fromto.xml') mjcf.validate_model(sys.mj_model) + def test_loads_different_transmission(self): + """Tests that the brax model loads with different transmission types.""" + mj = test_utils.load_fixture_mujoco('ant.xml') + mj.actuator_trntype[0] = mujoco.mjtTrn.mjTRN_SITE + mjcf.load_model(mj) # loads without raising an error + + with self.assertRaisesRegex(NotImplementedError, 'transmission types'): + mjcf.validate_model(mj) # raises an error if __name__ == '__main__': absltest.main() diff --git a/docs/release-notes/next-release.md b/docs/release-notes/next-release.md index 14069ed7..ca4ecb57 100644 --- a/docs/release-notes/next-release.md +++ b/docs/release-notes/next-release.md @@ -1,5 +1 @@ # Brax Release Notes - -* Rebase brax System and State onto mjx.Model and mjx.Data. -* Use the MuJoCo renderer instead of pytinyrenderer for brax.io.image. -* Separate validation logic from the model loader in brax.io.mjcf. diff --git a/docs/release-notes/v0.10.0.md b/docs/release-notes/v0.10.0.md new file mode 100644 index 00000000..971caa4a --- /dev/null +++ b/docs/release-notes/v0.10.0.md @@ -0,0 +1,9 @@ +# Brax v0.10.0 Release Notes + +This minor release makes several changes to the brax API, such that [MJX](https://mujoco.readthedocs.io/en/stable/mjx.html) data structures are the core data structures used in brax. This allows for more seamless model loading from `MuJoCo` XMLs, and allows for running `MJX` physics more seamlessly in brax. + +* Rebase brax `System` and `State` onto `mjx.Model` and `mjx.Data`. +* Separate validation logic from the model loading logic in `brax.io.mjcf`. This allows users to load an [MJX](https://mujoco.readthedocs.io/en/stable/mjx.html) model in brax, without hitting validation errors for other physics backends like `positional` and `spring`. +* Remove `System.geoms`, since `brax.System` inherits from `mjx.Model` and all geom information is available in `mjx.Model`. We also update the brax viewer to work with this new schema. +* Delete the brax contact library and use the contact library from `MJX`. +* Use the MuJoCo renderer instead of pytinyrenderer for `brax.io.image`. diff --git a/setup.py b/setup.py index ba333181..2bdd1613 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,7 @@ setup( name="brax", - version="0.9.4", + version="0.10.0", description="A differentiable physics engine written in JAX.", author="Brax Authors", author_email="no-reply@google.com",