diff --git a/README.md b/README.md
index 6f0d94a9..633aa75a 100644
--- a/README.md
+++ b/README.md
@@ -97,7 +97,7 @@ If you would like to reference Brax in a publication, please use:
author = {C. Daniel Freeman and Erik Frey and Anton Raichuk and Sertan Girgin and Igor Mordatch and Olivier Bachem},
title = {Brax - A Differentiable Physics Engine for Large Scale Rigid Body Simulation},
url = {http://github.com/google/brax},
- version = {0.9.0},
+ version = {0.9.1},
year = {2021},
}
```
diff --git a/brax/__init__.py b/brax/__init__.py
index 98e974b3..a873cdca 100644
--- a/brax/__init__.py
+++ b/brax/__init__.py
@@ -14,7 +14,7 @@
"""Import top-level classes and functions here for encapsulation/clarity."""
-__version__ = '0.9.0'
+__version__ = '0.9.1'
from brax.base import Motion
from brax.base import State
diff --git a/brax/actuator.py b/brax/actuator.py
index 39291c97..310a8f34 100644
--- a/brax/actuator.py
+++ b/brax/actuator.py
@@ -36,12 +36,21 @@ def to_tau(
if sys.act_size() == 0:
return jp.zeros(sys.qd_size())
- q, qd = q[sys.actuator.q_id], qd[sys.actuator.qd_id]
ctrl_range = sys.actuator.ctrl_range
+ force_range = sys.actuator.force_range
+
+ q, qd = q[sys.actuator.q_id], qd[sys.actuator.qd_id]
act = jp.clip(act, ctrl_range[:, 0], ctrl_range[:, 1])
- # TODO: incorporate gain
- act = act + q * sys.actuator.bias_q + qd * sys.actuator.bias_qd
- act_force = sys.actuator.gear * act
- tau = jp.zeros(sys.qd_size()).at[sys.actuator.qd_id].add(act_force)
+ # See https://github.com/deepmind/mujoco/discussions/754 for why gear is
+ # used for the bias term.
+ bias = sys.actuator.gear * (
+ q * sys.actuator.bias_q + qd * sys.actuator.bias_qd
+ )
+
+ force = sys.actuator.gain * act + bias
+ force = jp.clip(force, force_range[:, 0], force_range[:, 1])
+
+ force *= sys.actuator.gear
+ tau = jp.zeros(sys.qd_size()).at[sys.actuator.qd_id].add(force)
return tau
diff --git a/brax/actuator_test.py b/brax/actuator_test.py
index a5089512..4026c5ee 100644
--- a/brax/actuator_test.py
+++ b/brax/actuator_test.py
@@ -101,6 +101,25 @@ def test_velocity(self):
_, qd = _actuator_step(g_pipeline, sys, q, qd, act=act, dt=0.001, n=200)
np.testing.assert_array_almost_equal(qd, jp.array([1]), 3)
+ def test_force_limitted(self):
+ """Tests that forcerange limits work on actuators."""
+ sys = test_utils.load_fixture('single_pendulum_position_frclimit.xml')
+ mj_model = test_utils.load_fixture_mujoco(
+ 'single_pendulum_position_frclimit.xml'
+ )
+ mj_data = mujoco.MjData(mj_model)
+ q, qd = jp.array(mj_data.qpos), jp.array(mj_data.qvel)
+
+ for act, frclimit in [(1000, 3.1), (-1000, -2.5)]:
+ act = jp.array([act])
+ tau = actuator.to_tau(sys, act, q, qd)
+ # test that tau matches frclimit * 10, since gear=10
+ self.assertEqual(tau[0], frclimit * 10)
+ # test that tau matches MJ qfrc_actuator
+ mj_data.ctrl = act
+ mujoco.mj_step(mj_model, mj_data)
+ self.assertEqual(tau[0], mj_data.qfrc_actuator)
+
@parameterized.parameters((g_pipeline,), (s_pipeline,), (p_pipeline,))
def test_three_link_pendulum(self, pipeline):
"""Tests a three link pendulum with a motor actuator."""
@@ -148,14 +167,17 @@ def test_spherical_pendulum_mj_generalized(self, config):
np.testing.assert_array_almost_equal(gq, mq, 3)
np.testing.assert_array_almost_equal(gqd, mqd, 3)
- # TODO: test spherical pendulum once it's implemented in positional
@parameterized.parameters(
- ('single_pendulum_position.xml',), 'single_pendulum_motor.xml'
+ 'single_pendulum_position.xml',
+ 'single_pendulum_motor.xml',
+ 'single_spherical_pendulum_position.xml',
)
def test_single_pendulum_spring_positional(self, config):
sys = test_utils.load_fixture(config)
act = jp.array([0.05, 0.1, 0.15])[: sys.act_size()]
+
q, qd = sys.init_q, jp.zeros(sys.qd_size())
+
sq, sqd = _actuator_step(s_pipeline, sys, q, qd, act=act, dt=sys.dt, n=500)
pq, pqd = _actuator_step(p_pipeline, sys, q, qd, act=act, dt=sys.dt, n=500)
np.testing.assert_array_almost_equal(sq, pq, 2)
diff --git a/brax/base.py b/brax/base.py
index f3c78c02..43269693 100644
--- a/brax/base.py
+++ b/brax/base.py
@@ -453,7 +453,9 @@ class Actuator(Base):
q_id: (num_actuators,) q index associated with an actuator
qd_id: (num_actuators,) qd index associated with an actuator
ctrl_range: (num_actuators, 2) actuator control range
- gear: (num_actuators,) scaling factor for each actuator torque output
+ force_range: (num_actuators, 2) actuator force range
+ gain: (num_actuators,) scaling factor for each actuator control input
+ gear: (num_actuators,) scaling factor for each actuator force output
bias_q: (num_actuators,) bias applied by q (e.g. position actuators)
bias_qd: (num_actuators,) bias applied by qd (e.g. velocity actuators)
"""
@@ -461,6 +463,8 @@ class Actuator(Base):
q_id: jp.ndarray
qd_id: jp.ndarray
ctrl_range: jp.ndarray
+ force_range: jp.ndarray
+ gain: jp.ndarray
gear: jp.ndarray
bias_q: jp.ndarray
bias_qd: jp.ndarray
@@ -508,6 +512,8 @@ class System:
joint_scale_ang: scale for position-based joint rotation update
joint_scale_pos: scale for position-based joint position update
collide_scale: fraction of position based collide update to apply
+ enable_fluid: (1,) enables or disables fluid forces based on the
+ default viscosity and density parameters provided in the XML
geom_masks: 64-bit mask determines whether two geoms will be contact tested.
lower 32 bits are type, upper 32 bits are affinity. two geoms
a, b will be contact tested if a.type & b.affinity != 0
@@ -545,6 +551,7 @@ class System:
joint_scale_pos: jp.float32
collide_scale: jp.float32
# non-pytree nodes
+ enable_fluid: bool = struct.field(pytree_node=False)
geom_masks: List[int] = struct.field(pytree_node=False)
link_names: List[str] = struct.field(pytree_node=False)
link_types: str = struct.field(pytree_node=False)
diff --git a/brax/fluid.py b/brax/fluid.py
index b05e4e88..88d76250 100644
--- a/brax/fluid.py
+++ b/brax/fluid.py
@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-# pylint:disable=g-multiple-import
+# pylint:disable=g-multiple-import, g-importing-member
"""Functions for forces/torques through fluids."""
-from typing import Tuple, Union
+from typing import Union
from brax.base import Force, Motion, System, Transform
import jax
import jax.numpy as jp
@@ -58,13 +58,13 @@ def force(
xd: Motion,
mass: jp.ndarray,
inertia: jp.ndarray,
- subtree_com: Union[jp.ndarray, None] = None,
-) -> Tuple[Force, jp.ndarray]:
+ root_com: Union[jp.ndarray, None] = None,
+) -> Force:
"""Returns force due to motion through a fluid."""
# get the velocity at the com position/orientation
x_i = x.vmap().do(sys.link.inertia.transform)
- # TODO: remove subtree_com when xd is fixed for stacked joints
- offset = x_i.pos - x.pos if subtree_com is None else x_i.pos - subtree_com
+ # TODO: remove root_com when xd is fixed for stacked joints
+ offset = x_i.pos - x.pos if root_com is None else x_i.pos - root_com
xd_i = x_i.replace(pos=offset).vmap().do(xd)
# TODO: add ellipsoid fluid model from mujoco
@@ -80,4 +80,5 @@ def force(
# rotate back to the world orientation
frc = Transform.create(rot=x_i.rot).vmap().do(frc)
- return frc, x_i.pos
+
+ return frc
diff --git a/brax/fluid_test.py b/brax/fluid_test.py
index 5a9e54df..92b0a72f 100644
--- a/brax/fluid_test.py
+++ b/brax/fluid_test.py
@@ -20,6 +20,8 @@
from brax import test_utils
from brax.generalized import dynamics
from brax.generalized import pipeline as g_pipeline
+from brax.positional import pipeline as p_pipeline
+from brax.spring import pipeline as s_pipeline
import jax
from jax import numpy as jp
import mujoco
@@ -77,6 +79,42 @@ def test_fluid_mj_generalized(self, config, density, viscosity):
np.testing.assert_array_almost_equal(gq, mq, 2)
np.testing.assert_array_almost_equal(gqd, mqd, 2)
+ @parameterized.parameters(
+ ('fluid_sphere.xml', p_pipeline, 0.0, 0.0),
+ ('fluid_sphere.xml', p_pipeline, 0.0, 1.3),
+ ('fluid_sphere.xml', p_pipeline, 2.0, 0.0),
+ ('fluid_sphere.xml', p_pipeline, 2.0, 1.3),
+ ('fluid_sphere.xml', s_pipeline, 0.0, 0.0),
+ ('fluid_sphere.xml', s_pipeline, 0.0, 1.3),
+ ('fluid_sphere.xml', s_pipeline, 2.0, 0.0),
+ ('fluid_sphere.xml', s_pipeline, 2.0, 1.3),
+ ('fluid_two_spheres.xml', p_pipeline, 2.0, 1.3),
+ ('fluid_two_spheres.xml', s_pipeline, 2.0, 1.3),
+ )
+ def test_fluid_positional_spring(self, config, pipeline, density, viscosity):
+ """Tests fluid interactions for pbd/spring compared to generalized."""
+ sys = test_utils.load_fixture(config)
+ sys = sys.replace(density=density, viscosity=viscosity)
+
+ q, qd = sys.init_q, jp.zeros(sys.qd_size())
+ qd = qd.at[:3].set(jp.array([1, 2, 4]))
+ qd = qd.at[3].set(jp.pi)
+
+ state_g = jax.jit(g_pipeline.init)(sys, q, qd)
+ for _ in range(500):
+ state_g = jax.jit(g_pipeline.step)(
+ sys, state_g, jp.zeros((sys.act_size(),))
+ )
+ gq, gqd = state_g.q, state_g.qd
+
+ state = jax.jit(pipeline.init)(sys, q, qd)
+ for _ in range(500):
+ state = jax.jit(pipeline.step)(sys, state, jp.zeros((sys.act_size(),)))
+ tq, tqd = state.q, state.qd
+
+ np.testing.assert_array_almost_equal(tq, gq, 3)
+ np.testing.assert_array_almost_equal(tqd, gqd, 2)
+
if __name__ == '__main__':
absltest.main()
diff --git a/brax/generalized/base.py b/brax/generalized/base.py
index cce99fc8..83eb4a53 100644
--- a/brax/generalized/base.py
+++ b/brax/generalized/base.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-# pylint:disable=g-multiple-import
+# pylint:disable=g-multiple-import, g-importing-member
"""Base types for generalized pipeline."""
from brax import base
@@ -26,24 +26,24 @@ class State(base.State):
"""Dynamic state that changes after every step.
Attributes:
- com: center of mass position of the kinematic tree containing the link
- cinr: inertia in com frame
- cd: body velocities in com frame
- cdof: dofs in com frame
- cdofd: cdof velocity
+ root_com: (num_links,) center of mass position of link root kinematic tree
+ cinr: (num_links,) inertia in com frame
+ cd: (num_links,) link velocities in com frame
+ cdof: (qd_size,) dofs in com frame
+ cdofd: (qd_size,) cdof velocity
mass_mx: (qd_size, qd_size) mass matrix
mass_mx_inv: (qd_size, qd_size) inverse mass matrix
contact: calculated contacts
con_jac: constraint jacobian
con_diag: constraint A diagonal
con_aref: constraint reference acceleration
- qf_smooth: smooth dynamics force
+ qf_smooth: (qd_size,) smooth dynamics force
qf_constraint: (qd_size,) force from constraints (collision etc)
qdd: (qd_size,) joint acceleration vector
"""
# position/velocity based terms are updated at the end of each step:
- com: jp.ndarray
+ root_com: jp.ndarray
cinr: Inertia
cd: Motion
cdof: Motion
@@ -71,7 +71,7 @@ def init(
x=x,
xd=xd,
contact=None,
- com=jp.zeros(3),
+ root_com=jp.zeros(3),
cinr=Inertia(
Transform.zero((num_links,)),
jp.zeros((num_links, 3, 3)),
diff --git a/brax/generalized/constraint.py b/brax/generalized/constraint.py
index be59d59a..50794c63 100644
--- a/brax/generalized/constraint.py
+++ b/brax/generalized/constraint.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-# pylint:disable=g-multiple-import
+# pylint:disable=g-multiple-import, g-importing-member
"""Functions for constraint satisfaction."""
from typing import Tuple
@@ -57,9 +57,8 @@ def _imp_aref(
# See https://mujoco.readthedocs.io/en/latest/modeling.html#solver-parameters
stiffness, damping = params[:2]
- use_v2_params = (stiffness < 0) & (damping < 0)
- b = jp.where(use_v2_params, -damping / dmax, b)
- k = jp.where(use_v2_params, -stiffness / (dmax * dmax), k)
+ b = jp.where(damping <= 0, -damping / dmax, b)
+ k = jp.where(stiffness <= 0, -stiffness / (dmax * dmax), k)
aref = -b * vel - k * imp * pos
@@ -153,9 +152,9 @@ def jac_contact(
def row_fn(contact):
link_a, link_b = contact.link_idx
- a = point_jacobian(sys, state.com, state.cdof, contact.pos, link_a).vel
- b = point_jacobian(sys, state.com, state.cdof, contact.pos, link_b).vel
- diff = b - a
+ a = point_jacobian(sys, state.root_com, state.cdof, contact.pos, link_a)
+ b = point_jacobian(sys, state.root_com, state.cdof, contact.pos, link_b)
+ diff = b.vel - a.vel
# 4 pyramidal friction directions
jac = []
diff --git a/brax/generalized/dynamics.py b/brax/generalized/dynamics.py
index 781f58af..12c33f8e 100644
--- a/brax/generalized/dynamics.py
+++ b/brax/generalized/dynamics.py
@@ -12,10 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-# pylint:disable=g-multiple-import
+# pylint:disable=g-multiple-import, g-importing-member
"""Functions for smooth forward and inverse dynamics."""
-import functools
-
from brax import fluid
from brax import math
from brax import scan
@@ -42,8 +40,8 @@ def transform_com(sys: System, state: State) -> State:
mass_xi = jax.vmap(jp.multiply)(sys.link.inertia.mass, x_i.pos)
mass_xi_sum = jax.ops.segment_sum(mass_xi, root, sys.num_links())
mass_sum = jax.ops.segment_sum(sys.link.inertia.mass, root, sys.num_links())
- com = jax.vmap(jp.divide)(mass_xi_sum[root], mass_sum[root])
- cinr = x_i.replace(pos=x_i.pos - com).vmap().do(sys.link.inertia)
+ root_com = jax.vmap(jp.divide)(mass_xi_sum[root], mass_sum[root])
+ cinr = x_i.replace(pos=x_i.pos - root_com).vmap().do(sys.link.inertia)
# motion dofs to global frame centered at subtree-CoM
parent_idx = jp.array(
@@ -89,7 +87,8 @@ def cdof_fn(typ, q, motion):
cdof = scan.link_types(sys, cdof_fn, 'qd', 'd', state.q, sys.dof.motion)
ang = jax.vmap(math.rotate)(cdof.ang, j.take(sys.dof_link()).rot)
cdof = cdof.replace(ang=ang)
- cdof = Transform.create(pos=com - j.pos).take(sys.dof_link()).vmap().do(cdof)
+ off = Transform.create(pos=root_com - j.pos)
+ cdof = off.take(sys.dof_link()).vmap().do(cdof)
cdof_qd = jax.vmap(lambda x, y: x * y)(cdof, state.qd)
# forward scan down tree: accumulate link center of mass velocity
@@ -132,7 +131,9 @@ def cdofd_fn(typ, cd, cdof, cdof_qd):
cd_p = cd.concatenate(Motion.zero(shape=(1,))).take(parent_idx)
cdofd = scan.link_types(sys, cdofd_fn, 'ldd', 'd', cd_p, cdof, cdof_qd)
- return state.replace(com=com, cinr=cinr, cd=cd, cdof=cdof, cdofd=cdofd)
+ return state.replace(
+ root_com=root_com, cinr=cinr, cd=cd, cdof=cdof, cdofd=cdofd
+ )
def inverse(sys: System, state: State) -> jp.ndarray:
@@ -194,17 +195,20 @@ def stiffness_fn(typ, q, dof):
frc = scan.link_types(sys, stiffness_fn, 'qd', 'd', state.q, sys.dof)
frc -= sys.dof.damping * state.qd
- fluid_frc, pos = fluid.force(
- sys,
- state.x,
- state.cd,
- sys.link.inertia.mass,
- sys.link.inertia.i,
- state.com,
- )
- cdof_fn = functools.partial(point_jacobian, sys, state.com, state.cdof)
- jac = jax.vmap(cdof_fn)(pos, jp.arange(sys.num_links()))
- frc += jax.vmap(lambda x, y: x.dot(y))(jac, fluid_frc).sum(axis=0)
+ if sys.enable_fluid:
+ fluid_frc = fluid.force(
+ sys,
+ state.x,
+ state.cd,
+ sys.link.inertia.mass,
+ sys.link.inertia.i,
+ state.root_com,
+ )
+ link_idx = jp.arange(sys.num_links())
+ x_i = state.x.vmap().do(sys.link.inertia.transform)
+ jac_fn = jax.vmap(point_jacobian, in_axes=(None, None, None, 0, 0))
+ jac = jac_fn(sys, state.root_com, state.cdof, x_i.pos, link_idx)
+ frc += jax.vmap(lambda x, y: x.dot(y))(jac, fluid_frc).sum(axis=0)
return frc
diff --git a/brax/generalized/dynamics_test.py b/brax/generalized/dynamics_test.py
index fb0487cc..d6956f6b 100644
--- a/brax/generalized/dynamics_test.py
+++ b/brax/generalized/dynamics_test.py
@@ -35,7 +35,9 @@ def test_transform_com(self, xml_file):
for mj_prev, mj_next in test_utils.sample_mujoco_states(xml_file):
state = jax.jit(pipeline.init)(sys, mj_prev.qpos, mj_prev.qvel)
- np.testing.assert_almost_equal(state.com[0], mj_next.subtree_com[0], 5)
+ np.testing.assert_almost_equal(
+ state.root_com[0], mj_next.subtree_com[0], 5
+ )
mj_cinr_i = np.zeros((state.cinr.i.shape[0], 3, 3))
mj_cinr_i[:, [0, 1, 2], [0, 1, 2]] = mj_next.cinert[1:, 0:3] # diagonal
mj_cinr_i[:, [0, 0, 1], [1, 2, 2]] = mj_next.cinert[1:, 3:6] # upper tri
diff --git a/brax/io/mjcf.py b/brax/io/mjcf.py
index c1f086c4..2eb378dd 100644
--- a/brax/io/mjcf.py
+++ b/brax/io/mjcf.py
@@ -237,6 +237,8 @@ def load_model(mj: mujoco.MjModel) -> System:
# do some validation up front
if any(i not in [0, 1] for i in mj.actuator_biastype):
raise NotImplementedError('Only actuator_biastype in [0, 1] are supported.')
+ if any(i != 0 for i in mj.actuator_gaintype):
+ raise NotImplementedError('Only actuator_gaintype in [0] is supported.')
if mj.opt.integrator != 0:
raise NotImplementedError('Only euler integration is supported.')
if mj.opt.cone != 0:
@@ -406,17 +408,21 @@ def load_model(mj: mujoco.MjModel) -> System:
# create actuators
ctrl_range = mj.actuator_ctrlrange
ctrl_range[~(mj.actuator_ctrllimited == 1), :] = np.array([-np.inf, np.inf])
+ force_range = mj.actuator_forcerange
+ force_range[~(mj.actuator_forcelimited == 1), :] = np.array([-np.inf, np.inf])
q_id = np.array([mj.jnt_qposadr[i] for i in mj.actuator_trnid[:, 0]])
qd_id = np.array([mj.jnt_dofadr[i] for i in mj.actuator_trnid[:, 0]])
- bias_q = mj.actuator_biasprm[:, 1]
- bias_qd = mj.actuator_biasprm[:, 2]
+ bias_q = mj.actuator_biasprm[:, 1] * (mj.actuator_biastype != 0)
+ bias_qd = mj.actuator_biasprm[:, 2] * (mj.actuator_biastype != 0)
# TODO: might be nice to add actuator names for debugging
actuator = Actuator( # pytype: disable=wrong-arg-types
q_id=q_id,
qd_id=qd_id,
+ gain=mj.actuator_gainprm[:, 0],
gear=mj.actuator_gear[:, 0],
ctrl_range=ctrl_range,
+ force_range=force_range,
bias_q=bias_q,
bias_qd=bias_qd,
)
@@ -465,6 +471,7 @@ def load_model(mj: mujoco.MjModel) -> System:
joint_scale_ang=custom['joint_scale_ang'],
joint_scale_pos=custom['joint_scale_pos'],
collide_scale=custom['collide_scale'],
+ enable_fluid=(mj.opt.viscosity > 0) | (mj.opt.density > 0),
geom_masks=geom_masks,
link_names=link_names,
link_types=link_types,
diff --git a/brax/positional/base.py b/brax/positional/base.py
index 6e237d7e..956503be 100644
--- a/brax/positional/base.py
+++ b/brax/positional/base.py
@@ -18,6 +18,7 @@
from brax import base
from brax.base import Motion, Transform
from flax import struct
+import jax.numpy as jp
@struct.dataclass
@@ -31,6 +32,7 @@ class State(base.State):
jd: link motion in joint frame
a_p: joint parent anchor in world frame
a_c: joint child anchor in world frame
+ mass: link mass
"""
x_i: Transform
@@ -39,3 +41,4 @@ class State(base.State):
jd: Motion
a_p: Transform
a_c: Transform
+ mass: jp.ndarray
diff --git a/brax/positional/joints.py b/brax/positional/joints.py
index cf7a490f..f3146e41 100644
--- a/brax/positional/joints.py
+++ b/brax/positional/joints.py
@@ -27,7 +27,7 @@
from jax.ops import segment_sum
-def acceleration_update(sys: System, state: State, tau: jp.ndarray) -> Motion:
+def acceleration_update(sys: System, state: State, tau: jp.ndarray) -> Force:
"""Calculates forces to apply to links resulting from joint constraints.
Args:
@@ -36,7 +36,7 @@ def acceleration_update(sys: System, state: State, tau: jp.ndarray) -> Motion:
tau: joint force vector
Returns:
- xdd_i: acceleration to apply to link center of mass in world frame
+ xf_i: force to apply to link center of mass in world frame
"""
def _free_joint(*_) -> Force:
@@ -77,16 +77,7 @@ def j_fn(typ, link, jd, dof, tau):
fp = Transform.create(pos=state.a_p.pos - x_i_parent.pos).vmap().do(xf)
fp = jax.tree_map(lambda x: segment_sum(x, parent_idx, sys.num_links()), fp)
xf_i = fc - fp
-
- # convert to acceleration
- inv_mass = 1 / (sys.link.inertia.mass ** (1 - sys.spring_mass_scale))
- inv_inertia = com.inv_inertia(sys, state.x)
- xdd_i = Motion(
- ang=jax.vmap(lambda x, y: x @ y)(inv_inertia, xf_i.ang),
- vel=jax.vmap(lambda x, y: x * y)(inv_mass, xf_i.vel),
- )
-
- return xdd_i
+ return xf_i
def position_update(sys: System, state: State) -> Transform:
diff --git a/brax/positional/pipeline.py b/brax/positional/pipeline.py
index eb2fcdce..729c00ec 100644
--- a/brax/positional/pipeline.py
+++ b/brax/positional/pipeline.py
@@ -16,6 +16,7 @@
# pylint:disable=g-multiple-import
from brax import actuator
from brax import com
+from brax import fluid
from brax import geometry
from brax import kinematics
from brax.base import Motion, System
@@ -23,6 +24,7 @@
from brax.positional import integrator
from brax.positional import joints
from brax.positional.base import State
+import jax
from jax import numpy as jp
@@ -45,8 +47,8 @@ def init(
j, jd, a_p, a_c = kinematics.world_to_joint(sys, x, xd)
x_i, xd_i = com.from_world(sys, x, xd)
contact = geometry.contact(sys, x) if debug else None
-
- return State(q, qd, x, xd, contact, x_i, xd_i, j, jd, a_p, a_c)
+ mass = sys.link.inertia.mass ** (1 - sys.spring_mass_scale)
+ return State(q, qd, x, xd, contact, x_i, xd_i, j, jd, a_p, a_c, mass)
def step(
@@ -72,7 +74,15 @@ def step(
# calculate acceleration level updates
tau = actuator.to_tau(sys, act, state.q, state.qd)
xdd_i = Motion.create(vel=sys.gravity)
- xdd_i += joints.acceleration_update(sys, state, tau)
+ # get joint constraint forces
+ xf_i = joints.acceleration_update(sys, state, tau)
+ if sys.enable_fluid:
+ inertia = sys.link.inertia.i ** (1 - sys.spring_inertia_scale)
+ xf_i += fluid.force(sys, state.x, state.xd, state.mass, inertia)
+ xdd_i += Motion(
+ ang=jax.vmap(lambda x, y: x @ y)(com.inv_inertia(sys, state.x), xf_i.ang),
+ vel=jax.vmap(lambda x, y: x * y)(1 / state.mass, xf_i.vel),
+ )
# semi-implicit euler: apply acceleration update before resolving collisions
x_i, xd_i = integrator.integrate_xdd(sys, state.x_i, state.xd_i, xdd_i)
@@ -102,5 +112,4 @@ def step(
q, qd = kinematics.inverse(sys, j, jd)
contact = geometry.contact(sys, x) if debug else None
- return State(q, qd, x, xd, contact, x_i, xd_i, j, jd, a_p, a_c)
-
+ return State(q, qd, x, xd, contact, x_i, xd_i, j, jd, a_p, a_c, state.mass)
diff --git a/brax/positional/pipeline_test.py b/brax/positional/pipeline_test.py
index 82ed5d6c..50b66d51 100644
--- a/brax/positional/pipeline_test.py
+++ b/brax/positional/pipeline_test.py
@@ -72,7 +72,8 @@ def test_spherical_pendulum(self):
# from generalized and plug it back into pbd
# TODO: remove this xd override once kinematics.forward is fixed
state_g = g_pipeline.init(sys, init_q, init_qd)
- xd = Transform.create(pos=state_g.x.pos - state_g.com).vmap().do(state_g.cd)
+ off = state_g.x.pos - state_g.root_com
+ xd = Transform.create(pos=off).vmap().do(state_g.cd)
state = state.replace(xd=xd, xd_i=com.from_world(sys, state.x, xd)[1])
j_pos_step = jax.jit(pipeline.step)
diff --git a/brax/spring/joints.py b/brax/spring/joints.py
index 75e88273..423d3191 100644
--- a/brax/spring/joints.py
+++ b/brax/spring/joints.py
@@ -309,7 +309,7 @@ def _three_dof(
return Force(ang=ang, vel=vel)
-def resolve(sys: System, state: State, tau: jp.ndarray) -> Motion:
+def resolve(sys: System, state: State, tau: jp.ndarray) -> Force:
"""Calculates forces to apply to links resulting from joint constraints.
Args:
@@ -318,7 +318,7 @@ def resolve(sys: System, state: State, tau: jp.ndarray) -> Motion:
tau: joint force vector
Returns:
- xdd_i: acceleration to apply to link center of mass in world frame
+ xf_i: force to apply to link center of mass in world frame
"""
def j_fn(typ, link, j, jd, dof, tau):
@@ -346,11 +346,4 @@ def j_fn(typ, link, j, jd, dof, tau):
fp = Transform.create(pos=state.a_p.pos - x_i_parent.pos).vmap().do(xf)
fp = jax.tree_map(lambda x: segment_sum(x, parent_idx, sys.num_links()), fp)
xf_i = fc - fp
-
- # convert to acceleration
- xdd_i = Motion(
- ang=jax.vmap(lambda x, y: x @ y)(state.i_inv, xf_i.ang),
- vel=jax.vmap(lambda x, y: x / y)(xf_i.vel, state.mass),
- )
-
- return xdd_i
+ return xf_i
diff --git a/brax/spring/pipeline.py b/brax/spring/pipeline.py
index 748f76ca..b96ff775 100644
--- a/brax/spring/pipeline.py
+++ b/brax/spring/pipeline.py
@@ -17,6 +17,7 @@
from brax import actuator
from brax import com
+from brax import fluid
from brax import geometry
from brax import kinematics
from brax.base import Motion, System
@@ -24,6 +25,7 @@
from brax.spring import integrator
from brax.spring import joints
from brax.spring.base import State
+import jax
from jax import numpy as jp
@@ -88,7 +90,16 @@ def step(
# calculate acceleration and delta-velocity terms
tau = actuator.to_tau(sys, act, state.q, state.qd)
- xdd_i = joints.resolve(sys, state, tau) + Motion.create(vel=sys.gravity)
+ xdd_i = Motion.create(vel=sys.gravity)
+ xf_i = joints.resolve(sys, state, tau)
+ if sys.enable_fluid:
+ inertia = sys.link.inertia.i ** (1 - sys.spring_inertia_scale)
+ xf_i += fluid.force(sys, state.x, state.xd, state.mass, inertia)
+ xdd_i += Motion(
+ ang=jax.vmap(lambda x, y: x @ y)(state.i_inv, xf_i.ang),
+ vel=jax.vmap(lambda x, y: x / y)(xf_i.vel, state.mass),
+ )
+
# semi-implicit euler: apply acceleration update before resolving collisions
state = state.replace(xd_i=state.xd_i + xdd_i * sys.dt)
xdv_i = collisions.resolve(sys, state)
diff --git a/brax/spring/pipeline_test.py b/brax/spring/pipeline_test.py
index 698107eb..72786d1f 100644
--- a/brax/spring/pipeline_test.py
+++ b/brax/spring/pipeline_test.py
@@ -108,7 +108,8 @@ def test_spherical_pendulum(self):
# from generalized and plug it back into pbd
# TODO: remove this xd override once kinematics.forward is fixed
state_g = g_pipeline.init(sys, init_q, init_qd)
- xd = Transform.create(pos=state_g.x.pos - state_g.com).vmap().do(state_g.cd)
+ off = state_g.x.pos - state_g.root_com
+ xd = Transform.create(pos=off).vmap().do(state_g.cd)
state = state.replace(xd=xd, xd_i=com.from_world(sys, state.x, xd)[1])
j_spring_step = jax.jit(pipeline.step)
diff --git a/brax/test_data/fluid_sphere.xml b/brax/test_data/fluid_sphere.xml
new file mode 100644
index 00000000..13d461e3
--- /dev/null
+++ b/brax/test_data/fluid_sphere.xml
@@ -0,0 +1,15 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/brax/test_data/fluid_two_spheres.xml b/brax/test_data/fluid_two_spheres.xml
new file mode 100644
index 00000000..4e67a745
--- /dev/null
+++ b/brax/test_data/fluid_two_spheres.xml
@@ -0,0 +1,22 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/brax/test_data/single_pendulum_position.xml b/brax/test_data/single_pendulum_position.xml
index e469969c..533266ae 100644
--- a/brax/test_data/single_pendulum_position.xml
+++ b/brax/test_data/single_pendulum_position.xml
@@ -15,6 +15,6 @@