Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 537414728
Change-Id: I8f620bc9a9a5b424b5906357f8e1c1c5d1d4c6f7
  • Loading branch information
Brax Team authored and erikfrey committed Jun 2, 2023
1 parent aebd8b8 commit 3b72c95
Show file tree
Hide file tree
Showing 31 changed files with 261 additions and 187 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ If you would like to reference Brax in a publication, please use:
author = {C. Daniel Freeman and Erik Frey and Anton Raichuk and Sertan Girgin and Igor Mordatch and Olivier Bachem},
title = {Brax - A Differentiable Physics Engine for Large Scale Rigid Body Simulation},
url = {http://github.com/google/brax},
version = {0.9.0},
version = {0.9.1},
year = {2021},
}
```
Expand Down
2 changes: 1 addition & 1 deletion brax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""Import top-level classes and functions here for encapsulation/clarity."""

__version__ = '0.9.0'
__version__ = '0.9.1'

from brax.base import Motion
from brax.base import State
Expand Down
19 changes: 14 additions & 5 deletions brax/actuator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,21 @@ def to_tau(
if sys.act_size() == 0:
return jp.zeros(sys.qd_size())

q, qd = q[sys.actuator.q_id], qd[sys.actuator.qd_id]
ctrl_range = sys.actuator.ctrl_range
force_range = sys.actuator.force_range

q, qd = q[sys.actuator.q_id], qd[sys.actuator.qd_id]
act = jp.clip(act, ctrl_range[:, 0], ctrl_range[:, 1])
# TODO: incorporate gain
act = act + q * sys.actuator.bias_q + qd * sys.actuator.bias_qd
act_force = sys.actuator.gear * act
tau = jp.zeros(sys.qd_size()).at[sys.actuator.qd_id].add(act_force)
# See https://github.com/deepmind/mujoco/discussions/754 for why gear is
# used for the bias term.
bias = sys.actuator.gear * (
q * sys.actuator.bias_q + qd * sys.actuator.bias_qd
)

force = sys.actuator.gain * act + bias
force = jp.clip(force, force_range[:, 0], force_range[:, 1])

force *= sys.actuator.gear
tau = jp.zeros(sys.qd_size()).at[sys.actuator.qd_id].add(force)

return tau
26 changes: 24 additions & 2 deletions brax/actuator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,25 @@ def test_velocity(self):
_, qd = _actuator_step(g_pipeline, sys, q, qd, act=act, dt=0.001, n=200)
np.testing.assert_array_almost_equal(qd, jp.array([1]), 3)

def test_force_limitted(self):
"""Tests that forcerange limits work on actuators."""
sys = test_utils.load_fixture('single_pendulum_position_frclimit.xml')
mj_model = test_utils.load_fixture_mujoco(
'single_pendulum_position_frclimit.xml'
)
mj_data = mujoco.MjData(mj_model)
q, qd = jp.array(mj_data.qpos), jp.array(mj_data.qvel)

for act, frclimit in [(1000, 3.1), (-1000, -2.5)]:
act = jp.array([act])
tau = actuator.to_tau(sys, act, q, qd)
# test that tau matches frclimit * 10, since gear=10
self.assertEqual(tau[0], frclimit * 10)
# test that tau matches MJ qfrc_actuator
mj_data.ctrl = act
mujoco.mj_step(mj_model, mj_data)
self.assertEqual(tau[0], mj_data.qfrc_actuator)

@parameterized.parameters((g_pipeline,), (s_pipeline,), (p_pipeline,))
def test_three_link_pendulum(self, pipeline):
"""Tests a three link pendulum with a motor actuator."""
Expand Down Expand Up @@ -148,14 +167,17 @@ def test_spherical_pendulum_mj_generalized(self, config):
np.testing.assert_array_almost_equal(gq, mq, 3)
np.testing.assert_array_almost_equal(gqd, mqd, 3)

# TODO: test spherical pendulum once it's implemented in positional
@parameterized.parameters(
('single_pendulum_position.xml',), 'single_pendulum_motor.xml'
'single_pendulum_position.xml',
'single_pendulum_motor.xml',
'single_spherical_pendulum_position.xml',
)
def test_single_pendulum_spring_positional(self, config):
sys = test_utils.load_fixture(config)
act = jp.array([0.05, 0.1, 0.15])[: sys.act_size()]

q, qd = sys.init_q, jp.zeros(sys.qd_size())

sq, sqd = _actuator_step(s_pipeline, sys, q, qd, act=act, dt=sys.dt, n=500)
pq, pqd = _actuator_step(p_pipeline, sys, q, qd, act=act, dt=sys.dt, n=500)
np.testing.assert_array_almost_equal(sq, pq, 2)
Expand Down
9 changes: 8 additions & 1 deletion brax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,14 +453,18 @@ class Actuator(Base):
q_id: (num_actuators,) q index associated with an actuator
qd_id: (num_actuators,) qd index associated with an actuator
ctrl_range: (num_actuators, 2) actuator control range
gear: (num_actuators,) scaling factor for each actuator torque output
force_range: (num_actuators, 2) actuator force range
gain: (num_actuators,) scaling factor for each actuator control input
gear: (num_actuators,) scaling factor for each actuator force output
bias_q: (num_actuators,) bias applied by q (e.g. position actuators)
bias_qd: (num_actuators,) bias applied by qd (e.g. velocity actuators)
"""

q_id: jp.ndarray
qd_id: jp.ndarray
ctrl_range: jp.ndarray
force_range: jp.ndarray
gain: jp.ndarray
gear: jp.ndarray
bias_q: jp.ndarray
bias_qd: jp.ndarray
Expand Down Expand Up @@ -508,6 +512,8 @@ class System:
joint_scale_ang: scale for position-based joint rotation update
joint_scale_pos: scale for position-based joint position update
collide_scale: fraction of position based collide update to apply
enable_fluid: (1,) enables or disables fluid forces based on the
default viscosity and density parameters provided in the XML
geom_masks: 64-bit mask determines whether two geoms will be contact tested.
lower 32 bits are type, upper 32 bits are affinity. two geoms
a, b will be contact tested if a.type & b.affinity != 0
Expand Down Expand Up @@ -545,6 +551,7 @@ class System:
joint_scale_pos: jp.float32
collide_scale: jp.float32
# non-pytree nodes
enable_fluid: bool = struct.field(pytree_node=False)
geom_masks: List[int] = struct.field(pytree_node=False)
link_names: List[str] = struct.field(pytree_node=False)
link_types: str = struct.field(pytree_node=False)
Expand Down
15 changes: 8 additions & 7 deletions brax/fluid.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint:disable=g-multiple-import
# pylint:disable=g-multiple-import, g-importing-member
"""Functions for forces/torques through fluids."""

from typing import Tuple, Union
from typing import Union
from brax.base import Force, Motion, System, Transform
import jax
import jax.numpy as jp
Expand Down Expand Up @@ -58,13 +58,13 @@ def force(
xd: Motion,
mass: jp.ndarray,
inertia: jp.ndarray,
subtree_com: Union[jp.ndarray, None] = None,
) -> Tuple[Force, jp.ndarray]:
root_com: Union[jp.ndarray, None] = None,
) -> Force:
"""Returns force due to motion through a fluid."""
# get the velocity at the com position/orientation
x_i = x.vmap().do(sys.link.inertia.transform)
# TODO: remove subtree_com when xd is fixed for stacked joints
offset = x_i.pos - x.pos if subtree_com is None else x_i.pos - subtree_com
# TODO: remove root_com when xd is fixed for stacked joints
offset = x_i.pos - x.pos if root_com is None else x_i.pos - root_com
xd_i = x_i.replace(pos=offset).vmap().do(xd)

# TODO: add ellipsoid fluid model from mujoco
Expand All @@ -80,4 +80,5 @@ def force(

# rotate back to the world orientation
frc = Transform.create(rot=x_i.rot).vmap().do(frc)
return frc, x_i.pos

return frc
38 changes: 38 additions & 0 deletions brax/fluid_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from brax import test_utils
from brax.generalized import dynamics
from brax.generalized import pipeline as g_pipeline
from brax.positional import pipeline as p_pipeline
from brax.spring import pipeline as s_pipeline
import jax
from jax import numpy as jp
import mujoco
Expand Down Expand Up @@ -77,6 +79,42 @@ def test_fluid_mj_generalized(self, config, density, viscosity):
np.testing.assert_array_almost_equal(gq, mq, 2)
np.testing.assert_array_almost_equal(gqd, mqd, 2)

@parameterized.parameters(
('fluid_sphere.xml', p_pipeline, 0.0, 0.0),
('fluid_sphere.xml', p_pipeline, 0.0, 1.3),
('fluid_sphere.xml', p_pipeline, 2.0, 0.0),
('fluid_sphere.xml', p_pipeline, 2.0, 1.3),
('fluid_sphere.xml', s_pipeline, 0.0, 0.0),
('fluid_sphere.xml', s_pipeline, 0.0, 1.3),
('fluid_sphere.xml', s_pipeline, 2.0, 0.0),
('fluid_sphere.xml', s_pipeline, 2.0, 1.3),
('fluid_two_spheres.xml', p_pipeline, 2.0, 1.3),
('fluid_two_spheres.xml', s_pipeline, 2.0, 1.3),
)
def test_fluid_positional_spring(self, config, pipeline, density, viscosity):
"""Tests fluid interactions for pbd/spring compared to generalized."""
sys = test_utils.load_fixture(config)
sys = sys.replace(density=density, viscosity=viscosity)

q, qd = sys.init_q, jp.zeros(sys.qd_size())
qd = qd.at[:3].set(jp.array([1, 2, 4]))
qd = qd.at[3].set(jp.pi)

state_g = jax.jit(g_pipeline.init)(sys, q, qd)
for _ in range(500):
state_g = jax.jit(g_pipeline.step)(
sys, state_g, jp.zeros((sys.act_size(),))
)
gq, gqd = state_g.q, state_g.qd

state = jax.jit(pipeline.init)(sys, q, qd)
for _ in range(500):
state = jax.jit(pipeline.step)(sys, state, jp.zeros((sys.act_size(),)))
tq, tqd = state.q, state.qd

np.testing.assert_array_almost_equal(tq, gq, 3)
np.testing.assert_array_almost_equal(tqd, gqd, 2)


if __name__ == '__main__':
absltest.main()
18 changes: 9 additions & 9 deletions brax/generalized/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint:disable=g-multiple-import
# pylint:disable=g-multiple-import, g-importing-member
"""Base types for generalized pipeline."""

from brax import base
Expand All @@ -26,24 +26,24 @@ class State(base.State):
"""Dynamic state that changes after every step.
Attributes:
com: center of mass position of the kinematic tree containing the link
cinr: inertia in com frame
cd: body velocities in com frame
cdof: dofs in com frame
cdofd: cdof velocity
root_com: (num_links,) center of mass position of link root kinematic tree
cinr: (num_links,) inertia in com frame
cd: (num_links,) link velocities in com frame
cdof: (qd_size,) dofs in com frame
cdofd: (qd_size,) cdof velocity
mass_mx: (qd_size, qd_size) mass matrix
mass_mx_inv: (qd_size, qd_size) inverse mass matrix
contact: calculated contacts
con_jac: constraint jacobian
con_diag: constraint A diagonal
con_aref: constraint reference acceleration
qf_smooth: smooth dynamics force
qf_smooth: (qd_size,) smooth dynamics force
qf_constraint: (qd_size,) force from constraints (collision etc)
qdd: (qd_size,) joint acceleration vector
"""

# position/velocity based terms are updated at the end of each step:
com: jp.ndarray
root_com: jp.ndarray
cinr: Inertia
cd: Motion
cdof: Motion
Expand Down Expand Up @@ -71,7 +71,7 @@ def init(
x=x,
xd=xd,
contact=None,
com=jp.zeros(3),
root_com=jp.zeros(3),
cinr=Inertia(
Transform.zero((num_links,)),
jp.zeros((num_links, 3, 3)),
Expand Down
13 changes: 6 additions & 7 deletions brax/generalized/constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint:disable=g-multiple-import
# pylint:disable=g-multiple-import, g-importing-member
"""Functions for constraint satisfaction."""
from typing import Tuple

Expand Down Expand Up @@ -57,9 +57,8 @@ def _imp_aref(

# See https://mujoco.readthedocs.io/en/latest/modeling.html#solver-parameters
stiffness, damping = params[:2]
use_v2_params = (stiffness < 0) & (damping < 0)
b = jp.where(use_v2_params, -damping / dmax, b)
k = jp.where(use_v2_params, -stiffness / (dmax * dmax), k)
b = jp.where(damping <= 0, -damping / dmax, b)
k = jp.where(stiffness <= 0, -stiffness / (dmax * dmax), k)

aref = -b * vel - k * imp * pos

Expand Down Expand Up @@ -153,9 +152,9 @@ def jac_contact(

def row_fn(contact):
link_a, link_b = contact.link_idx
a = point_jacobian(sys, state.com, state.cdof, contact.pos, link_a).vel
b = point_jacobian(sys, state.com, state.cdof, contact.pos, link_b).vel
diff = b - a
a = point_jacobian(sys, state.root_com, state.cdof, contact.pos, link_a)
b = point_jacobian(sys, state.root_com, state.cdof, contact.pos, link_b)
diff = b.vel - a.vel

# 4 pyramidal friction directions
jac = []
Expand Down
40 changes: 22 additions & 18 deletions brax/generalized/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint:disable=g-multiple-import
# pylint:disable=g-multiple-import, g-importing-member
"""Functions for smooth forward and inverse dynamics."""
import functools

from brax import fluid
from brax import math
from brax import scan
Expand All @@ -42,8 +40,8 @@ def transform_com(sys: System, state: State) -> State:
mass_xi = jax.vmap(jp.multiply)(sys.link.inertia.mass, x_i.pos)
mass_xi_sum = jax.ops.segment_sum(mass_xi, root, sys.num_links())
mass_sum = jax.ops.segment_sum(sys.link.inertia.mass, root, sys.num_links())
com = jax.vmap(jp.divide)(mass_xi_sum[root], mass_sum[root])
cinr = x_i.replace(pos=x_i.pos - com).vmap().do(sys.link.inertia)
root_com = jax.vmap(jp.divide)(mass_xi_sum[root], mass_sum[root])
cinr = x_i.replace(pos=x_i.pos - root_com).vmap().do(sys.link.inertia)

# motion dofs to global frame centered at subtree-CoM
parent_idx = jp.array(
Expand Down Expand Up @@ -89,7 +87,8 @@ def cdof_fn(typ, q, motion):
cdof = scan.link_types(sys, cdof_fn, 'qd', 'd', state.q, sys.dof.motion)
ang = jax.vmap(math.rotate)(cdof.ang, j.take(sys.dof_link()).rot)
cdof = cdof.replace(ang=ang)
cdof = Transform.create(pos=com - j.pos).take(sys.dof_link()).vmap().do(cdof)
off = Transform.create(pos=root_com - j.pos)
cdof = off.take(sys.dof_link()).vmap().do(cdof)
cdof_qd = jax.vmap(lambda x, y: x * y)(cdof, state.qd)

# forward scan down tree: accumulate link center of mass velocity
Expand Down Expand Up @@ -132,7 +131,9 @@ def cdofd_fn(typ, cd, cdof, cdof_qd):
cd_p = cd.concatenate(Motion.zero(shape=(1,))).take(parent_idx)
cdofd = scan.link_types(sys, cdofd_fn, 'ldd', 'd', cd_p, cdof, cdof_qd)

return state.replace(com=com, cinr=cinr, cd=cd, cdof=cdof, cdofd=cdofd)
return state.replace(
root_com=root_com, cinr=cinr, cd=cd, cdof=cdof, cdofd=cdofd
)


def inverse(sys: System, state: State) -> jp.ndarray:
Expand Down Expand Up @@ -194,17 +195,20 @@ def stiffness_fn(typ, q, dof):
frc = scan.link_types(sys, stiffness_fn, 'qd', 'd', state.q, sys.dof)
frc -= sys.dof.damping * state.qd

fluid_frc, pos = fluid.force(
sys,
state.x,
state.cd,
sys.link.inertia.mass,
sys.link.inertia.i,
state.com,
)
cdof_fn = functools.partial(point_jacobian, sys, state.com, state.cdof)
jac = jax.vmap(cdof_fn)(pos, jp.arange(sys.num_links()))
frc += jax.vmap(lambda x, y: x.dot(y))(jac, fluid_frc).sum(axis=0)
if sys.enable_fluid:
fluid_frc = fluid.force(
sys,
state.x,
state.cd,
sys.link.inertia.mass,
sys.link.inertia.i,
state.root_com,
)
link_idx = jp.arange(sys.num_links())
x_i = state.x.vmap().do(sys.link.inertia.transform)
jac_fn = jax.vmap(point_jacobian, in_axes=(None, None, None, 0, 0))
jac = jac_fn(sys, state.root_com, state.cdof, x_i.pos, link_idx)
frc += jax.vmap(lambda x, y: x.dot(y))(jac, fluid_frc).sum(axis=0)

return frc

Expand Down
4 changes: 3 additions & 1 deletion brax/generalized/dynamics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ def test_transform_com(self, xml_file):
for mj_prev, mj_next in test_utils.sample_mujoco_states(xml_file):
state = jax.jit(pipeline.init)(sys, mj_prev.qpos, mj_prev.qvel)

np.testing.assert_almost_equal(state.com[0], mj_next.subtree_com[0], 5)
np.testing.assert_almost_equal(
state.root_com[0], mj_next.subtree_com[0], 5
)
mj_cinr_i = np.zeros((state.cinr.i.shape[0], 3, 3))
mj_cinr_i[:, [0, 1, 2], [0, 1, 2]] = mj_next.cinert[1:, 0:3] # diagonal
mj_cinr_i[:, [0, 0, 1], [1, 2, 2]] = mj_next.cinert[1:, 3:6] # upper tri
Expand Down
Loading

0 comments on commit 3b72c95

Please sign in to comment.