Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 576391312
Change-Id: Ic485e50f983edf6096e0f753b42c0eff84bb1408
  • Loading branch information
Brax Team authored and btaba committed Oct 25, 2023
1 parent 175cf3e commit 1630403
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 5 deletions.
33 changes: 33 additions & 0 deletions brax/generalized/pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from absl.testing import parameterized
from brax import test_utils
from brax.generalized import pipeline
from brax.io import mjcf
import jax
from jax import numpy as jp
import numpy as np
Expand Down Expand Up @@ -46,5 +47,37 @@ def test_forward(self, xml_file):
np.testing.assert_allclose(state.qd, mj_next.qvel, atol=0.5)


class GradientTest(absltest.TestCase):
"""Tests that gradients are not NaN."""

def test_grad(self):
"""Tests that gradients are not NaN."""
xml = """
<mujoco>
<worldbody>
<body>
<joint type="slide" axis="1 0 0" damping="1"/>
<joint type="slide" axis="0 1 0" damping="1"/>
<geom size="0.1" mass="1"/>
</body>
</worldbody>
</mujoco>
"""
sys = mjcf.loads(xml)
init_state = jax.jit(pipeline.init)(
sys, sys.init_q, jp.zeros(sys.qd_size())
)

def fn(xd):
qd = jp.zeros(sys.qd_size()).at[0].set(xd)
state = init_state.replace(qd=qd)
for _ in range(10):
state = jax.jit(pipeline.step)(sys, state, None)
return state.qd[0]

grad = jax.grad(fn)(-1.0)
self.assertFalse(np.isnan(grad))


if __name__ == '__main__':
absltest.main()
10 changes: 5 additions & 5 deletions brax/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,13 +303,13 @@ def body_fn(carry, _):
a_inv, r, err = carry
a_inv_next = a_inv @ (np.eye(a.shape[0]) + r)
r_next = np.eye(a.shape[0]) - a @ a_inv_next
err_next = jp.linalg.norm(r_next)
err_next = safe_norm(r_next)
a_inv_next = jp.where(err_next < err, a_inv_next, a_inv)
return (a_inv_next, r_next, err_next), None

# ensure ||I - X0 @ A|| < 1, in order to guarantee convergence
r0 = jp.eye(a.shape[0]) - a @ a_inv
a_inv = jp.where(jp.linalg.norm(r0) > 1, 0.5 * a.T / jp.trace(a @ a.T), a_inv)
a_inv = jp.where(safe_norm(r0) > 1, 0.5 * a.T / jp.trace(a @ a.T), a_inv)
(a_inv, _, _), _ = jax.lax.scan(body_fn, (a_inv, r0, 1.0), None, num_iter)

return a_inv
Expand All @@ -332,9 +332,9 @@ def safe_norm(

is_zero = jp.allclose(x, 0.0)
# temporarily swap x with ones if is_zero, then swap back
x = jp.where(is_zero, jp.ones_like(x), x)
n = jp.linalg.norm(x, axis=axis)
n = jp.where(is_zero, 0.0, n)
x = x + is_zero * 1.0
n = jp.linalg.norm(x) * (1.0 - is_zero)

return n


Expand Down
3 changes: 3 additions & 0 deletions docs/release-notes/v0.9.4.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Brax v0.9.4 Pre-Release Notes

* Fixes gradients for generalized by changing a jp.linalg.norm to safe_norm.

0 comments on commit 1630403

Please sign in to comment.