Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 605094677
Change-Id: I5983e766a55834b56ef4ab037ec4247398a25851
  • Loading branch information
Brax Team authored and btaba committed Feb 7, 2024
1 parent 532a88a commit f9a4d73
Show file tree
Hide file tree
Showing 9 changed files with 40 additions and 21 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ If you would like to reference Brax in a publication, please use:
author = {C. Daniel Freeman and Erik Frey and Anton Raichuk and Sertan Girgin and Igor Mordatch and Olivier Bachem},
title = {Brax - A Differentiable Physics Engine for Large Scale Rigid Body Simulation},
url = {http://github.com/google/brax},
version = {0.9.4},
version = {0.10.0},
year = {2021},
}
```
Expand Down
2 changes: 1 addition & 1 deletion brax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""Import top-level classes and functions here for encapsulation/clarity."""

__version__ = '0.9.4'
__version__ = '0.10.0'

from brax.base import Motion
from brax.base import State
Expand Down
3 changes: 1 addition & 2 deletions brax/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ def get(sys: System, x: Transform) -> Optional[Contact]:
Returns:
Contact pytree
"""
# TODO: use mjx.ncon.
ncon = mjx._src.collision_driver.ncon(sys)
ncon = mjx.ncon(sys)
if not ncon:
return None

Expand Down
5 changes: 2 additions & 3 deletions brax/io/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def dumps(sys: System, states: List[State]) -> Text:
for id_ in range(sys.ngeom):
link_idx = sys.geom_bodyid[id_] - 1

rgba = sys.mj_model.geom_rgba[id_]
rgba = sys.geom_rgba[id_]
if (rgba == [0.5, 0.5, 0.5, 1.0]).all():
# convert the default mjcf color to brax default color
rgba = np.array([0.4, 0.33, 0.26, 1.0])
Expand All @@ -171,8 +171,7 @@ def dumps(sys: System, states: List[State]) -> Text:
}

if geom['name'] in ('Mesh', 'Box'):
# TODO: use sys.geom_dataid.
vert, face = _get_mesh(sys.mj_model, sys.mj_model.geom_dataid[id_])
vert, face = _get_mesh(sys.mj_model, sys.geom_dataid[id_])
geom['vert'] = vert
geom['face'] = face

Expand Down
25 changes: 16 additions & 9 deletions brax/io/mjcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,25 +385,32 @@ def load_model(mj: mujoco.MjModel) -> System:
)

# create actuators
# TODO: swap brax actuation for mjx actuation model.
ctrl_range = mj.actuator_ctrlrange
ctrl_range[~(mj.actuator_ctrllimited == 1), :] = np.array([-np.inf, np.inf])
force_range = mj.actuator_forcerange
force_range[~(mj.actuator_forcelimited == 1), :] = np.array([-np.inf, np.inf])
q_id = np.array([mj.jnt_qposadr[i] for i in mj.actuator_trnid[:, 0]])
qd_id = np.array([mj.jnt_dofadr[i] for i in mj.actuator_trnid[:, 0]])
bias_q = mj.actuator_biasprm[:, 1] * (mj.actuator_biastype != 0)
bias_qd = mj.actuator_biasprm[:, 2] * (mj.actuator_biastype != 0)
# mask actuators since brax only supports joint transmission types
act_mask = mj.actuator_trntype == mujoco.mjtTrn.mjTRN_JOINT
trnid = mj.actuator_trnid[act_mask, 0].astype(np.uint32)
q_id = mj.jnt_qposadr[trnid]
qd_id = mj.jnt_dofadr[trnid]
act_kwargs = {
'gain': mj.actuator_gainprm[:, 0],
'gear': mj.actuator_gear[:, 0],
'ctrl_range': ctrl_range,
'force_range': force_range,
'bias_q': bias_q,
'bias_qd': bias_qd,
}
act_kwargs = jax.tree_map(lambda x: x[act_mask], act_kwargs)

# TODO: remove brax actuators
actuator = Actuator( # pytype: disable=wrong-arg-types
q_id=q_id,
qd_id=qd_id,
gain=mj.actuator_gainprm[:, 0],
gear=mj.actuator_gear[:, 0],
ctrl_range=ctrl_range,
force_range=force_range,
bias_q=bias_q,
bias_qd=bias_qd,
**act_kwargs
)

# create non-pytree params. these do not live on device directly, and they
Expand Down
9 changes: 9 additions & 0 deletions brax/io/mjcf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from absl.testing import absltest
from brax import test_utils
from brax.io import mjcf
import mujoco
import numpy as np

assert_almost_equal = np.testing.assert_array_almost_equal
Expand Down Expand Up @@ -131,6 +132,14 @@ def test_world_fromto(self):
sys = test_utils.load_fixture('world_fromto.xml')
mjcf.validate_model(sys.mj_model)

def test_loads_different_transmission(self):
"""Tests that the brax model loads with different transmission types."""
mj = test_utils.load_fixture_mujoco('ant.xml')
mj.actuator_trntype[0] = mujoco.mjtTrn.mjTRN_SITE
mjcf.load_model(mj) # loads without raising an error

with self.assertRaisesRegex(NotImplementedError, 'transmission types'):
mjcf.validate_model(mj) # raises an error

if __name__ == '__main__':
absltest.main()
4 changes: 0 additions & 4 deletions docs/release-notes/next-release.md
Original file line number Diff line number Diff line change
@@ -1,5 +1 @@
# Brax Release Notes

* Rebase brax System and State onto mjx.Model and mjx.Data.
* Use the MuJoCo renderer instead of pytinyrenderer for brax.io.image.
* Separate validation logic from the model loader in brax.io.mjcf.
9 changes: 9 additions & 0 deletions docs/release-notes/v0.10.0.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Brax v0.10.0 Release Notes

This minor release makes several changes to the brax API, such that [MJX](https://mujoco.readthedocs.io/en/stable/mjx.html) data structures are the core data structures used in brax. This allows for more seamless model loading from `MuJoCo` XMLs, and allows for running `MJX` physics more seamlessly in brax.

* Rebase brax `System` and `State` onto `mjx.Model` and `mjx.Data`.
* Separate validation logic from the model loading logic in `brax.io.mjcf`. This allows users to load an [MJX](https://mujoco.readthedocs.io/en/stable/mjx.html) model in brax, without hitting validation errors for other physics backends like `positional` and `spring`.
* Remove `System.geoms`, since `brax.System` inherits from `mjx.Model` and all geom information is available in `mjx.Model`. We also update the brax viewer to work with this new schema.
* Delete the brax contact library and use the contact library from `MJX`.
* Use the MuJoCo renderer instead of pytinyrenderer for `brax.io.image`.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

setup(
name="brax",
version="0.9.4",
version="0.10.0",
description="A differentiable physics engine written in JAX.",
author="Brax Authors",
author_email="[email protected]",
Expand Down

0 comments on commit f9a4d73

Please sign in to comment.