From 3b72c952f7c990e8e9e355f736e77f10f408bc14 Mon Sep 17 00:00:00 2001 From: Brax Team Date: Fri, 2 Jun 2023 14:48:20 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 537414728 Change-Id: I8f620bc9a9a5b424b5906357f8e1c1c5d1d4c6f7 --- README.md | 2 +- brax/__init__.py | 2 +- brax/actuator.py | 19 ++++-- brax/actuator_test.py | 26 +++++++- brax/base.py | 9 ++- brax/fluid.py | 15 ++--- brax/fluid_test.py | 38 ++++++++++++ brax/generalized/base.py | 18 +++--- brax/generalized/constraint.py | 13 ++-- brax/generalized/dynamics.py | 40 +++++++------ brax/generalized/dynamics_test.py | 4 +- brax/io/mjcf.py | 11 +++- brax/positional/base.py | 3 + brax/positional/joints.py | 15 +---- brax/positional/pipeline.py | 19 ++++-- brax/positional/pipeline_test.py | 3 +- brax/spring/joints.py | 13 +--- brax/spring/pipeline.py | 13 +++- brax/spring/pipeline_test.py | 3 +- brax/test_data/fluid_sphere.xml | 15 +++++ brax/test_data/fluid_two_spheres.xml | 22 +++++++ brax/test_data/single_pendulum_position.xml | 2 +- .../single_pendulum_position_frclimit.xml | 20 +++++++ brax/test_data/single_pendulum_velocity.xml | 2 +- brax/training/agents/apg/train.py | 1 + brax/training/agents/ppo/train.py | 1 + brax/training/agents/sac/train.py | 1 + brax/training/replay_buffers.py | 49 +++------------ brax/training/replay_buffers_test.py | 59 +------------------ docs/release-notes/v0.9.1.md | 8 +++ setup.py | 2 +- 31 files changed, 261 insertions(+), 187 deletions(-) create mode 100644 brax/test_data/fluid_sphere.xml create mode 100644 brax/test_data/fluid_two_spheres.xml create mode 100644 brax/test_data/single_pendulum_position_frclimit.xml create mode 100644 docs/release-notes/v0.9.1.md diff --git a/README.md b/README.md index 6f0d94a9..633aa75a 100644 --- a/README.md +++ b/README.md @@ -97,7 +97,7 @@ If you would like to reference Brax in a publication, please use: author = {C. Daniel Freeman and Erik Frey and Anton Raichuk and Sertan Girgin and Igor Mordatch and Olivier Bachem}, title = {Brax - A Differentiable Physics Engine for Large Scale Rigid Body Simulation}, url = {http://github.com/google/brax}, - version = {0.9.0}, + version = {0.9.1}, year = {2021}, } ``` diff --git a/brax/__init__.py b/brax/__init__.py index 98e974b3..a873cdca 100644 --- a/brax/__init__.py +++ b/brax/__init__.py @@ -14,7 +14,7 @@ """Import top-level classes and functions here for encapsulation/clarity.""" -__version__ = '0.9.0' +__version__ = '0.9.1' from brax.base import Motion from brax.base import State diff --git a/brax/actuator.py b/brax/actuator.py index 39291c97..310a8f34 100644 --- a/brax/actuator.py +++ b/brax/actuator.py @@ -36,12 +36,21 @@ def to_tau( if sys.act_size() == 0: return jp.zeros(sys.qd_size()) - q, qd = q[sys.actuator.q_id], qd[sys.actuator.qd_id] ctrl_range = sys.actuator.ctrl_range + force_range = sys.actuator.force_range + + q, qd = q[sys.actuator.q_id], qd[sys.actuator.qd_id] act = jp.clip(act, ctrl_range[:, 0], ctrl_range[:, 1]) - # TODO: incorporate gain - act = act + q * sys.actuator.bias_q + qd * sys.actuator.bias_qd - act_force = sys.actuator.gear * act - tau = jp.zeros(sys.qd_size()).at[sys.actuator.qd_id].add(act_force) + # See https://github.com/deepmind/mujoco/discussions/754 for why gear is + # used for the bias term. + bias = sys.actuator.gear * ( + q * sys.actuator.bias_q + qd * sys.actuator.bias_qd + ) + + force = sys.actuator.gain * act + bias + force = jp.clip(force, force_range[:, 0], force_range[:, 1]) + + force *= sys.actuator.gear + tau = jp.zeros(sys.qd_size()).at[sys.actuator.qd_id].add(force) return tau diff --git a/brax/actuator_test.py b/brax/actuator_test.py index a5089512..4026c5ee 100644 --- a/brax/actuator_test.py +++ b/brax/actuator_test.py @@ -101,6 +101,25 @@ def test_velocity(self): _, qd = _actuator_step(g_pipeline, sys, q, qd, act=act, dt=0.001, n=200) np.testing.assert_array_almost_equal(qd, jp.array([1]), 3) + def test_force_limitted(self): + """Tests that forcerange limits work on actuators.""" + sys = test_utils.load_fixture('single_pendulum_position_frclimit.xml') + mj_model = test_utils.load_fixture_mujoco( + 'single_pendulum_position_frclimit.xml' + ) + mj_data = mujoco.MjData(mj_model) + q, qd = jp.array(mj_data.qpos), jp.array(mj_data.qvel) + + for act, frclimit in [(1000, 3.1), (-1000, -2.5)]: + act = jp.array([act]) + tau = actuator.to_tau(sys, act, q, qd) + # test that tau matches frclimit * 10, since gear=10 + self.assertEqual(tau[0], frclimit * 10) + # test that tau matches MJ qfrc_actuator + mj_data.ctrl = act + mujoco.mj_step(mj_model, mj_data) + self.assertEqual(tau[0], mj_data.qfrc_actuator) + @parameterized.parameters((g_pipeline,), (s_pipeline,), (p_pipeline,)) def test_three_link_pendulum(self, pipeline): """Tests a three link pendulum with a motor actuator.""" @@ -148,14 +167,17 @@ def test_spherical_pendulum_mj_generalized(self, config): np.testing.assert_array_almost_equal(gq, mq, 3) np.testing.assert_array_almost_equal(gqd, mqd, 3) - # TODO: test spherical pendulum once it's implemented in positional @parameterized.parameters( - ('single_pendulum_position.xml',), 'single_pendulum_motor.xml' + 'single_pendulum_position.xml', + 'single_pendulum_motor.xml', + 'single_spherical_pendulum_position.xml', ) def test_single_pendulum_spring_positional(self, config): sys = test_utils.load_fixture(config) act = jp.array([0.05, 0.1, 0.15])[: sys.act_size()] + q, qd = sys.init_q, jp.zeros(sys.qd_size()) + sq, sqd = _actuator_step(s_pipeline, sys, q, qd, act=act, dt=sys.dt, n=500) pq, pqd = _actuator_step(p_pipeline, sys, q, qd, act=act, dt=sys.dt, n=500) np.testing.assert_array_almost_equal(sq, pq, 2) diff --git a/brax/base.py b/brax/base.py index f3c78c02..43269693 100644 --- a/brax/base.py +++ b/brax/base.py @@ -453,7 +453,9 @@ class Actuator(Base): q_id: (num_actuators,) q index associated with an actuator qd_id: (num_actuators,) qd index associated with an actuator ctrl_range: (num_actuators, 2) actuator control range - gear: (num_actuators,) scaling factor for each actuator torque output + force_range: (num_actuators, 2) actuator force range + gain: (num_actuators,) scaling factor for each actuator control input + gear: (num_actuators,) scaling factor for each actuator force output bias_q: (num_actuators,) bias applied by q (e.g. position actuators) bias_qd: (num_actuators,) bias applied by qd (e.g. velocity actuators) """ @@ -461,6 +463,8 @@ class Actuator(Base): q_id: jp.ndarray qd_id: jp.ndarray ctrl_range: jp.ndarray + force_range: jp.ndarray + gain: jp.ndarray gear: jp.ndarray bias_q: jp.ndarray bias_qd: jp.ndarray @@ -508,6 +512,8 @@ class System: joint_scale_ang: scale for position-based joint rotation update joint_scale_pos: scale for position-based joint position update collide_scale: fraction of position based collide update to apply + enable_fluid: (1,) enables or disables fluid forces based on the + default viscosity and density parameters provided in the XML geom_masks: 64-bit mask determines whether two geoms will be contact tested. lower 32 bits are type, upper 32 bits are affinity. two geoms a, b will be contact tested if a.type & b.affinity != 0 @@ -545,6 +551,7 @@ class System: joint_scale_pos: jp.float32 collide_scale: jp.float32 # non-pytree nodes + enable_fluid: bool = struct.field(pytree_node=False) geom_masks: List[int] = struct.field(pytree_node=False) link_names: List[str] = struct.field(pytree_node=False) link_types: str = struct.field(pytree_node=False) diff --git a/brax/fluid.py b/brax/fluid.py index b05e4e88..88d76250 100644 --- a/brax/fluid.py +++ b/brax/fluid.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -# pylint:disable=g-multiple-import +# pylint:disable=g-multiple-import, g-importing-member """Functions for forces/torques through fluids.""" -from typing import Tuple, Union +from typing import Union from brax.base import Force, Motion, System, Transform import jax import jax.numpy as jp @@ -58,13 +58,13 @@ def force( xd: Motion, mass: jp.ndarray, inertia: jp.ndarray, - subtree_com: Union[jp.ndarray, None] = None, -) -> Tuple[Force, jp.ndarray]: + root_com: Union[jp.ndarray, None] = None, +) -> Force: """Returns force due to motion through a fluid.""" # get the velocity at the com position/orientation x_i = x.vmap().do(sys.link.inertia.transform) - # TODO: remove subtree_com when xd is fixed for stacked joints - offset = x_i.pos - x.pos if subtree_com is None else x_i.pos - subtree_com + # TODO: remove root_com when xd is fixed for stacked joints + offset = x_i.pos - x.pos if root_com is None else x_i.pos - root_com xd_i = x_i.replace(pos=offset).vmap().do(xd) # TODO: add ellipsoid fluid model from mujoco @@ -80,4 +80,5 @@ def force( # rotate back to the world orientation frc = Transform.create(rot=x_i.rot).vmap().do(frc) - return frc, x_i.pos + + return frc diff --git a/brax/fluid_test.py b/brax/fluid_test.py index 5a9e54df..92b0a72f 100644 --- a/brax/fluid_test.py +++ b/brax/fluid_test.py @@ -20,6 +20,8 @@ from brax import test_utils from brax.generalized import dynamics from brax.generalized import pipeline as g_pipeline +from brax.positional import pipeline as p_pipeline +from brax.spring import pipeline as s_pipeline import jax from jax import numpy as jp import mujoco @@ -77,6 +79,42 @@ def test_fluid_mj_generalized(self, config, density, viscosity): np.testing.assert_array_almost_equal(gq, mq, 2) np.testing.assert_array_almost_equal(gqd, mqd, 2) + @parameterized.parameters( + ('fluid_sphere.xml', p_pipeline, 0.0, 0.0), + ('fluid_sphere.xml', p_pipeline, 0.0, 1.3), + ('fluid_sphere.xml', p_pipeline, 2.0, 0.0), + ('fluid_sphere.xml', p_pipeline, 2.0, 1.3), + ('fluid_sphere.xml', s_pipeline, 0.0, 0.0), + ('fluid_sphere.xml', s_pipeline, 0.0, 1.3), + ('fluid_sphere.xml', s_pipeline, 2.0, 0.0), + ('fluid_sphere.xml', s_pipeline, 2.0, 1.3), + ('fluid_two_spheres.xml', p_pipeline, 2.0, 1.3), + ('fluid_two_spheres.xml', s_pipeline, 2.0, 1.3), + ) + def test_fluid_positional_spring(self, config, pipeline, density, viscosity): + """Tests fluid interactions for pbd/spring compared to generalized.""" + sys = test_utils.load_fixture(config) + sys = sys.replace(density=density, viscosity=viscosity) + + q, qd = sys.init_q, jp.zeros(sys.qd_size()) + qd = qd.at[:3].set(jp.array([1, 2, 4])) + qd = qd.at[3].set(jp.pi) + + state_g = jax.jit(g_pipeline.init)(sys, q, qd) + for _ in range(500): + state_g = jax.jit(g_pipeline.step)( + sys, state_g, jp.zeros((sys.act_size(),)) + ) + gq, gqd = state_g.q, state_g.qd + + state = jax.jit(pipeline.init)(sys, q, qd) + for _ in range(500): + state = jax.jit(pipeline.step)(sys, state, jp.zeros((sys.act_size(),))) + tq, tqd = state.q, state.qd + + np.testing.assert_array_almost_equal(tq, gq, 3) + np.testing.assert_array_almost_equal(tqd, gqd, 2) + if __name__ == '__main__': absltest.main() diff --git a/brax/generalized/base.py b/brax/generalized/base.py index cce99fc8..83eb4a53 100644 --- a/brax/generalized/base.py +++ b/brax/generalized/base.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# pylint:disable=g-multiple-import +# pylint:disable=g-multiple-import, g-importing-member """Base types for generalized pipeline.""" from brax import base @@ -26,24 +26,24 @@ class State(base.State): """Dynamic state that changes after every step. Attributes: - com: center of mass position of the kinematic tree containing the link - cinr: inertia in com frame - cd: body velocities in com frame - cdof: dofs in com frame - cdofd: cdof velocity + root_com: (num_links,) center of mass position of link root kinematic tree + cinr: (num_links,) inertia in com frame + cd: (num_links,) link velocities in com frame + cdof: (qd_size,) dofs in com frame + cdofd: (qd_size,) cdof velocity mass_mx: (qd_size, qd_size) mass matrix mass_mx_inv: (qd_size, qd_size) inverse mass matrix contact: calculated contacts con_jac: constraint jacobian con_diag: constraint A diagonal con_aref: constraint reference acceleration - qf_smooth: smooth dynamics force + qf_smooth: (qd_size,) smooth dynamics force qf_constraint: (qd_size,) force from constraints (collision etc) qdd: (qd_size,) joint acceleration vector """ # position/velocity based terms are updated at the end of each step: - com: jp.ndarray + root_com: jp.ndarray cinr: Inertia cd: Motion cdof: Motion @@ -71,7 +71,7 @@ def init( x=x, xd=xd, contact=None, - com=jp.zeros(3), + root_com=jp.zeros(3), cinr=Inertia( Transform.zero((num_links,)), jp.zeros((num_links, 3, 3)), diff --git a/brax/generalized/constraint.py b/brax/generalized/constraint.py index be59d59a..50794c63 100644 --- a/brax/generalized/constraint.py +++ b/brax/generalized/constraint.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# pylint:disable=g-multiple-import +# pylint:disable=g-multiple-import, g-importing-member """Functions for constraint satisfaction.""" from typing import Tuple @@ -57,9 +57,8 @@ def _imp_aref( # See https://mujoco.readthedocs.io/en/latest/modeling.html#solver-parameters stiffness, damping = params[:2] - use_v2_params = (stiffness < 0) & (damping < 0) - b = jp.where(use_v2_params, -damping / dmax, b) - k = jp.where(use_v2_params, -stiffness / (dmax * dmax), k) + b = jp.where(damping <= 0, -damping / dmax, b) + k = jp.where(stiffness <= 0, -stiffness / (dmax * dmax), k) aref = -b * vel - k * imp * pos @@ -153,9 +152,9 @@ def jac_contact( def row_fn(contact): link_a, link_b = contact.link_idx - a = point_jacobian(sys, state.com, state.cdof, contact.pos, link_a).vel - b = point_jacobian(sys, state.com, state.cdof, contact.pos, link_b).vel - diff = b - a + a = point_jacobian(sys, state.root_com, state.cdof, contact.pos, link_a) + b = point_jacobian(sys, state.root_com, state.cdof, contact.pos, link_b) + diff = b.vel - a.vel # 4 pyramidal friction directions jac = [] diff --git a/brax/generalized/dynamics.py b/brax/generalized/dynamics.py index 781f58af..12c33f8e 100644 --- a/brax/generalized/dynamics.py +++ b/brax/generalized/dynamics.py @@ -12,10 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -# pylint:disable=g-multiple-import +# pylint:disable=g-multiple-import, g-importing-member """Functions for smooth forward and inverse dynamics.""" -import functools - from brax import fluid from brax import math from brax import scan @@ -42,8 +40,8 @@ def transform_com(sys: System, state: State) -> State: mass_xi = jax.vmap(jp.multiply)(sys.link.inertia.mass, x_i.pos) mass_xi_sum = jax.ops.segment_sum(mass_xi, root, sys.num_links()) mass_sum = jax.ops.segment_sum(sys.link.inertia.mass, root, sys.num_links()) - com = jax.vmap(jp.divide)(mass_xi_sum[root], mass_sum[root]) - cinr = x_i.replace(pos=x_i.pos - com).vmap().do(sys.link.inertia) + root_com = jax.vmap(jp.divide)(mass_xi_sum[root], mass_sum[root]) + cinr = x_i.replace(pos=x_i.pos - root_com).vmap().do(sys.link.inertia) # motion dofs to global frame centered at subtree-CoM parent_idx = jp.array( @@ -89,7 +87,8 @@ def cdof_fn(typ, q, motion): cdof = scan.link_types(sys, cdof_fn, 'qd', 'd', state.q, sys.dof.motion) ang = jax.vmap(math.rotate)(cdof.ang, j.take(sys.dof_link()).rot) cdof = cdof.replace(ang=ang) - cdof = Transform.create(pos=com - j.pos).take(sys.dof_link()).vmap().do(cdof) + off = Transform.create(pos=root_com - j.pos) + cdof = off.take(sys.dof_link()).vmap().do(cdof) cdof_qd = jax.vmap(lambda x, y: x * y)(cdof, state.qd) # forward scan down tree: accumulate link center of mass velocity @@ -132,7 +131,9 @@ def cdofd_fn(typ, cd, cdof, cdof_qd): cd_p = cd.concatenate(Motion.zero(shape=(1,))).take(parent_idx) cdofd = scan.link_types(sys, cdofd_fn, 'ldd', 'd', cd_p, cdof, cdof_qd) - return state.replace(com=com, cinr=cinr, cd=cd, cdof=cdof, cdofd=cdofd) + return state.replace( + root_com=root_com, cinr=cinr, cd=cd, cdof=cdof, cdofd=cdofd + ) def inverse(sys: System, state: State) -> jp.ndarray: @@ -194,17 +195,20 @@ def stiffness_fn(typ, q, dof): frc = scan.link_types(sys, stiffness_fn, 'qd', 'd', state.q, sys.dof) frc -= sys.dof.damping * state.qd - fluid_frc, pos = fluid.force( - sys, - state.x, - state.cd, - sys.link.inertia.mass, - sys.link.inertia.i, - state.com, - ) - cdof_fn = functools.partial(point_jacobian, sys, state.com, state.cdof) - jac = jax.vmap(cdof_fn)(pos, jp.arange(sys.num_links())) - frc += jax.vmap(lambda x, y: x.dot(y))(jac, fluid_frc).sum(axis=0) + if sys.enable_fluid: + fluid_frc = fluid.force( + sys, + state.x, + state.cd, + sys.link.inertia.mass, + sys.link.inertia.i, + state.root_com, + ) + link_idx = jp.arange(sys.num_links()) + x_i = state.x.vmap().do(sys.link.inertia.transform) + jac_fn = jax.vmap(point_jacobian, in_axes=(None, None, None, 0, 0)) + jac = jac_fn(sys, state.root_com, state.cdof, x_i.pos, link_idx) + frc += jax.vmap(lambda x, y: x.dot(y))(jac, fluid_frc).sum(axis=0) return frc diff --git a/brax/generalized/dynamics_test.py b/brax/generalized/dynamics_test.py index fb0487cc..d6956f6b 100644 --- a/brax/generalized/dynamics_test.py +++ b/brax/generalized/dynamics_test.py @@ -35,7 +35,9 @@ def test_transform_com(self, xml_file): for mj_prev, mj_next in test_utils.sample_mujoco_states(xml_file): state = jax.jit(pipeline.init)(sys, mj_prev.qpos, mj_prev.qvel) - np.testing.assert_almost_equal(state.com[0], mj_next.subtree_com[0], 5) + np.testing.assert_almost_equal( + state.root_com[0], mj_next.subtree_com[0], 5 + ) mj_cinr_i = np.zeros((state.cinr.i.shape[0], 3, 3)) mj_cinr_i[:, [0, 1, 2], [0, 1, 2]] = mj_next.cinert[1:, 0:3] # diagonal mj_cinr_i[:, [0, 0, 1], [1, 2, 2]] = mj_next.cinert[1:, 3:6] # upper tri diff --git a/brax/io/mjcf.py b/brax/io/mjcf.py index c1f086c4..2eb378dd 100644 --- a/brax/io/mjcf.py +++ b/brax/io/mjcf.py @@ -237,6 +237,8 @@ def load_model(mj: mujoco.MjModel) -> System: # do some validation up front if any(i not in [0, 1] for i in mj.actuator_biastype): raise NotImplementedError('Only actuator_biastype in [0, 1] are supported.') + if any(i != 0 for i in mj.actuator_gaintype): + raise NotImplementedError('Only actuator_gaintype in [0] is supported.') if mj.opt.integrator != 0: raise NotImplementedError('Only euler integration is supported.') if mj.opt.cone != 0: @@ -406,17 +408,21 @@ def load_model(mj: mujoco.MjModel) -> System: # create actuators ctrl_range = mj.actuator_ctrlrange ctrl_range[~(mj.actuator_ctrllimited == 1), :] = np.array([-np.inf, np.inf]) + force_range = mj.actuator_forcerange + force_range[~(mj.actuator_forcelimited == 1), :] = np.array([-np.inf, np.inf]) q_id = np.array([mj.jnt_qposadr[i] for i in mj.actuator_trnid[:, 0]]) qd_id = np.array([mj.jnt_dofadr[i] for i in mj.actuator_trnid[:, 0]]) - bias_q = mj.actuator_biasprm[:, 1] - bias_qd = mj.actuator_biasprm[:, 2] + bias_q = mj.actuator_biasprm[:, 1] * (mj.actuator_biastype != 0) + bias_qd = mj.actuator_biasprm[:, 2] * (mj.actuator_biastype != 0) # TODO: might be nice to add actuator names for debugging actuator = Actuator( # pytype: disable=wrong-arg-types q_id=q_id, qd_id=qd_id, + gain=mj.actuator_gainprm[:, 0], gear=mj.actuator_gear[:, 0], ctrl_range=ctrl_range, + force_range=force_range, bias_q=bias_q, bias_qd=bias_qd, ) @@ -465,6 +471,7 @@ def load_model(mj: mujoco.MjModel) -> System: joint_scale_ang=custom['joint_scale_ang'], joint_scale_pos=custom['joint_scale_pos'], collide_scale=custom['collide_scale'], + enable_fluid=(mj.opt.viscosity > 0) | (mj.opt.density > 0), geom_masks=geom_masks, link_names=link_names, link_types=link_types, diff --git a/brax/positional/base.py b/brax/positional/base.py index 6e237d7e..956503be 100644 --- a/brax/positional/base.py +++ b/brax/positional/base.py @@ -18,6 +18,7 @@ from brax import base from brax.base import Motion, Transform from flax import struct +import jax.numpy as jp @struct.dataclass @@ -31,6 +32,7 @@ class State(base.State): jd: link motion in joint frame a_p: joint parent anchor in world frame a_c: joint child anchor in world frame + mass: link mass """ x_i: Transform @@ -39,3 +41,4 @@ class State(base.State): jd: Motion a_p: Transform a_c: Transform + mass: jp.ndarray diff --git a/brax/positional/joints.py b/brax/positional/joints.py index cf7a490f..f3146e41 100644 --- a/brax/positional/joints.py +++ b/brax/positional/joints.py @@ -27,7 +27,7 @@ from jax.ops import segment_sum -def acceleration_update(sys: System, state: State, tau: jp.ndarray) -> Motion: +def acceleration_update(sys: System, state: State, tau: jp.ndarray) -> Force: """Calculates forces to apply to links resulting from joint constraints. Args: @@ -36,7 +36,7 @@ def acceleration_update(sys: System, state: State, tau: jp.ndarray) -> Motion: tau: joint force vector Returns: - xdd_i: acceleration to apply to link center of mass in world frame + xf_i: force to apply to link center of mass in world frame """ def _free_joint(*_) -> Force: @@ -77,16 +77,7 @@ def j_fn(typ, link, jd, dof, tau): fp = Transform.create(pos=state.a_p.pos - x_i_parent.pos).vmap().do(xf) fp = jax.tree_map(lambda x: segment_sum(x, parent_idx, sys.num_links()), fp) xf_i = fc - fp - - # convert to acceleration - inv_mass = 1 / (sys.link.inertia.mass ** (1 - sys.spring_mass_scale)) - inv_inertia = com.inv_inertia(sys, state.x) - xdd_i = Motion( - ang=jax.vmap(lambda x, y: x @ y)(inv_inertia, xf_i.ang), - vel=jax.vmap(lambda x, y: x * y)(inv_mass, xf_i.vel), - ) - - return xdd_i + return xf_i def position_update(sys: System, state: State) -> Transform: diff --git a/brax/positional/pipeline.py b/brax/positional/pipeline.py index eb2fcdce..729c00ec 100644 --- a/brax/positional/pipeline.py +++ b/brax/positional/pipeline.py @@ -16,6 +16,7 @@ # pylint:disable=g-multiple-import from brax import actuator from brax import com +from brax import fluid from brax import geometry from brax import kinematics from brax.base import Motion, System @@ -23,6 +24,7 @@ from brax.positional import integrator from brax.positional import joints from brax.positional.base import State +import jax from jax import numpy as jp @@ -45,8 +47,8 @@ def init( j, jd, a_p, a_c = kinematics.world_to_joint(sys, x, xd) x_i, xd_i = com.from_world(sys, x, xd) contact = geometry.contact(sys, x) if debug else None - - return State(q, qd, x, xd, contact, x_i, xd_i, j, jd, a_p, a_c) + mass = sys.link.inertia.mass ** (1 - sys.spring_mass_scale) + return State(q, qd, x, xd, contact, x_i, xd_i, j, jd, a_p, a_c, mass) def step( @@ -72,7 +74,15 @@ def step( # calculate acceleration level updates tau = actuator.to_tau(sys, act, state.q, state.qd) xdd_i = Motion.create(vel=sys.gravity) - xdd_i += joints.acceleration_update(sys, state, tau) + # get joint constraint forces + xf_i = joints.acceleration_update(sys, state, tau) + if sys.enable_fluid: + inertia = sys.link.inertia.i ** (1 - sys.spring_inertia_scale) + xf_i += fluid.force(sys, state.x, state.xd, state.mass, inertia) + xdd_i += Motion( + ang=jax.vmap(lambda x, y: x @ y)(com.inv_inertia(sys, state.x), xf_i.ang), + vel=jax.vmap(lambda x, y: x * y)(1 / state.mass, xf_i.vel), + ) # semi-implicit euler: apply acceleration update before resolving collisions x_i, xd_i = integrator.integrate_xdd(sys, state.x_i, state.xd_i, xdd_i) @@ -102,5 +112,4 @@ def step( q, qd = kinematics.inverse(sys, j, jd) contact = geometry.contact(sys, x) if debug else None - return State(q, qd, x, xd, contact, x_i, xd_i, j, jd, a_p, a_c) - + return State(q, qd, x, xd, contact, x_i, xd_i, j, jd, a_p, a_c, state.mass) diff --git a/brax/positional/pipeline_test.py b/brax/positional/pipeline_test.py index 82ed5d6c..50b66d51 100644 --- a/brax/positional/pipeline_test.py +++ b/brax/positional/pipeline_test.py @@ -72,7 +72,8 @@ def test_spherical_pendulum(self): # from generalized and plug it back into pbd # TODO: remove this xd override once kinematics.forward is fixed state_g = g_pipeline.init(sys, init_q, init_qd) - xd = Transform.create(pos=state_g.x.pos - state_g.com).vmap().do(state_g.cd) + off = state_g.x.pos - state_g.root_com + xd = Transform.create(pos=off).vmap().do(state_g.cd) state = state.replace(xd=xd, xd_i=com.from_world(sys, state.x, xd)[1]) j_pos_step = jax.jit(pipeline.step) diff --git a/brax/spring/joints.py b/brax/spring/joints.py index 75e88273..423d3191 100644 --- a/brax/spring/joints.py +++ b/brax/spring/joints.py @@ -309,7 +309,7 @@ def _three_dof( return Force(ang=ang, vel=vel) -def resolve(sys: System, state: State, tau: jp.ndarray) -> Motion: +def resolve(sys: System, state: State, tau: jp.ndarray) -> Force: """Calculates forces to apply to links resulting from joint constraints. Args: @@ -318,7 +318,7 @@ def resolve(sys: System, state: State, tau: jp.ndarray) -> Motion: tau: joint force vector Returns: - xdd_i: acceleration to apply to link center of mass in world frame + xf_i: force to apply to link center of mass in world frame """ def j_fn(typ, link, j, jd, dof, tau): @@ -346,11 +346,4 @@ def j_fn(typ, link, j, jd, dof, tau): fp = Transform.create(pos=state.a_p.pos - x_i_parent.pos).vmap().do(xf) fp = jax.tree_map(lambda x: segment_sum(x, parent_idx, sys.num_links()), fp) xf_i = fc - fp - - # convert to acceleration - xdd_i = Motion( - ang=jax.vmap(lambda x, y: x @ y)(state.i_inv, xf_i.ang), - vel=jax.vmap(lambda x, y: x / y)(xf_i.vel, state.mass), - ) - - return xdd_i + return xf_i diff --git a/brax/spring/pipeline.py b/brax/spring/pipeline.py index 748f76ca..b96ff775 100644 --- a/brax/spring/pipeline.py +++ b/brax/spring/pipeline.py @@ -17,6 +17,7 @@ from brax import actuator from brax import com +from brax import fluid from brax import geometry from brax import kinematics from brax.base import Motion, System @@ -24,6 +25,7 @@ from brax.spring import integrator from brax.spring import joints from brax.spring.base import State +import jax from jax import numpy as jp @@ -88,7 +90,16 @@ def step( # calculate acceleration and delta-velocity terms tau = actuator.to_tau(sys, act, state.q, state.qd) - xdd_i = joints.resolve(sys, state, tau) + Motion.create(vel=sys.gravity) + xdd_i = Motion.create(vel=sys.gravity) + xf_i = joints.resolve(sys, state, tau) + if sys.enable_fluid: + inertia = sys.link.inertia.i ** (1 - sys.spring_inertia_scale) + xf_i += fluid.force(sys, state.x, state.xd, state.mass, inertia) + xdd_i += Motion( + ang=jax.vmap(lambda x, y: x @ y)(state.i_inv, xf_i.ang), + vel=jax.vmap(lambda x, y: x / y)(xf_i.vel, state.mass), + ) + # semi-implicit euler: apply acceleration update before resolving collisions state = state.replace(xd_i=state.xd_i + xdd_i * sys.dt) xdv_i = collisions.resolve(sys, state) diff --git a/brax/spring/pipeline_test.py b/brax/spring/pipeline_test.py index 698107eb..72786d1f 100644 --- a/brax/spring/pipeline_test.py +++ b/brax/spring/pipeline_test.py @@ -108,7 +108,8 @@ def test_spherical_pendulum(self): # from generalized and plug it back into pbd # TODO: remove this xd override once kinematics.forward is fixed state_g = g_pipeline.init(sys, init_q, init_qd) - xd = Transform.create(pos=state_g.x.pos - state_g.com).vmap().do(state_g.cd) + off = state_g.x.pos - state_g.root_com + xd = Transform.create(pos=off).vmap().do(state_g.cd) state = state.replace(xd=xd, xd_i=com.from_world(sys, state.x, xd)[1]) j_spring_step = jax.jit(pipeline.step) diff --git a/brax/test_data/fluid_sphere.xml b/brax/test_data/fluid_sphere.xml new file mode 100644 index 00000000..13d461e3 --- /dev/null +++ b/brax/test_data/fluid_sphere.xml @@ -0,0 +1,15 @@ + + + + + + \ No newline at end of file diff --git a/brax/test_data/fluid_two_spheres.xml b/brax/test_data/fluid_two_spheres.xml new file mode 100644 index 00000000..4e67a745 --- /dev/null +++ b/brax/test_data/fluid_two_spheres.xml @@ -0,0 +1,22 @@ + + + + + + \ No newline at end of file diff --git a/brax/test_data/single_pendulum_position.xml b/brax/test_data/single_pendulum_position.xml index e469969c..533266ae 100644 --- a/brax/test_data/single_pendulum_position.xml +++ b/brax/test_data/single_pendulum_position.xml @@ -15,6 +15,6 @@ - + diff --git a/brax/test_data/single_pendulum_position_frclimit.xml b/brax/test_data/single_pendulum_position_frclimit.xml new file mode 100644 index 00000000..7a38e619 --- /dev/null +++ b/brax/test_data/single_pendulum_position_frclimit.xml @@ -0,0 +1,20 @@ + + + diff --git a/brax/test_data/single_pendulum_velocity.xml b/brax/test_data/single_pendulum_velocity.xml index 5eafb9ec..16692bec 100644 --- a/brax/test_data/single_pendulum_velocity.xml +++ b/brax/test_data/single_pendulum_velocity.xml @@ -15,6 +15,6 @@ - + diff --git a/brax/training/agents/apg/train.py b/brax/training/agents/apg/train.py index 196e8752..8a96c8ba 100644 --- a/brax/training/agents/apg/train.py +++ b/brax/training/agents/apg/train.py @@ -224,6 +224,7 @@ def training_epoch_with_timing(training_state: TrainingState, key=eval_key) # Run initial eval + metrics = {} if process_id == 0 and num_evals > 1: metrics = evaluator.run_evaluation( _unpmap( diff --git a/brax/training/agents/ppo/train.py b/brax/training/agents/ppo/train.py index 0284f2cc..b51414dd 100644 --- a/brax/training/agents/ppo/train.py +++ b/brax/training/agents/ppo/train.py @@ -312,6 +312,7 @@ def training_epoch_with_timing( key=eval_key) # Run initial eval + metrics = {} if process_id == 0 and num_evals > 1: metrics = evaluator.run_evaluation( _unpmap( diff --git a/brax/training/agents/sac/train.py b/brax/training/agents/sac/train.py index b4b4d9f6..23c7df0f 100644 --- a/brax/training/agents/sac/train.py +++ b/brax/training/agents/sac/train.py @@ -424,6 +424,7 @@ def training_epoch_with_timing( key=eval_key) # Run initial eval + metrics = {} if process_id == 0 and num_evals > 1: metrics = evaluator.run_evaluation( _unpmap( diff --git a/brax/training/replay_buffers.py b/brax/training/replay_buffers.py index 1fe18b6d..0223b9fe 100644 --- a/brax/training/replay_buffers.py +++ b/brax/training/replay_buffers.py @@ -15,7 +15,8 @@ """Replay buffers for Brax.""" import abc -from typing import Generic, Optional, Tuple, TypeVar +import math +from typing import Generic, Optional, Sequence, Tuple, TypeVar from brax.training.types import PRNGKey import flax @@ -362,21 +363,19 @@ def __init__( self, buffer: ReplayBuffer[State, Sample], mesh: jax.sharding.Mesh, - axis_name: str, - batch_partition_spec: Optional[jax.sharding.PartitionSpec] = None, + axis_names: Sequence[str], ): """Constructor. Args: buffer: The buffer to replicate. mesh: Device mesh for pjitting context. - axis_name: The axis along which the replay buffer data should be + axis_names: The axes along which the replay buffer data should be partitionned. - batch_partition_spec: PartitionSpec of the inserted/sampled batch. """ self._buffer = buffer self._mesh = mesh - self._num_devices = mesh.shape[axis_name] + self._num_devices = math.prod(mesh.shape[name] for name in axis_names) def init(key: PRNGKey) -> State: keys = jax.random.split(key, self._num_devices) @@ -405,7 +404,7 @@ def sample(buffer_state: State) -> Tuple[State, Sample]: def size(buffer_state: State) -> int: return jnp.sum(jax.vmap(self._buffer.size)(buffer_state)) - partition_spec = jax.sharding.PartitionSpec((axis_name,)) + partition_spec = jax.sharding.PartitionSpec((axis_names),) self._partitioned_init = pjit.pjit(init, out_shardings=partition_spec) self._partitioned_insert = pjit.pjit( insert, @@ -413,7 +412,7 @@ def size(buffer_state: State) -> int: ) self._partitioned_sample = pjit.pjit( sample, - out_shardings=(partition_spec, batch_partition_spec), + out_shardings=partition_spec, ) # This will return the TOTAL size across all devices. self._partitioned_size = pjit.pjit(size, out_shardings=None) @@ -454,37 +453,3 @@ class PrimitiveReplayBufferState(Generic[Sample]): """The state of the primitive replay buffer.""" samples: Optional[Sample] = None - - -class PrimitiveReplayBuffer( - ReplayBuffer[PrimitiveReplayBufferState[Sample], Sample], Generic[Sample] -): - """A primitive queue that can contain at most one batch of samples.""" - - def init(self, key: PRNGKey) -> PrimitiveReplayBufferState[Sample]: - """Init the replay buffer.""" - return PrimitiveReplayBufferState(samples=None) - - def insert_internal( - self, buffer_state: PrimitiveReplayBufferState[Sample], samples: Sample - ) -> PrimitiveReplayBufferState[Sample]: - """Insert data in the replay buffer.""" - if buffer_state.samples is not None: - raise ValueError('The buffer is full') - return PrimitiveReplayBufferState(samples=samples) - - def sample_internal( - self, buffer_state: PrimitiveReplayBufferState[Sample] - ) -> Tuple[PrimitiveReplayBufferState[Sample], Sample]: - """Sample a batch of data.""" - if buffer_state.samples is None: - raise ValueError('The buffer is empty') - return PrimitiveReplayBufferState(samples=None), buffer_state.samples - - def size(self, buffer_state: PrimitiveReplayBufferState[Sample]) -> int: - """Return the total amount of elements that are sampleable.""" - return ( - jax.tree_flatten(buffer_state.samples)[0][0].shape[0] - if buffer_state.samples is not None - else 0 - ) diff --git a/brax/training/replay_buffers_test.py b/brax/training/replay_buffers_test.py index 9aa11cc0..526a40a6 100644 --- a/brax/training/replay_buffers_test.py +++ b/brax/training/replay_buffers_test.py @@ -63,7 +63,7 @@ def pjit_wrap(buffer): return replay_buffers.PjitWrapper( buffer, mesh=get_mesh(), - axis_name=AXIS_NAME, + axis_names=(AXIS_NAME,), ) @@ -643,62 +643,5 @@ def testQueueSampleFromEmpty(self, wrapper) -> None: buffer_state, samples = replay_buffer.sample(buffer_state) -class PrimitiveReplayBufferTest(parameterized.TestCase): - - @parameterized.parameters(WRAPPERS) - def testInsert(self, wrapper): - replay_buffer = wrapper(replay_buffers.PrimitiveReplayBuffer()) - rng = jax.random.PRNGKey(0) - buffer_state = replay_buffer.init(rng) - if wrapper not in [pjit_wrap, pmap_wrap]: - assert_equal(self, replay_buffer.size(buffer_state), 0) - buffer_state = replay_buffer.insert(buffer_state, get_dummy_batch(8)) - assert_equal(self, replay_buffer.size(buffer_state), 8) - - @parameterized.parameters(WRAPPERS) - def testInsertWhenFull(self, wrapper): - replay_buffer = wrapper(replay_buffers.PrimitiveReplayBuffer()) - rng = jax.random.PRNGKey(0) - buffer_state = replay_buffer.init(rng) - if wrapper not in [pjit_wrap, pmap_wrap]: - assert_equal(self, replay_buffer.size(buffer_state), 0) - - buffer_state = replay_buffer.insert(buffer_state, get_dummy_batch(8)) - with self.assertRaises(ValueError): - buffer_state = replay_buffer.insert(buffer_state, get_dummy_batch(8)) - assert_equal(self, replay_buffer.size(buffer_state), 8) - - @parameterized.parameters(WRAPPERS) - def testSample(self, wrapper): - replay_buffer = wrapper(replay_buffers.PrimitiveReplayBuffer()) - rng = jax.random.PRNGKey(0) - buffer_state = replay_buffer.init(rng) - if wrapper not in [pjit_wrap, pmap_wrap]: - assert_equal(self, replay_buffer.size(buffer_state), 0) - - buffer_state = replay_buffer.insert(buffer_state, get_dummy_batch(8)) - assert_equal(self, replay_buffer.size(buffer_state), 8) - buffer_state, samples = replay_buffer.sample(buffer_state) - if wrapper not in [pjit_wrap, pmap_wrap]: - assert_equal(self, replay_buffer.size(buffer_state), 0) - assert_equal(self, samples['a'].shape, (8,)) - assert_equal(self, samples['b'].shape, (8, 5, 5)) - for sample in samples['b']: - assert_equal( - self, jnp.reshape(sample - sample[0, 0], (-1,)), range(5 * 5) - ) - - @parameterized.parameters(WRAPPERS) - def testSampleWhenEmpty(self, wrapper): - replay_buffer = wrapper(replay_buffers.PrimitiveReplayBuffer()) - rng = jax.random.PRNGKey(0) - buffer_state = replay_buffer.init(rng) - if wrapper not in [pjit_wrap, pmap_wrap]: - assert_equal(self, replay_buffer.size(buffer_state), 0) - - with self.assertRaises(ValueError): - _, _ = replay_buffer.sample(buffer_state) - - if __name__ == '__main__': absltest.main() diff --git a/docs/release-notes/v0.9.1.md b/docs/release-notes/v0.9.1.md new file mode 100644 index 00000000..54c7bc4c --- /dev/null +++ b/docs/release-notes/v0.9.1.md @@ -0,0 +1,8 @@ +# Brax v0.9.1 Release Notes + +This patch release includes: +* Add support for positional actuators. +* Add fluid viscosity + density via box model. +* Adds cylinder collider (but only for wafer-thin cylinders) +* Bring back dm_env and torch env wrappers. +* Bring back image rendering via pytinyrenderer. diff --git a/setup.py b/setup.py index 4f6f0fec..a8a60b89 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,7 @@ setup( name="brax", - version="0.9.0", + version="0.9.1", description=("A differentiable physics engine written in JAX."), author="Brax Authors", author_email="no-reply@google.com",