diff --git a/brax/generalized/pipeline_test.py b/brax/generalized/pipeline_test.py index cef75c83..7730c55f 100644 --- a/brax/generalized/pipeline_test.py +++ b/brax/generalized/pipeline_test.py @@ -19,6 +19,7 @@ from absl.testing import parameterized from brax import test_utils from brax.generalized import pipeline +from brax.io import mjcf import jax from jax import numpy as jp import numpy as np @@ -46,5 +47,37 @@ def test_forward(self, xml_file): np.testing.assert_allclose(state.qd, mj_next.qvel, atol=0.5) +class GradientTest(absltest.TestCase): + """Tests that gradients are not NaN.""" + + def test_grad(self): + """Tests that gradients are not NaN.""" + xml = """ + + + + + + + + + + """ + sys = mjcf.loads(xml) + init_state = jax.jit(pipeline.init)( + sys, sys.init_q, jp.zeros(sys.qd_size()) + ) + + def fn(xd): + qd = jp.zeros(sys.qd_size()).at[0].set(xd) + state = init_state.replace(qd=qd) + for _ in range(10): + state = jax.jit(pipeline.step)(sys, state, None) + return state.qd[0] + + grad = jax.grad(fn)(-1.0) + self.assertFalse(np.isnan(grad)) + + if __name__ == '__main__': absltest.main() diff --git a/brax/math.py b/brax/math.py index 57ad1c06..816414f3 100644 --- a/brax/math.py +++ b/brax/math.py @@ -303,13 +303,13 @@ def body_fn(carry, _): a_inv, r, err = carry a_inv_next = a_inv @ (np.eye(a.shape[0]) + r) r_next = np.eye(a.shape[0]) - a @ a_inv_next - err_next = jp.linalg.norm(r_next) + err_next = safe_norm(r_next) a_inv_next = jp.where(err_next < err, a_inv_next, a_inv) return (a_inv_next, r_next, err_next), None # ensure ||I - X0 @ A|| < 1, in order to guarantee convergence r0 = jp.eye(a.shape[0]) - a @ a_inv - a_inv = jp.where(jp.linalg.norm(r0) > 1, 0.5 * a.T / jp.trace(a @ a.T), a_inv) + a_inv = jp.where(safe_norm(r0) > 1, 0.5 * a.T / jp.trace(a @ a.T), a_inv) (a_inv, _, _), _ = jax.lax.scan(body_fn, (a_inv, r0, 1.0), None, num_iter) return a_inv @@ -332,9 +332,9 @@ def safe_norm( is_zero = jp.allclose(x, 0.0) # temporarily swap x with ones if is_zero, then swap back - x = jp.where(is_zero, jp.ones_like(x), x) - n = jp.linalg.norm(x, axis=axis) - n = jp.where(is_zero, 0.0, n) + x = x + is_zero * 1.0 + n = jp.linalg.norm(x) * (1.0 - is_zero) + return n diff --git a/docs/release-notes/v0.9.4.md b/docs/release-notes/v0.9.4.md new file mode 100644 index 00000000..50e3d705 --- /dev/null +++ b/docs/release-notes/v0.9.4.md @@ -0,0 +1,3 @@ +# Brax v0.9.4 Pre-Release Notes + +* Fixes gradients for generalized by changing a jp.linalg.norm to safe_norm.