diff --git a/MANIFEST.in b/MANIFEST.in
index d82985f0..0e6fec3d 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -1,3 +1,3 @@
include brax/envs/assets/*.xml
recursive-include brax/test_data *.xml *.stl *.obj *.urdf
-recursive-include brax/visualizer *
+recursive-include brax/visualizer *
\ No newline at end of file
diff --git a/brax/actuator.py b/brax/actuator.py
index 192b99a6..39291c97 100644
--- a/brax/actuator.py
+++ b/brax/actuator.py
@@ -15,18 +15,20 @@
# pylint:disable=g-multiple-import
"""Functions for applying actuators to a physics pipeline."""
-from brax import scan
from brax.base import System
from jax import numpy as jp
-def to_tau(sys: System, act: jp.ndarray, q: jp.ndarray) -> jp.ndarray:
+def to_tau(
+ sys: System, act: jp.ndarray, q: jp.ndarray, qd: jp.ndarray
+) -> jp.ndarray:
"""Convert actuator to a joint force tau.
Args:
sys: system defining the kinematic tree and other properties
act: (act_size,) actuator force input vector
q: joint position vector
+ qd: joint velocity vector
Returns:
tau: (qd_size,) vector of joint forces
@@ -34,21 +36,12 @@ def to_tau(sys: System, act: jp.ndarray, q: jp.ndarray) -> jp.ndarray:
if sys.act_size() == 0:
return jp.zeros(sys.qd_size())
- def act_fn(act_type, act, actuator, q, qd_idx):
- if act_type not in ('p', 'm'):
- raise RuntimeError(f'unrecognized act type: {act_type}')
-
- force = jp.clip(act, actuator.ctrl_range[:, 0], actuator.ctrl_range[:, 1])
- if act_type == 'p':
- force -= q # positional actuators have a bias
- tau = actuator.gear * force
-
- return tau, qd_idx
-
- qd_idx = jp.arange(sys.qd_size())
- tau, qd_idx = scan.actuator_types(
- sys, act_fn, 'aaqd', 'a', act, sys.actuator, q, qd_idx
- )
- tau = jp.zeros(sys.qd_size()).at[qd_idx].add(tau)
+ q, qd = q[sys.actuator.q_id], qd[sys.actuator.qd_id]
+ ctrl_range = sys.actuator.ctrl_range
+ act = jp.clip(act, ctrl_range[:, 0], ctrl_range[:, 1])
+ # TODO: incorporate gain
+ act = act + q * sys.actuator.bias_q + qd * sys.actuator.bias_qd
+ act_force = sys.actuator.gear * act
+ tau = jp.zeros(sys.qd_size()).at[sys.actuator.qd_id].add(act_force)
return tau
diff --git a/brax/actuator_test.py b/brax/actuator_test.py
index c2e86508..a5089512 100644
--- a/brax/actuator_test.py
+++ b/brax/actuator_test.py
@@ -55,7 +55,7 @@ def test_motor(self, pipeline, dt, n, decimal):
tau = jp.array([0.5 * 9.81]) # -mgl sin(theta)
act = jp.array([1.0 / 150.0 * 0.5 * 9.81])
- tau2 = actuator.to_tau(sys, act, q)
+ tau2 = actuator.to_tau(sys, act, q, qd)
np.testing.assert_array_almost_equal(tau, tau2, 5)
q2, qd2 = _actuator_step(pipeline, sys, q, qd, act=act, dt=dt, n=n)
@@ -73,7 +73,7 @@ def test_position(self):
# a position actuator at the bottom should not move the pendulum
act = jp.array([theta])
- tau = actuator.to_tau(sys, act, q)
+ tau = actuator.to_tau(sys, act, q, qd)
np.testing.assert_array_almost_equal(tau, jp.array([0]), 5)
# put the pendulum into the horizontal position with the positional actuator
@@ -83,6 +83,24 @@ def test_position(self):
q2, _ = _actuator_step(g_pipeline, sys, q, qd, act=act, dt=0.01, n=1)
np.testing.assert_array_almost_equal(q2, jp.array([0]), 1)
+ def test_velocity(self):
+ """Tests a single pendulum with velocity actuator."""
+ sys = test_utils.load_fixture('single_pendulum_velocity.xml')
+ mj_model = test_utils.load_fixture_mujoco('single_pendulum_velocity.xml')
+ mj_data = mujoco.MjData(mj_model)
+ q, qd = jp.array(mj_data.qpos), jp.array(mj_data.qvel)
+ theta = jp.pi / 2.0 # pendulum is vertical
+ q = jp.array([theta])
+
+ act = jp.array([0])
+ tau = actuator.to_tau(sys, act, q, qd)
+ np.testing.assert_array_almost_equal(tau, jp.array([0]), 5)
+
+ # set the act to rotate at 1/s
+ act = jp.array([1])
+ _, qd = _actuator_step(g_pipeline, sys, q, qd, act=act, dt=0.001, n=200)
+ np.testing.assert_array_almost_equal(qd, jp.array([1]), 3)
+
@parameterized.parameters((g_pipeline,), (s_pipeline,), (p_pipeline,))
def test_three_link_pendulum(self, pipeline):
"""Tests a three link pendulum with a motor actuator."""
diff --git a/brax/base.py b/brax/base.py
index b886ecae..f3c78c02 100644
--- a/brax/base.py
+++ b/brax/base.py
@@ -269,7 +269,7 @@ class Link(Base):
constraint_stiffness: jp.ndarray
constraint_vel_damping: jp.ndarray
constraint_limit_stiffness: jp.ndarray
- # only used by `brax.physics.spring` and `brax.physics.pbd`:
+ # only used by `brax.physics.spring` and `brax.physics.positional`:
constraint_ang_damping: jp.ndarray
@@ -284,6 +284,7 @@ class DoF(Base):
damping: restorative force back to zero velocity
limit: tuple of min, max angle limits
invweight: diagonal inverse inertia at init_qpos
+ solver_params: (7,) limit constraint solver parameters
"""
motion: Motion
@@ -293,6 +294,7 @@ class DoF(Base):
limit: Tuple[jp.ndarray, jp.ndarray]
# only used by `brax.physics.generalized`:
invweight: jp.ndarray
+ solver_params: jp.ndarray
@struct.dataclass
@@ -305,12 +307,14 @@ class Geometry(Base):
relative to the world frame in the case of unparented geometry
friction: resistance encountered when sliding against another geometry
elasticity: bounce/restitution encountered when hitting another geometry
+ solver_params: (7,) solver parameters (reference, impedance)
"""
link_idx: Optional[jp.ndarray]
transform: Transform
friction: jp.ndarray
elasticity: jp.ndarray
+ solver_params: jp.ndarray
@struct.dataclass
@@ -354,6 +358,21 @@ class Box(Geometry):
rgba: Optional[jp.ndarray] = None
+@struct.dataclass
+class Cylinder(Geometry):
+ """A cylinder.
+
+ Attributes:
+ radius: (1,) radius of the top and bottom of the cylinder
+ length: (1,) length of the cylinder
+ rgba: (4,) the rgba to display in the renderer
+ """
+
+ radius: jp.ndarray
+ length: jp.ndarray
+ rgba: Optional[jp.ndarray] = None
+
+
@struct.dataclass
class Plane(Geometry):
"""An infinite plane whose normal points at +z in its coordinate space.
@@ -410,6 +429,7 @@ class Contact(Base):
two geometries are interpenetrating, negative means they are not
friction: resistance encountered when sliding against another geometry
elasticity: bounce/restitution encountered when hitting another geometry
+ solver_params: (7,) collision constraint solver parameters
link_idx: Tuple of link indices participating in contact. The second part
of the tuple can be None if the second geometry is static.
"""
@@ -418,8 +438,9 @@ class Contact(Base):
normal: jp.ndarray
penetration: jp.ndarray
friction: jp.ndarray
- # only used by `brax.physics.spring` and `brax.physics.pbd`:
+ # only used by `brax.physics.spring` and `brax.physics.positional`:
elasticity: jp.ndarray
+ solver_params: jp.ndarray
link_idx: Tuple[jp.ndarray, Optional[jp.ndarray]]
@@ -429,13 +450,20 @@ class Actuator(Base):
"""Actuator, transforms an input signal into a force (motor or thruster).
Attributes:
- ctrl_range: (num_actuators, 2) control range for each actuator
- gear: (num_actuators,) a list of floats used as a scaling factor for each
- actuator torque output
+ q_id: (num_actuators,) q index associated with an actuator
+ qd_id: (num_actuators,) qd index associated with an actuator
+ ctrl_range: (num_actuators, 2) actuator control range
+ gear: (num_actuators,) scaling factor for each actuator torque output
+ bias_q: (num_actuators,) bias applied by q (e.g. position actuators)
+ bias_qd: (num_actuators,) bias applied by qd (e.g. velocity actuators)
"""
+ q_id: jp.ndarray
+ qd_id: jp.ndarray
ctrl_range: jp.ndarray
gear: jp.ndarray
+ bias_q: jp.ndarray
+ bias_qd: jp.ndarray
@struct.dataclass
@@ -464,13 +492,13 @@ class System:
Attributes:
dt: timestep used for the simulation
gravity: (3,) linear universal force applied during forward dynamics
+ viscosity: (1,) viscosity of the medium applied to all links
+ density: (1,) density of the medium applied to all links
link: (num_link,) the links in the system
dof: (qd_size,) every degree of freedom for the system
geoms: list of batched geoms grouped by type
actuator: actuators that can be applied to links
init_q: (q_size,) initial q position for the system
- solver_params_joint: (7,) joint limit constraint solver parameters
- solver_params_contact: (7,) collision constraint solver parameters
vel_damping: (1,) linear vel damping applied to each body.
ang_damping: (1,) angular vel damping applied to each body.
baumgarte_erp: how aggressively interpenetrating bodies should push away\
@@ -492,12 +520,6 @@ class System:
* '3': spherical, 3 dof, like a ball joint
link_parents: (num_link,) int list specifying the index of each link's
parent link, or -1 if the link has no parent
- actuator_types: (num_actuators,) string specifying the actuator types:
- * 't': torque
- * 'p': position
- actuator_link_id: (num_actuators,) the link id associated with each actuator
- actuator_qid: (num_actuators,) the q index associated with each actuator
- actuator_qdid: (num_actuators,) the qd index associated with each actuator
matrix_inv_iterations: maximum number of iterations of the matrix inverse
solver_iterations: maximum number of iterations of the constraint solver
solver_maxls: maximum number of line searches of the constraint solver
@@ -505,33 +527,28 @@ class System:
dt: jp.ndarray
gravity: jp.ndarray
+ viscosity: jp.float32
+ density: jp.float32
link: Link
dof: DoF
geoms: List[Geometry]
actuator: Actuator
init_q: jp.ndarray
- # only used in `brax.physics.generalized`
- solver_params_joint: jp.ndarray
- solver_params_contact: jp.ndarray
- # only used in `brax.physics.spring` and `brax.physics.pbd`:
+ # only used in `brax.physics.spring` and `brax.physics.positional`:
vel_damping: jp.float32
ang_damping: jp.float32
baumgarte_erp: jp.float32
spring_mass_scale: jp.float32
spring_inertia_scale: jp.float32
- # only used in `brax.physics.positional`
+ # only used in `brax.physics.positional`:
joint_scale_ang: jp.float32
joint_scale_pos: jp.float32
collide_scale: jp.float32
-
+ # non-pytree nodes
geom_masks: List[int] = struct.field(pytree_node=False)
link_names: List[str] = struct.field(pytree_node=False)
link_types: str = struct.field(pytree_node=False)
link_parents: Tuple[int, ...] = struct.field(pytree_node=False)
- actuator_types: str = struct.field(pytree_node=False)
- actuator_link_id: List[int] = struct.field(pytree_node=False)
- actuator_qid: List[int] = struct.field(pytree_node=False)
- actuator_qdid: List[int] = struct.field(pytree_node=False)
# only used in `brax.physics.generalized`:
matrix_inv_iterations: int = struct.field(pytree_node=False)
solver_iterations: int = struct.field(pytree_node=False)
@@ -595,7 +612,7 @@ def qd_size(self) -> int:
def act_size(self) -> int:
"""Returns the act dimension for the system."""
- return sum({'m': 1, 'p': 1}[act_typ] for act_typ in self.actuator_types)
+ return self.actuator.q_id.shape[0]
# below are some operation dispatch derivations
diff --git a/brax/envs/humanoid.py b/brax/envs/humanoid.py
index 4d32d8ec..d2bf0590 100644
--- a/brax/envs/humanoid.py
+++ b/brax/envs/humanoid.py
@@ -310,7 +310,8 @@ def _get_obs(
com_ang = xd_i.ang
com_velocity = jp.hstack([com_vel, com_ang])
- qfrc_actuator = actuator.to_tau(self.sys, action, pipeline_state.q)
+ qfrc_actuator = actuator.to_tau(
+ self.sys, action, pipeline_state.q, pipeline_state.qd)
# external_contact_forces are excluded
return jp.concatenate([
diff --git a/brax/envs/humanoidstandup.py b/brax/envs/humanoidstandup.py
index 972d0d77..e7d6ea6d 100644
--- a/brax/envs/humanoidstandup.py
+++ b/brax/envs/humanoidstandup.py
@@ -254,7 +254,8 @@ def _get_obs(
com_ang = xd_i.ang
com_velocity = jp.hstack([com_vel, com_ang])
- qfrc_actuator = actuator.to_tau(self.sys, action, pipeline_state.q)
+ qfrc_actuator = actuator.to_tau(
+ self.sys, action, pipeline_state.q, pipeline_state.qd)
# external_contact_forces are excluded
return jp.concatenate([
diff --git a/brax/fluid.py b/brax/fluid.py
new file mode 100644
index 00000000..b05e4e88
--- /dev/null
+++ b/brax/fluid.py
@@ -0,0 +1,83 @@
+# Copyright 2023 The Brax Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# pylint:disable=g-multiple-import
+"""Functions for forces/torques through fluids."""
+
+from typing import Tuple, Union
+from brax.base import Force, Motion, System, Transform
+import jax
+import jax.numpy as jp
+
+
+def _box_viscosity(
+ box: jp.ndarray, xd_i: Motion, viscosity: jp.ndarray
+) -> Force:
+ """Gets force due to motion through a viscous fluid."""
+ diam = jp.mean(box, axis=-1)
+ ang_scale = -jp.pi * diam**3 * viscosity
+ vel_scale = -3.0 * jp.pi * diam * viscosity
+ frc = Force(
+ ang=ang_scale[:, None] * xd_i.ang, vel=vel_scale[:, None] * xd_i.vel
+ )
+ return frc
+
+
+def _box_density(box: jp.ndarray, xd_i: Motion, density: jp.ndarray) -> Force:
+ """Gets force due to motion through dense fluid."""
+
+ @jax.vmap
+ def apply(b: jp.ndarray, xd: Motion) -> Force:
+ box_mult_vel = jp.array([b[1] * b[2], b[0] * b[2], b[0] * b[1]])
+ vel = -0.5 * density * box_mult_vel * jp.abs(xd.vel) * xd.vel
+ box_mult_ang = jp.array([
+ b[0] * (b[1] ** 4 + b[2] ** 4),
+ b[1] * (b[0] ** 4 + b[2] ** 4),
+ b[2] * (b[0] ** 4 + b[1] ** 4),
+ ])
+ ang = -1.0 * density * box_mult_ang * jp.abs(xd.ang) * xd.ang / 64.0
+ return Force(vel=vel, ang=ang)
+
+ return apply(box, xd_i)
+
+
+def force(
+ sys: System,
+ x: Transform,
+ xd: Motion,
+ mass: jp.ndarray,
+ inertia: jp.ndarray,
+ subtree_com: Union[jp.ndarray, None] = None,
+) -> Tuple[Force, jp.ndarray]:
+ """Returns force due to motion through a fluid."""
+ # get the velocity at the com position/orientation
+ x_i = x.vmap().do(sys.link.inertia.transform)
+ # TODO: remove subtree_com when xd is fixed for stacked joints
+ offset = x_i.pos - x.pos if subtree_com is None else x_i.pos - subtree_com
+ xd_i = x_i.replace(pos=offset).vmap().do(xd)
+
+ # TODO: add ellipsoid fluid model from mujoco
+ # TODO: consider adding wind from mj.opt.wind
+ diag_inertia = jax.vmap(jp.diag)(inertia)
+ diag_inertia_v = jp.repeat(diag_inertia, 3, axis=-2).reshape((-1, 3, 3))
+ diag_inertia_v *= jp.ones((3, 3)) - 2 * jp.eye(3)
+ box = 6.0 * jp.clip(jp.sum(diag_inertia_v, axis=-1), a_min=1e-12)
+ box = jp.sqrt(box / mass[:, None])
+
+ frc = _box_viscosity(box, xd_i, sys.viscosity)
+ frc += _box_density(box, xd_i, sys.density)
+
+ # rotate back to the world orientation
+ frc = Transform.create(rot=x_i.rot).vmap().do(frc)
+ return frc, x_i.pos
diff --git a/brax/fluid_test.py b/brax/fluid_test.py
new file mode 100644
index 00000000..5a9e54df
--- /dev/null
+++ b/brax/fluid_test.py
@@ -0,0 +1,82 @@
+# Copyright 2023 The Brax Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# pylint:disable=g-multiple-import
+"""Tests for fluid."""
+
+from absl.testing import absltest
+from absl.testing import parameterized
+from brax import test_utils
+from brax.generalized import dynamics
+from brax.generalized import pipeline as g_pipeline
+import jax
+from jax import numpy as jp
+import mujoco
+import numpy as np
+
+assert_almost_equal = np.testing.assert_almost_equal
+
+
+class FluidTest(parameterized.TestCase):
+
+ @parameterized.parameters(
+ ('fluid_box.xml', 0.0, 0.0),
+ ('fluid_box.xml', 0.0, 1.3),
+ ('fluid_box.xml', 2.0, 0.0),
+ ('fluid_box.xml', 2.0, 1.3),
+ ('fluid_box_offset_com.xml', 0.0, 1.3),
+ ('fluid_box_offset_com.xml', 2.0, 0.0),
+ ('fluid_box_offset_com.xml', 2.0, 1.3),
+ )
+ def test_fluid_mj_generalized(self, config, density, viscosity):
+ """Tests fluid interactions."""
+ sys = test_utils.load_fixture(config)
+ sys = sys.replace(density=density, viscosity=viscosity)
+
+ mj_model = test_utils.load_fixture_mujoco(config)
+ mj_model.opt.density = density
+ mj_model.opt.viscosity = viscosity
+ mj_data = mujoco.MjData(mj_model)
+ # initialize qd so that the object interacts with the fluid
+ mj_data.qvel[:3] = [1, 2, 4]
+ mj_data.qvel[3] = jp.pi
+ q, qd = jp.asarray(mj_data.qpos), jp.asarray(mj_data.qvel)
+
+ # check qfrc_passive after the first step
+ mujoco.mj_step(mj_model, mj_data)
+ state = jax.jit(g_pipeline.init)(sys, q, qd)
+ qfrc_passive = jax.jit(dynamics._passive)(sys, state)
+ np.testing.assert_array_almost_equal(
+ qfrc_passive[3:], mj_data.qfrc_passive[3:], 2
+ )
+ m = max(jp.abs(mj_data.qfrc_passive[:3])) + 1e-6
+ np.testing.assert_array_almost_equal(
+ qfrc_passive[:3] / m, mj_data.qfrc_passive[:3] / m, 1
+ )
+
+ # check q/qd after multiple steps
+ for _ in range(500):
+ mujoco.mj_step(mj_model, mj_data)
+ mq, mqd = jp.asarray(mj_data.qpos), jp.asarray(mj_data.qvel)
+
+ for _ in range(500):
+ state = jax.jit(g_pipeline.step)(sys, state, jp.zeros((sys.act_size(),)))
+ gq, gqd = state.q, state.qd
+
+ np.testing.assert_array_almost_equal(gq, mq, 2)
+ np.testing.assert_array_almost_equal(gqd, mqd, 2)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/brax/generalized/constraint.py b/brax/generalized/constraint.py
index 5b9ef367..be59d59a 100644
--- a/brax/generalized/constraint.py
+++ b/brax/generalized/constraint.py
@@ -26,38 +26,6 @@
import jaxopt
-def _pt_jac(
- sys: System,
- com: jp.ndarray,
- cdof: Motion,
- pos: jp.ndarray,
- link_idx: jp.ndarray,
-) -> jp.ndarray:
- """Calculates the point jacobian.
-
- Args:
- sys: a brax system
- com: center of mass position
- cdof: dofs in com frame
- pos: position in world frame
- link_idx: index of link frame to transform point jacobian
-
- Returns:
- pt: point jacobian
- """
- # backward scan up tree: build the link mask corresponding to link_idx
- def mask_fn(mask_child, link):
- mask = link == link_idx
- if mask_child is not None:
- mask += mask_child
- return mask
-
- mask = scan.tree(sys, mask_fn, 'l', jp.arange(sys.num_links()), reverse=True)
- cdof = jax.vmap(lambda a, b: a * b)(cdof, jp.take(mask, sys.dof_link()))
- off = Transform.create(pos=pos - com[link_idx])
- return off.vmap(in_axes=(None, 0)).do(cdof).vel
-
-
def _imp_aref(
params: jp.ndarray, pos: jp.ndarray, vel: jp.ndarray
) -> Tuple[jp.ndarray, jp.ndarray]:
@@ -87,11 +55,49 @@ def _imp_aref(
b = 2 / (dmax * timeconst)
k = 1 / (dmax * dmax * timeconst * timeconst * dampratio * dampratio)
+ # See https://mujoco.readthedocs.io/en/latest/modeling.html#solver-parameters
+ stiffness, damping = params[:2]
+ use_v2_params = (stiffness < 0) & (damping < 0)
+ b = jp.where(use_v2_params, -damping / dmax, b)
+ k = jp.where(use_v2_params, -stiffness / (dmax * dmax), k)
+
aref = -b * vel - k * imp * pos
return imp, aref
+def point_jacobian(
+ sys: System,
+ com: jp.ndarray,
+ cdof: Motion,
+ pos: jp.ndarray,
+ link_idx: jp.ndarray,
+) -> Motion:
+ """Calculates the jacobian of a point on a link.
+
+ Args:
+ sys: a brax system
+ com: center of mass position
+ cdof: dofs in com frame
+ pos: position in world frame to calculate the jacobian
+ link_idx: index of link frame to transform point jacobian
+
+ Returns:
+ pt: point jacobian
+ """
+ # backward scan up tree: build the link mask corresponding to link_idx
+ def mask_fn(mask_child, link):
+ mask = link == link_idx
+ if mask_child is not None:
+ mask += mask_child
+ return mask
+
+ mask = scan.tree(sys, mask_fn, 'l', jp.arange(sys.num_links()), reverse=True)
+ cdof = jax.vmap(lambda a, b: a * b)(cdof, jp.take(mask, sys.dof_link()))
+ off = Transform.create(pos=pos - com[link_idx])
+ return off.vmap(in_axes=(None, 0)).do(cdof)
+
+
def jac_limit(
sys: System, state: State
) -> Tuple[jp.ndarray, jp.ndarray, jp.ndarray]:
@@ -118,7 +124,8 @@ def jac_limit(
side = ((pos_min < pos_max) * 2 - 1) * (pos < 0)
jac = jax.vmap(jp.multiply)(jp.eye(sys.qd_size())[qd_idx], side)
- imp, aref = _imp_aref(sys.solver_params_joint, pos, jac @ state.qd)
+ params = sys.dof.solver_params[qd_idx]
+ imp, aref = jax.vmap(_imp_aref)(params, pos, jac @ state.qd)
diag = sys.dof.invweight[qd_idx] * (pos < 0) * (1 - imp) / (imp + 1e-8)
aref = jax.vmap(lambda x, y: x * y)(aref, (pos < 0))
@@ -146,8 +153,8 @@ def jac_contact(
def row_fn(contact):
link_a, link_b = contact.link_idx
- a = _pt_jac(sys, state.com, state.cdof, contact.pos, link_a)
- b = _pt_jac(sys, state.com, state.cdof, contact.pos, link_b)
+ a = point_jacobian(sys, state.com, state.cdof, contact.pos, link_a).vel
+ b = point_jacobian(sys, state.com, state.cdof, contact.pos, link_b).vel
diff = b - a
# 4 pyramidal friction directions
@@ -158,7 +165,7 @@ def row_fn(contact):
jac = jp.stack(jac)
pos = -jp.tile(contact.penetration, 4)
- imp, aref = _imp_aref(sys.solver_params_contact, pos, jac @ state.qd)
+ imp, aref = _imp_aref(contact.solver_params, pos, jac @ state.qd)
t = sys.link.invweight[link_a] + sys.link.invweight[link_b] * (link_b > -1)
diag = jp.tile(t + contact.friction * contact.friction * t, 4)
diag *= 2 * contact.friction * contact.friction * (1 - imp) / (imp + 1e-8)
diff --git a/brax/generalized/constraint_test.py b/brax/generalized/constraint_test.py
index 057918da..d10de72d 100644
--- a/brax/generalized/constraint_test.py
+++ b/brax/generalized/constraint_test.py
@@ -20,6 +20,7 @@
from brax import test_utils
from brax.generalized import pipeline
import jax
+from jax import numpy as jp
import numpy as np
@@ -30,6 +31,7 @@ class ConstraintTest(parameterized.TestCase):
('triple_pendulum.xml',),
('humanoid.xml',),
('half_cheetah.xml',),
+ ('solver_params_v2.xml',)
)
def test_jacobian(self, xml_file):
"""Test constraint jacobian."""
@@ -56,6 +58,7 @@ def test_jacobian(self, xml_file):
('triple_pendulum.xml',),
('humanoid.xml',),
('half_cheetah.xml',),
+ ('solver_params_v2.xml',)
)
def test_force(self, xml_file):
"""Test constraint force."""
@@ -68,8 +71,9 @@ def test_force(self, xml_file):
# force PGS so we can reference efc_AR:
samples = test_utils.sample_mujoco_states(xml_file, force_pgs=True)
for mj_prev, mj_next in samples:
+ act = jp.zeros(sys.act_size())
state = jax.jit(pipeline.init)(sys, mj_prev.qpos, mj_prev.qvel)
- state = jax.jit(pipeline.step)(sys, state, mj_prev.qfrc_applied)
+ state = jax.jit(pipeline.step)(sys, state, act)
efc_jt = np.reshape(mj_next.efc_J, (-1, sys.qd_size())).T
# recover con_frc by backing it out from qf_constraint
con_frc = np.linalg.lstsq(efc_jt, state.qf_constraint, None)[0]
diff --git a/brax/generalized/dynamics.py b/brax/generalized/dynamics.py
index 6743ca11..781f58af 100644
--- a/brax/generalized/dynamics.py
+++ b/brax/generalized/dynamics.py
@@ -14,10 +14,14 @@
# pylint:disable=g-multiple-import
"""Functions for smooth forward and inverse dynamics."""
+import functools
+
+from brax import fluid
from brax import math
from brax import scan
from brax.base import Motion, System, Transform
from brax.generalized.base import State
+from brax.generalized.constraint import point_jacobian
import jax
from jax import numpy as jp
@@ -180,16 +184,27 @@ def cfrc_fn(cfrc_child, cfrc):
return tau
-def _passive(sys: System, q: jp.ndarray, qd: jp.ndarray) -> jp.ndarray:
+def _passive(sys: System, state: State) -> jp.ndarray:
"""Calculates the system's passive forces given input motion and position."""
-
def stiffness_fn(typ, q, dof):
if typ in 'fb':
return jp.zeros_like(dof.stiffness)
return -q * dof.stiffness
- frc = scan.link_types(sys, stiffness_fn, 'qd', 'd', q, sys.dof)
- frc -= sys.dof.damping * qd
+ frc = scan.link_types(sys, stiffness_fn, 'qd', 'd', state.q, sys.dof)
+ frc -= sys.dof.damping * state.qd
+
+ fluid_frc, pos = fluid.force(
+ sys,
+ state.x,
+ state.cd,
+ sys.link.inertia.mass,
+ sys.link.inertia.i,
+ state.com,
+ )
+ cdof_fn = functools.partial(point_jacobian, sys, state.com, state.cdof)
+ jac = jax.vmap(cdof_fn)(pos, jp.arange(sys.num_links()))
+ frc += jax.vmap(lambda x, y: x.dot(y))(jac, fluid_frc).sum(axis=0)
return frc
@@ -210,7 +225,7 @@ def forward(sys: System, state: State, tau: jp.ndarray) -> jp.ndarray:
Returns:
qfrc: joint force vector
"""
- qfrc_passive = _passive(sys, state.q, state.qd)
+ qfrc_passive = _passive(sys, state)
qfrc_bias = inverse(sys, state)
qfrc = qfrc_passive - qfrc_bias + tau
diff --git a/brax/generalized/dynamics_test.py b/brax/generalized/dynamics_test.py
index daefdf89..fb0487cc 100644
--- a/brax/generalized/dynamics_test.py
+++ b/brax/generalized/dynamics_test.py
@@ -20,6 +20,7 @@
from brax import test_utils
from brax.generalized import pipeline
import jax
+from jax import numpy as jp
import numpy as np
@@ -55,8 +56,9 @@ def test_forward(self, xml_file):
"""Test dynamics forward."""
sys = test_utils.load_fixture(xml_file)
for mj_prev, mj_next in test_utils.sample_mujoco_states(xml_file):
+ act = jp.zeros(sys.act_size())
state = jax.jit(pipeline.init)(sys, mj_prev.qpos, mj_prev.qvel)
- state = jax.jit(pipeline.step)(sys, state, mj_prev.qfrc_applied)
+ state = jax.jit(pipeline.step)(sys, state, act)
np.testing.assert_allclose(
state.qf_smooth, mj_next.qfrc_smooth, rtol=1e-4, atol=1e-4
diff --git a/brax/generalized/perf_test.py b/brax/generalized/perf_test.py
index bacbb317..2f6de4e0 100644
--- a/brax/generalized/perf_test.py
+++ b/brax/generalized/perf_test.py
@@ -35,7 +35,7 @@ def init_fn(rng):
return pipeline.init(sys, q, qd)
def step_fn(state):
- return pipeline.step(sys, state, jp.zeros(sys.qd_size()))
+ return pipeline.step(sys, state, jp.zeros(sys.act_size()))
test_utils.benchmark('generalized pipeline ant', init_fn, step_fn)
diff --git a/brax/generalized/pipeline.py b/brax/generalized/pipeline.py
index 9c9425e8..638f8c1f 100644
--- a/brax/generalized/pipeline.py
+++ b/brax/generalized/pipeline.py
@@ -67,7 +67,7 @@ def step(
state: physics state after step
"""
# calculate acceleration terms
- tau = actuator.to_tau(sys, act, state.q)
+ tau = actuator.to_tau(sys, act, state.q, state.qd)
state = state.replace(qf_smooth=dynamics.forward(sys, state, tau))
state = state.replace(qf_constraint=constraint.force(sys, state))
diff --git a/brax/generalized/pipeline_test.py b/brax/generalized/pipeline_test.py
index 1c17350d..b34f6817 100644
--- a/brax/generalized/pipeline_test.py
+++ b/brax/generalized/pipeline_test.py
@@ -20,6 +20,7 @@
from brax import test_utils
from brax.generalized import pipeline
import jax
+from jax import numpy as jp
import numpy as np
@@ -38,7 +39,7 @@ def test_forward(self, xml_file):
sys = sys.replace(solver_iterations=500)
for mj_prev, mj_next in test_utils.sample_mujoco_states(xml_file):
state = jax.jit(pipeline.init)(sys, mj_prev.qpos, mj_prev.qvel)
- state = jax.jit(pipeline.step)(sys, state, mj_prev.qfrc_applied)
+ state = jax.jit(pipeline.step)(sys, state, jp.zeros(sys.act_size()))
np.testing.assert_allclose(state.q, mj_next.qpos, atol=0.002)
np.testing.assert_allclose(state.qd, mj_next.qvel, atol=0.5)
diff --git a/brax/geometry/contact.py b/brax/geometry/contact.py
index 997c061e..b6c2a37c 100644
--- a/brax/geometry/contact.py
+++ b/brax/geometry/contact.py
@@ -20,6 +20,7 @@
from brax import math
from brax.base import (
Capsule,
+ Cylinder,
Contact,
Convex,
Geometry,
@@ -35,19 +36,6 @@
from jax import numpy as jp
-def _combine(
- geom_a: Geometry, geom_b: Geometry
-) -> Tuple[float, float, Tuple[int, int]]:
- # default is to take maximum, but can override
- friction = jp.maximum(geom_a.friction, geom_b.friction)
- elasticity = jp.maximum(geom_a.elasticity, geom_b.elasticity)
- link_idx = (
- geom_a.link_idx,
- geom_b.link_idx if geom_b.link_idx is not None else -1,
- )
- return friction, elasticity, link_idx # pytype: disable=bad-return-type # jax-ndarray
-
-
def _sphere_plane(sphere: Sphere, plane: Plane) -> Contact:
"""Calculates one contact between a sphere and a plane."""
n = math.rotate(jp.array([0.0, 0.0, 1.0]), plane.transform.rot)
@@ -55,8 +43,8 @@ def _sphere_plane(sphere: Sphere, plane: Plane) -> Contact:
penetration = sphere.radius - t
# halfway between contact points on sphere and on plane
pos = sphere.transform.pos - n * (sphere.radius - 0.5 * penetration)
- c = Contact(pos, n, penetration, *_combine(sphere, plane)) # pytype: disable=wrong-arg-types # jax-ndarray
- # add a batch dimension of size 1
+ c = Contact(pos, n, penetration, *_combine(sphere, plane))
+ # returns 1 contact, so add a batch dimension of size 1
return jax.tree_map(lambda x: jp.expand_dims(x, axis=0), c)
@@ -67,8 +55,8 @@ def _sphere_sphere(s_a: Sphere, s_b: Sphere) -> Contact:
s_a_pos = s_a.transform.pos - n * s_a.radius
s_b_pos = s_b.transform.pos + n * s_b.radius
pos = (s_a_pos + s_b_pos) * 0.5
- c = Contact(pos, n, penetration, *_combine(s_a, s_b)) # pytype: disable=wrong-arg-types # jax-ndarray
- # add a batch dimension of size 1
+ c = Contact(pos, n, penetration, *_combine(s_a, s_b))
+ # returns 1 contact, so add a batch dimension of size 1
return jax.tree_map(lambda x: jp.expand_dims(x, axis=0), c)
@@ -88,8 +76,57 @@ def _sphere_capsule(sphere: Sphere, capsule: Capsule) -> Contact:
cap_pos = pt + n * capsule.radius
pos = (sphere_pos + cap_pos) * 0.5
- c = Contact(pos, n, penetration, *_combine(sphere, capsule)) # pytype: disable=wrong-arg-types # jax-ndarray
- # add a batch dimension of size 1
+ c = Contact(pos, n, penetration, *_combine(sphere, capsule))
+ # returns 1 contact, so add a batch dimension of size 1
+ return jax.tree_map(lambda x: jp.expand_dims(x, axis=0), c)
+
+
+def _sphere_circle(sphere: Sphere, circle: Cylinder) -> Contact:
+ """Calculates one contact between a sphere and a circle."""
+ n = math.rotate(jp.array([0.0, 0.0, 1.0]), circle.transform.rot)
+
+ # orient the normal s.t. it points at the CoM of the sphere
+ normal_dir = jp.sign(
+ (sphere.transform.pos - circle.transform.pos).dot(n))
+ n = n * normal_dir
+
+ pos = sphere.transform.pos - n * sphere.radius
+ plane_pt = circle.transform.pos
+ penetration = jp.dot(plane_pt - pos, n)
+
+ # check if the sphere radius is within the cylinder in the normal dir of the
+ # circle
+ plane_pt2 = plane_pt + n
+ line_pt = geom_math.closest_line_point(
+ plane_pt, plane_pt2, sphere.transform.pos
+ )
+ in_cylinder = (sphere.transform.pos - line_pt).dot(
+ sphere.transform.pos - line_pt
+ ) <= circle.radius**2
+
+ # get closest point on circle edge
+ perp_dir = jp.cross(n, sphere.transform.pos - plane_pt)
+ perp_dir = math.rotate(perp_dir, math.quat_rot_axis(n, -jp.pi / 2.0))
+ perp_dir, _ = math.normalize(perp_dir)
+ edge_pt = plane_pt + perp_dir * circle.radius
+ edge_contact = (sphere.transform.pos - edge_pt).dot(
+ sphere.transform.pos - edge_pt
+ ) <= sphere.radius**2
+ edge_to_sphere = sphere.transform.pos - edge_pt
+ edge_to_sphere = math.normalize(edge_to_sphere)[0]
+
+ penetration = jp.where(in_cylinder, penetration, -jp.ones_like(penetration))
+ penetration = jp.where(
+ edge_contact,
+ sphere.radius
+ - jp.sqrt(
+ (sphere.transform.pos - edge_pt).dot(sphere.transform.pos - edge_pt)
+ ),
+ penetration,
+ )
+ n = jp.where(edge_contact, edge_to_sphere, n)
+ pos = jp.where(edge_contact, edge_pt, pos)
+ c = Contact(pos, n, penetration, *_combine(sphere, circle)) # pytype: disable=wrong-arg-types # jax-ndarray
return jax.tree_map(lambda x: jp.expand_dims(x, axis=0), c)
@@ -154,7 +191,7 @@ def get_support(faces, normal):
penetration = sphere.radius - d
pos = (pt + spt) * 0.5
- c = Contact(pos, n, penetration, *_combine(sphere, convex)) # pytype: disable=wrong-arg-types # jax-ndarray
+ c = Contact(pos, n, penetration, *_combine(sphere, convex))
return jax.tree_map(lambda x: jp.expand_dims(x, axis=0), c)
@@ -174,7 +211,7 @@ def sphere_face(face):
penetration = sphere.radius - dist
sph_p = sphere.transform.pos - n * sphere.radius
pos = (tri_p + sph_p) * 0.5
- return Contact(pos, n, penetration, *_combine(sphere, mesh)) # pytype: disable=wrong-arg-types # jax-ndarray
+ return Contact(pos, n, penetration, *_combine(sphere, mesh))
return sphere_face(jp.take(mesh.vert, mesh.face, axis=0))
@@ -187,11 +224,12 @@ def _capsule_plane(capsule: Capsule, plane: Plane) -> Contact:
results = []
for off in [segment, -segment]:
sphere = Sphere(
+ radius=capsule.radius,
link_idx=capsule.link_idx,
transform=Transform.create(pos=capsule.transform.pos + off),
friction=capsule.friction,
elasticity=capsule.elasticity,
- radius=capsule.radius,
+ solver_params=capsule.solver_params,
)
results.append(_sphere_plane(sphere, plane))
@@ -217,8 +255,8 @@ def _capsule_capsule(cap_a: Capsule, cap_b: Capsule) -> Contact:
cap_b_pos = pt_b + n * cap_b.radius
pos = (cap_a_pos + cap_b_pos) * 0.5
- c = Contact(pos, n, penetration, *_combine(cap_a, cap_b)) # pytype: disable=wrong-arg-types # jax-ndarray
- # add a batch dimension of size 1
+ c = Contact(pos, n, penetration, *_combine(cap_a, cap_b))
+ # returns 1 contact, so add a batch dimension of size 1
return jax.tree_map(lambda x: jp.expand_dims(x, axis=0), c)
@@ -306,10 +344,9 @@ def get_support(face, normal):
penetration = jp.where(
has_edge_contact, penetration.at[0].set(edge_penetration), penetration
)
- friction, elasticity, link_idx = jax.tree_map(
- lambda x: jp.repeat(x, 2), _combine(capsule, convex)
- )
- return Contact(pos, norm, penetration, friction, elasticity, link_idx)
+ tile_fn = lambda x: jp.tile(x, (2,) + tuple([1 for _ in x.shape]))
+ params = jax.tree_map(tile_fn, _combine(capsule, convex))
+ return Contact(pos, norm, penetration, *params)
def _capsule_mesh(capsule: Capsule, mesh: Mesh) -> Contact:
@@ -335,7 +372,7 @@ def capsule_face(face, face_norm):
penetration = capsule.radius - dist
cap_p = seg_p - n * capsule.radius
pos = (tri_p + cap_p) * 0.5
- return Contact(pos, n, penetration, *_combine(capsule, mesh)) # pytype: disable=wrong-arg-types # jax-ndarray
+ return Contact(pos, n, penetration, *_combine(capsule, mesh))
face_vert = jp.take(mesh.vert, mesh.face, axis=0)
face_norm = geom_mesh.get_face_norm(mesh.vert, mesh.face)
@@ -360,10 +397,9 @@ def transform_verts(vertices):
normal = jp.stack([n] * 4, axis=0)
unique = jp.tril(idx == idx[:, None]).sum(axis=1) == 1
penetration = jp.where(unique, support[idx], -1)
- friction, elasticity, link_idx = jax.tree_map(
- lambda x: jp.repeat(x, 4), _combine(convex, plane)
- )
- return Contact(pos, normal, penetration, friction, elasticity, link_idx)
+ tile_fn = lambda x: jp.tile(x, (4,) + tuple([1 for _ in x.shape]))
+ params = jax.tree_map(tile_fn, _combine(convex, plane))
+ return Contact(pos, normal, penetration, *params)
def _convex_convex(convex_a: Convex, convex_b: Convex) -> Contact:
@@ -416,18 +452,10 @@ def transform_verts(convex, vertices):
unique_edges_a,
unique_edges_b,
)
- friction, elasticity, link_idx = jax.tree_map(
- lambda x: jp.repeat(x, 4), _combine(convex_a, convex_b)
- )
+ tile_fn = lambda x: jp.tile(x, (4,) + tuple([1 for _ in x.shape]))
+ params = jax.tree_map(tile_fn, _combine(convex_a, convex_b))
- return Contact(
- c.pos,
- c.normal,
- c.penetration,
- friction,
- elasticity,
- link_idx,
- )
+ return Contact(c.pos, c.normal, c.penetration, *params)
def _mesh_plane(mesh: Mesh, plane: Plane) -> Contact:
@@ -438,15 +466,30 @@ def point_plane(vert):
n = math.rotate(jp.array([0.0, 0.0, 1.0]), plane.transform.rot)
pos = mesh.transform.pos + math.rotate(vert, mesh.transform.rot)
penetration = jp.dot(plane.transform.pos - pos, n)
- return Contact(pos, n, penetration, *_combine(mesh, plane)) # pytype: disable=wrong-arg-types # jax-ndarray
+ return Contact(pos, n, penetration, *_combine(mesh, plane))
return point_plane(mesh.vert)
+def _combine(
+ geom_a: Geometry, geom_b: Geometry
+) -> Tuple[jp.ndarray, jp.ndarray, jp.ndarray, Tuple[jp.ndarray, jp.ndarray]]:
+ # default is to take maximum, but can override
+ friction = jp.maximum(geom_a.friction, geom_b.friction)
+ elasticity = (geom_a.elasticity + geom_b.elasticity) * 0.5
+ solver_params = (geom_a.solver_params + geom_b.solver_params) * 0.5
+ link_idx = (
+ jp.array(geom_a.link_idx if geom_a.link_idx is not None else -1),
+ jp.array(geom_b.link_idx if geom_b.link_idx is not None else -1),
+ )
+ return friction, elasticity, solver_params, link_idx
+
+
_TYPE_FUN = {
(Sphere, Plane): _sphere_plane,
(Sphere, Sphere): _sphere_sphere,
(Sphere, Capsule): _sphere_capsule,
+ (Sphere, Cylinder): _sphere_circle,
(Sphere, Convex): _sphere_convex,
(Sphere, Mesh): _sphere_mesh,
(Capsule, Plane): _capsule_plane,
diff --git a/brax/geometry/contact_test.py b/brax/geometry/contact_test.py
index cccd7373..95532d59 100644
--- a/brax/geometry/contact_test.py
+++ b/brax/geometry/contact_test.py
@@ -91,6 +91,28 @@ def test_sphere_capsule(self):
np.testing.assert_array_almost_equal(c.pos, jp.array([0.0, 0.3, 0.045]))
np.testing.assert_array_almost_equal(c.normal, jp.array([0, 0.0, -1.0]))
+ _SPHERE_CYLINDER = """
+
+
+
+
+
+
+
+
+
+
+
+
+ """
+
+ def test_sphere_cylinder(self):
+ sys = mjcf.loads(self._SPHERE_CYLINDER)
+ x, _ = kinematics.forward(sys, sys.init_q, jp.zeros(sys.qd_size()))
+ c = geometry.contact(sys, x).take(0)
+
+ np.testing.assert_array_almost_equal(c.penetration, 0.01)
+
_SPHERE_CONVEX = """
diff --git a/brax/geometry/math.py b/brax/geometry/math.py
index 24d5041d..bfb06a71 100644
--- a/brax/geometry/math.py
+++ b/brax/geometry/math.py
@@ -40,6 +40,15 @@ def closest_segment_point_and_dist(
return closest, dist
+def closest_line_point(
+ a: jp.ndarray, b: jp.ndarray, pt: jp.ndarray
+) -> jp.ndarray:
+ """Returns the closest point on the a-b line to a point pt."""
+ ab = b - a
+ t = jp.dot(pt - a, ab) / (jp.dot(ab, ab) + 1e-6)
+ return a + t * ab
+
+
def closest_segment_to_segment_points(
a0: jp.ndarray, a1: jp.ndarray, b0: jp.ndarray, b1: jp.ndarray
) -> Tuple[jp.ndarray, jp.ndarray]:
@@ -432,6 +441,7 @@ def _create_contact_manifold(
penetration=penetration,
friction=jp.array([]),
elasticity=jp.array([]),
+ solver_params=jp.array([]),
link_idx=jp.array([]),
)
diff --git a/brax/geometry/mesh.py b/brax/geometry/mesh.py
index 89b25cba..9a9afb41 100644
--- a/brax/geometry/mesh.py
+++ b/brax/geometry/mesh.py
@@ -128,6 +128,7 @@ def box_tri(b: Box) -> Mesh:
transform=b.transform,
friction=b.friction,
elasticity=b.elasticity,
+ solver_params=b.solver_params,
rgba=b.rgba,
)
@@ -142,6 +143,7 @@ def _box_hull(b: Box) -> Convex:
transform=b.transform,
friction=b.friction,
elasticity=b.elasticity,
+ solver_params=b.solver_params,
unique_edge=get_unique_edges(vert, face),
rgba=b.rgba,
)
@@ -225,6 +227,7 @@ def _convex_hull(m: Mesh) -> Convex:
transform=m.transform,
friction=m.friction,
elasticity=m.elasticity,
+ solver_params=m.solver_params,
unique_edge=get_unique_edges(vert, face),
rgba=m.rgba,
)
@@ -244,6 +247,7 @@ def convex_hull(obj: Union[Box, Mesh]) -> Convex:
transform=obj.transform,
friction=obj.friction,
elasticity=obj.elasticity,
+ solver_params=obj.solver_params,
rgba=obj.rgba,
)
return convex
diff --git a/brax/geometry/mesh_test.py b/brax/geometry/mesh_test.py
index 6a0f5cdc..d84772a2 100644
--- a/brax/geometry/mesh_test.py
+++ b/brax/geometry/mesh_test.py
@@ -31,6 +31,7 @@ def test_box(self):
transform=None,
friction=0.42,
elasticity=1,
+ solver_params=None,
)
m = mesh.box_tri(b)
self.assertIsInstance(m, Mesh)
@@ -64,6 +65,7 @@ def test_box_hull(self):
transform=None,
friction=0.42,
elasticity=1,
+ solver_params=None,
)
h = mesh.convex_hull(b)
self.assertIsInstance(h, Convex)
@@ -106,6 +108,7 @@ def test_pyramid(self):
face=face,
friction=1,
elasticity=0,
+ solver_params=None,
)
h = mesh.convex_hull(pyramid)
diff --git a/brax/io/json.py b/brax/io/json.py
index 56d7208b..015eee58 100644
--- a/brax/io/json.py
+++ b/brax/io/json.py
@@ -18,7 +18,6 @@
import json
from typing import List, Text
-from brax import geometry
from brax.base import State, System
from etils import epath
import jax
diff --git a/brax/io/mjcf.py b/brax/io/mjcf.py
index c646bba9..c1f086c4 100644
--- a/brax/io/mjcf.py
+++ b/brax/io/mjcf.py
@@ -25,6 +25,7 @@
Actuator,
Box,
Capsule,
+ Cylinder,
DoF,
Inertia,
Link,
@@ -43,13 +44,6 @@
import numpy as np
-# map from mujoco bias type to brax actuator type string
-_ACT_TYPE_STR = {
- 0: 'm', # motor
- 1: 'p', # position
-}
-
-
def _transform_do(
pos: np.ndarray, quat: np.ndarray, cpos: np.ndarray, cquat: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
@@ -251,6 +245,10 @@ def load_model(mj: mujoco.MjModel) -> System:
raise NotImplementedError(
'Only joint transmission types are supported for actuators.'
)
+ if (mj.geom_solmix[0] != mj.geom_solmix).any():
+ raise NotImplementedError('geom_solmix parameter not supported.')
+ if (mj.geom_priority[0] != mj.geom_priority).any():
+ raise NotImplementedError('geom_priority parameter not supported.')
if mj.opt.collision == 1:
raise NotImplementedError('Predefined collisions not supported.')
q_width = {0: 7, 1: 4, 2: 1, 3: 1}
@@ -258,7 +256,10 @@ def load_model(mj: mujoco.MjModel) -> System:
if mj.qpos0[non_free].any():
raise NotImplementedError(
'The `ref` attribute on joint types is not supported.')
-
+ if (mj.geom_fluid != 0).any():
+ raise NotImplementedError('Ellipsoid fluid model not implemented.')
+ if mj.opt.wind.any():
+ raise NotImplementedError('option.wind is not implemented.')
custom = _get_custom(mj)
# create links
@@ -326,22 +327,25 @@ def load_model(mj: mujoco.MjModel) -> System:
if np.any(mj.jnt_limited):
limit = jax.tree_map(lambda *x: np.concatenate(x), *limits)
stiffness = np.concatenate(stiffnesses)
+ solver_params_jnt = np.concatenate((mj.jnt_solref, mj.jnt_solimp), axis=1)
+ solver_params_dof = solver_params_jnt[mj.dof_jntid]
- dof = DoF( # pytype: disable=wrong-arg-types # jax-ndarray
+ dof = DoF( # pytype: disable=wrong-arg-types
motion=motion,
armature=mj.dof_armature,
stiffness=stiffness,
damping=mj.dof_damping,
limit=limit,
invweight=mj.dof_invweight0,
+ solver_params=solver_params_dof,
)
+ solver_params_geom = np.concatenate((mj.geom_solref, mj.geom_solimp), axis=1)
# group geoms so that they can be stacked. two geoms can be stacked if:
# - they have the same type
# - their fields have the same shape (e.g. Mesh verts might vary)
# - they have the same mask
key_fn = lambda g, m: (jax.tree_map(np.shape, g), m)
-
geom_groups = {}
for i, typ in enumerate(mj.geom_type):
rgba = mj.geom_rgba[i]
@@ -353,12 +357,21 @@ def load_model(mj: mujoco.MjModel) -> System:
'transform': Transform(pos=mj.geom_pos[i], rot=mj.geom_quat[i]),
'friction': mj.geom_friction[i, 0],
'elasticity': custom['elasticity'][i],
+ 'solver_params': solver_params_geom[i],
'rgba': rgba,
}
mask = mj.geom_contype[i] | mj.geom_conaffinity[i] << 32
if typ == 0: # Plane
geom = Plane(**kwargs)
geom_groups.setdefault(key_fn(geom, mask), []).append(geom)
+ elif typ == 5: # Cylinder
+ radius, halflength = mj.geom_size[i, 0:2]
+ if halflength > 0.001 and mask > 0:
+ # TODO: support cylinders with depth.
+ raise NotImplementedError(
+ 'Cylinders of half-length>0.001 are not supported for collision.')
+ geom = Cylinder(radius=radius, length=halflength * 2, **kwargs)
+ geom_groups.setdefault(key_fn(geom, mask), []).append(geom)
elif typ == 2: # Sphere
geom = Sphere(radius=mj.geom_size[i, 0], **kwargs)
geom_groups.setdefault(key_fn(geom, mask), []).append(geom)
@@ -386,30 +399,28 @@ def load_model(mj: mujoco.MjModel) -> System:
continue
geoms = [
- jax.tree_map(lambda *x: jp.stack(x), *g) for g in geom_groups.values()
+ jax.tree_map(lambda *x: np.stack(x), *g) for g in geom_groups.values()
]
geom_masks = [m for _, m in geom_groups.keys()]
# create actuators
ctrl_range = mj.actuator_ctrlrange
ctrl_range[~(mj.actuator_ctrllimited == 1), :] = np.array([-np.inf, np.inf])
- actuator = Actuator( # pytype: disable=wrong-arg-types # jax-ndarray
+ q_id = np.array([mj.jnt_qposadr[i] for i in mj.actuator_trnid[:, 0]])
+ qd_id = np.array([mj.jnt_dofadr[i] for i in mj.actuator_trnid[:, 0]])
+ bias_q = mj.actuator_biasprm[:, 1]
+ bias_qd = mj.actuator_biasprm[:, 2]
+
+ # TODO: might be nice to add actuator names for debugging
+ actuator = Actuator( # pytype: disable=wrong-arg-types
+ q_id=q_id,
+ qd_id=qd_id,
gear=mj.actuator_gear[:, 0],
ctrl_range=ctrl_range,
+ bias_q=bias_q,
+ bias_qd=bias_qd,
)
- # create generalized solver params
- params_joint = jp.concatenate((mj.jnt_solref, mj.jnt_solimp), axis=1)
- params_geom = jp.concatenate((mj.geom_solref, mj.geom_solimp), axis=1)
- params_pair = jp.concatenate((mj.pair_solref, mj.pair_solimp), axis=1)
- params_contact = jp.concatenate((params_geom, params_pair))
- if (params_joint[0] != params_joint).any():
- raise NotImplementedError('brax only supports one joint solver params')
- if (params_contact[0] != params_contact).any():
- raise NotImplementedError('brax only supports one contact solver params')
- solver_params_joint = params_joint[0]
- solver_params_contact = params_contact[0]
-
# create non-pytree params. these do not live on device directly, and they
# cannot be differentiated, but they do change the emitted control flow
link_names = [_get_name(mj, i) for i in mj.name_bodyadr[1:]]
@@ -430,21 +441,6 @@ def load_model(mj: mujoco.MjModel) -> System:
link_types += typ
link_parents = tuple(mj.body_parentid - 1)[1:]
- # create non-pytree params for actuators.
- actuator_types = ''.join([_ACT_TYPE_STR[bt] for bt in mj.actuator_biastype])
- actuator_link_id = [mj.jnt_bodyid[i] - 1 for i in mj.actuator_trnid[:, 0]]
- unsupported_act_links = set(link_types[i] for i in actuator_link_id) - {
- '1',
- '2',
- '3',
- }
- if unsupported_act_links:
- raise NotImplementedError(
- f'Link types {unsupported_act_links} are not supported for actuators.'
- )
- actuator_qid = [mj.jnt_qposadr[i] for i in mj.actuator_trnid[:, 0]]
- actuator_qdid = [mj.jnt_dofadr[i] for i in mj.actuator_trnid[:, 0]]
-
# mujoco stores free q in world frame, so clear link transform for free links
if 'f' in link_types:
free_idx = np.array([i for i, typ in enumerate(link_types) if typ == 'f'])
@@ -454,13 +450,13 @@ def load_model(mj: mujoco.MjModel) -> System:
sys = System( # pytype: disable=wrong-arg-types # jax-ndarray
dt=mj.opt.timestep,
gravity=mj.opt.gravity,
+ viscosity=mj.opt.viscosity,
+ density=mj.opt.density,
link=link,
dof=dof,
geoms=geoms,
actuator=actuator,
init_q=custom['init_qpos'] if 'init_qpos' in custom else mj.qpos0,
- solver_params_joint=solver_params_joint,
- solver_params_contact=solver_params_contact,
vel_damping=custom['vel_damping'],
ang_damping=custom['ang_damping'],
baumgarte_erp=custom['baumgarte_erp'],
@@ -473,10 +469,6 @@ def load_model(mj: mujoco.MjModel) -> System:
link_names=link_names,
link_types=link_types,
link_parents=link_parents,
- actuator_types=actuator_types,
- actuator_link_id=actuator_link_id,
- actuator_qid=actuator_qid,
- actuator_qdid=actuator_qdid,
matrix_inv_iterations=int(custom['matrix_inv_iterations']),
solver_iterations=mj.opt.iterations,
solver_maxls=int(custom['solver_maxls']),
diff --git a/brax/io/mjcf_test.py b/brax/io/mjcf_test.py
index dde7a0b5..02e2434e 100644
--- a/brax/io/mjcf_test.py
+++ b/brax/io/mjcf_test.py
@@ -171,6 +171,34 @@ def test_world_body_transform(self):
sys.init_q, np.array([1.245, 0.0, 0.0, 0.5, 0.5, 0.5, -0.5])
)
+ def test_load_flat_cylinder(self):
+ sys = test_utils.load_fixture('flat_cylinder.xml')
+ self.assertEqual(sys.geoms[1].radius, 0.25)
+ self.assertEqual(sys.geoms[1].length, 0.002)
+
+ def test_load_fat_cylinder(self):
+ with self.assertRaisesRegex(
+ NotImplementedError, 'Cylinders of half-length'
+ ):
+ test_utils.load_fixture('fat_cylinder.xml')
+
+ def test_load_fluid_box(self):
+ sys = test_utils.load_fixture('fluid_box.xml')
+ assert_almost_equal(sys.density, 1.2)
+ assert_almost_equal(sys.viscosity, 0.15)
+
+ def test_load_fluid_ellipsoid(self):
+ with self.assertRaisesRegex(
+ NotImplementedError, 'Ellipsoid fluid model not implemented'
+ ):
+ test_utils.load_fixture('fluid_ellipsoid.xml')
+
+ def test_load_wind(self):
+ with self.assertRaisesRegex(
+ NotImplementedError, 'option.wind is not implemented'
+ ):
+ test_utils.load_fixture('fluid_wind.xml')
+
if __name__ == '__main__':
absltest.main()
diff --git a/brax/kinematics.py b/brax/kinematics.py
index f52d5e19..a3b89497 100644
--- a/brax/kinematics.py
+++ b/brax/kinematics.py
@@ -70,6 +70,7 @@ def jcalc(typ, q, qd, motion):
for i in range(1, num_dofs):
j_i, jd_i = j_stack.take(i, axis=1), jd_stack.take(i, axis=1)
j = j.vmap().do(j_i)
+ # TODO: fix qd->jd calculation for stacked/offset joints
jd = jd + Motion(
ang=jax.vmap(math.rotate)(jd_i.ang, j_i.rot),
vel=jax.vmap(math.rotate)(
@@ -88,16 +89,16 @@ def jcalc(typ, q, qd, motion):
def world(parent, j, jd):
"""Convert transform/motion from joint frame to world frame."""
if parent is None:
+ jd = jd.replace(ang=jax.vmap(math.rotate)(jd.ang, j.rot))
return j, jd
- x, xd = parent
- # TODO: determine why the motion `do` is inverted
- x = x.vmap().do(j)
- xd = xd + Motion(
- ang=jax.vmap(math.rotate)(jd.ang, x.rot),
- vel=jax.vmap(math.rotate)(
- jd.vel + jax.vmap(jp.cross)(x.pos, jd.ang), x.rot
- ),
- )
+ x_p, xd_p = parent
+ x = x_p.vmap().do(j)
+ # get the linear velocity at the tip of the parent
+ vel = xd_p.vel + jax.vmap(jp.cross)(xd_p.ang, x.pos - x_p.pos)
+ # add in the child linear velocity in the world frame
+ vel += jax.vmap(math.rotate)(jd.vel, x_p.rot)
+ ang = xd_p.ang + jax.vmap(math.rotate)(jd.ang, x.rot)
+ xd = Motion(vel=vel, ang=ang)
return x, xd
x, xd = scan.tree(sys, world, 'll', j, jd)
@@ -126,7 +127,7 @@ def world_to_joint(
# move into joint coordinates
xd_joint = xd - xd_wj
- inv_rotate = jax.vmap(lambda x, y: math.rotate(x, math.quat_inv(y)))
+ inv_rotate = jax.vmap(math.inv_rotate)
jd = jax.tree_map(lambda x: inv_rotate(x, a_p.rot), xd_joint)
return j, jd, a_p, a_c
@@ -158,6 +159,9 @@ def link_to_joint_frame(motion: Motion) -> Tuple[Motion, float]:
We also need translational components because the prismatic components of a
joint might not be aligned with the rotational components of the joint.
"""
+ if motion.ang.shape[0] > 3 or motion.ang.shape[0] == 0:
+ raise AssertionError('Motion shape must be in (0, 3], '
+ f'got {motion.ang.shape[0]}')
# 1-dof
if motion.ang.shape[0] == 1:
@@ -282,16 +286,14 @@ def axis_angle_ang(
child_frame = v_rot(joint_motion.ang, j.rot)
line_of_nodes = jp.cross(child_frame[2], joint_motion.ang[0])
- line_of_nodes = line_of_nodes / (1e-10 + math.safe_norm(line_of_nodes))
+ line_of_nodes, _ = math.normalize(line_of_nodes)
y_n_normal = joint_motion.ang[0]
psi = math.signed_angle(y_n_normal, joint_motion.ang[1], line_of_nodes)
axis_1_p_in_xz_c = (
jp.dot(joint_motion.ang[0], child_frame[0]) * child_frame[0]
+ jp.dot(joint_motion.ang[0], child_frame[1]) * child_frame[1]
)
- axis_1_p_in_xz_c = axis_1_p_in_xz_c / (
- 1e-10 + math.safe_norm(axis_1_p_in_xz_c)
- )
+ axis_1_p_in_xz_c, _ = math.normalize(axis_1_p_in_xz_c)
ang_between_1_p_xz_c = jp.dot(axis_1_p_in_xz_c, joint_motion.ang[0])
theta = math.safe_arccos(jp.clip(ang_between_1_p_xz_c, -1, 1)) * jp.sign(
jp.dot(joint_motion.ang[0], child_frame[2])
@@ -331,10 +333,13 @@ def inverse(
) -> Tuple[jp.ndarray, jp.ndarray]:
"""Translates maximal coordinates into reduced coordinates."""
- def free(x, xd, _):
- return jp.concatenate([x.pos, x.rot]), jp.concatenate([xd.vel, xd.ang])
+ def free(x, xd, *_):
+ ang = math.inv_rotate(xd.ang, x.rot)
+ return jp.concatenate([x.pos, x.rot]), jp.concatenate([xd.vel, ang])
- def x_dof(j, jd, motion, x):
+ def x_dof(j, jd, parent_idx, motion, x):
+ j_rot = jp.where(parent_idx == -1, j.rot, jp.array([1.0, 0.0, 0.0, 0.0]))
+ jd = jd.replace(ang=math.inv_rotate(jd.ang, j_rot))
joint_frame, parity = link_to_joint_frame(motion)
axis, angles, _ = axis_angle_ang(j, joint_frame, parity)
angle_vels = jax.tree_map(lambda x: jp.dot(x, jd.ang), axis)
@@ -350,7 +355,7 @@ def x_dof(j, jd, motion, x):
)
return q, qd
- def q_fn(typ, j, jd, motion):
+ def q_fn(typ, j, jd, parent_idx, motion):
motion = jax.tree_map(
lambda y: y.reshape((-1, base.QD_WIDTHS[typ], 3)), motion
)
@@ -361,10 +366,12 @@ def q_fn(typ, j, jd, motion):
'3': functools.partial(x_dof, x=3),
}
- q, qd = jax.vmap(q_fn_map[typ])(j, jd, motion)
+ q, qd = jax.vmap(q_fn_map[typ])(j, jd, parent_idx, motion)
# transposed to preserve order of outputs
return jp.array(q).reshape(-1), jp.array(qd).reshape(-1)
- q, qd = scan.link_types(sys, q_fn, 'lld', 'qd', j, jd, sys.dof.motion)
+ parent_idx = jp.array(sys.link_parents)
+ q, qd = scan.link_types(sys, q_fn, 'llld', 'qd', j, jd, parent_idx,
+ sys.dof.motion)
return q, qd
diff --git a/brax/kinematics_test.py b/brax/kinematics_test.py
index 7101074b..e3d39fb8 100644
--- a/brax/kinematics_test.py
+++ b/brax/kinematics_test.py
@@ -21,6 +21,7 @@
from brax import kinematics
from brax import scan
from brax import test_utils
+from brax.base import Motion
import jax
import jax.numpy as jp
import numpy as np
@@ -31,14 +32,33 @@ class KinematicsTest(parameterized.TestCase):
@parameterized.parameters(
('ant.xml',), ('humanoid.xml',), ('reacher.xml',), ('half_cheetah.xml',)
)
- def test_forward_q(self, xml_file):
+ def test_forward(self, xml_file):
"""Test dynamics forward q."""
sys = test_utils.load_fixture(xml_file)
- for mj_prev, mj_next in test_utils.sample_mujoco_states(xml_file):
- x, _ = jax.jit(kinematics.forward)(sys, mj_prev.qpos, mj_prev.qvel)
+
+ for mj_prev, mj_next in test_utils.sample_mujoco_states(
+ xml_file, random_init=True, vel_to_local=False):
+ x, xd = jax.jit(kinematics.forward)(sys, mj_prev.qpos, mj_prev.qvel)
+
np.testing.assert_almost_equal(x.pos, mj_next.xpos[1:], 3)
+ # handle quat rotations +/- 2pi
+ quat_sign = np.allclose(
+ np.sum(mj_next.xquat[1:]) - np.sum(x.rot), 0, atol=1e-2)
+ quat_sign = 1 if quat_sign else -1
+ x = x.replace(rot=x.rot * quat_sign)
np.testing.assert_almost_equal(x.rot, mj_next.xquat[1:], 3)
+ # xd vel/ang were added to linvel/angmom in `sample_mujoco_states`
+ xd_mj = Motion(
+ vel=mj_next.subtree_linvel[1:], ang=mj_next.subtree_angmom[1:])
+
+ if xml_file == 'humanoid.xml':
+ # TODO: get forward to match MJ for stacked/offset joints
+ return
+
+ np.testing.assert_array_almost_equal(xd.ang, xd_mj.ang, 3)
+ np.testing.assert_array_almost_equal(xd.vel, xd_mj.vel, 3)
+
def test_init_q(self):
sys = test_utils.load_fixture('ant.xml')
np.testing.assert_almost_equal(
@@ -53,10 +73,11 @@ def test_init_q(self):
def test_inverse(self, xml_file):
np.random.seed(0)
sys = test_utils.load_fixture(xml_file)
- # # test at random init
+ # test at random init
rand_q = np.random.rand(sys.init_q.shape[0])
- # normalize quaternion part of init_q
- rand_q[3:7] = rand_q[3:7] / np.linalg.norm(rand_q[3:7])
+ if sys.link_types[0] == 'f':
+ # normalize quaternion part of init_q
+ rand_q[3:7] = rand_q[3:7] / np.linalg.norm(rand_q[3:7])
rand_q = jp.array(rand_q)
rand_qd = jp.array(np.random.rand(sys.qd_size())) * 0.1
@@ -94,10 +115,11 @@ def _collect_frame(typ, motion):
),
),
armature=np.array([0.0, 0.0, 0.0, 0.0]),
- invweight=np.array([0.0, 0.0, 0.0, 0.0]),
stiffness=np.array([0.0, 0.0, 0.0, 0.0]),
damping=np.array([0.0, 0.0, 0.0, 0.0]),
limit=None,
+ invweight=np.array([0.0, 0.0, 0.0, 0.0]),
+ solver_params=np.zeros((4, 7)),
)
)
@@ -129,10 +151,11 @@ def _collect_frame(typ, motion):
]),
),
armature=np.array([0.0, 0.0, 0.0, 0.0]),
- invweight=np.array([0.0, 0.0, 0.0, 0.0]),
stiffness=np.array([0.0, 0.0, 0.0, 0.0]),
damping=np.array([0.0, 0.0, 0.0, 0.0]),
limit=None,
+ invweight=np.array([0.0, 0.0, 0.0, 0.0]),
+ solver_params=np.zeros((4, 7)),
)
)
diff --git a/brax/math.py b/brax/math.py
index da3f33fa..9725d4e8 100644
--- a/brax/math.py
+++ b/brax/math.py
@@ -40,6 +40,19 @@ def rotate(vec: jp.ndarray, quat: jp.ndarray):
return r
+def inv_rotate(vec: jp.ndarray, quat: jp.ndarray):
+ """Rotates a vector vec by an inverted unit quaternion quat.
+
+ Args:
+ vec: (3,) a vector
+ quat: (4,) a quaternion
+
+ Returns:
+ ndarray(3) containing vec rotated by the inverse of quat.
+ """
+ return rotate(vec, quat_inv(quat))
+
+
def rotate_np(vec: np.ndarray, quat: np.ndarray):
"""Rotates a vector vec by a unit quaternion quat.
diff --git a/brax/positional/perf_test.py b/brax/positional/perf_test.py
index 65d10a0c..ef1b4f47 100644
--- a/brax/positional/perf_test.py
+++ b/brax/positional/perf_test.py
@@ -34,7 +34,7 @@ def init_fn(rng):
return pipeline.init(sys, q, qd)
def step_fn(state):
- return pipeline.step(sys, state, jp.zeros(sys.qd_size()))
+ return pipeline.step(sys, state, jp.zeros(sys.act_size()))
test_utils.benchmark('pbd pipeline ant', init_fn, step_fn)
diff --git a/brax/positional/pipeline.py b/brax/positional/pipeline.py
index e00a0648..eb2fcdce 100644
--- a/brax/positional/pipeline.py
+++ b/brax/positional/pipeline.py
@@ -70,7 +70,7 @@ def step(
x_i_prev = state.x_i
# calculate acceleration level updates
- tau = actuator.to_tau(sys, act, state.q)
+ tau = actuator.to_tau(sys, act, state.q, state.qd)
xdd_i = Motion.create(vel=sys.gravity)
xdd_i += joints.acceleration_update(sys, state, tau)
diff --git a/brax/positional/pipeline_test.py b/brax/positional/pipeline_test.py
index 12a74709..82ed5d6c 100644
--- a/brax/positional/pipeline_test.py
+++ b/brax/positional/pipeline_test.py
@@ -16,8 +16,10 @@
"""Tests for spring physics pipeline."""
from absl.testing import absltest
+from brax import com
from brax import kinematics
from brax import test_utils
+from brax.base import Transform
from brax.generalized import pipeline as g_pipeline
from brax.positional import pipeline
import jax
@@ -36,7 +38,7 @@ def test_pendulum(self):
state = pipeline.init(sys, sys.init_q, jp.zeros(sys.qd_size()))
j_pos_step = jax.jit(pipeline.step)
for _ in range(2_000):
- state = j_pos_step(sys, state, jp.zeros(sys.qd_size()))
+ state = j_pos_step(sys, state, jp.zeros(sys.act_size()))
x = state.x
# compare against generalized step
@@ -45,7 +47,7 @@ def test_pendulum(self):
j_g_step = jax.jit(g_pipeline.step)
j_forward = jax.jit(kinematics.forward)
for _ in range(2_000):
- state = j_g_step(sys, state, jp.zeros(sys.qd_size()))
+ state = j_g_step(sys, state, jp.zeros(sys.act_size()))
x_g, _ = j_forward(sys, state.q, state.qd)
# trajectories should be close after .1 second of simulation
@@ -66,18 +68,24 @@ def test_spherical_pendulum(self):
sys = sys.replace(solver_iterations=500)
state = pipeline.init(sys, init_q, init_qd)
+ # the qd calculation for pbd/spring doesn't match generalized, so we get xd
+ # from generalized and plug it back into pbd
+ # TODO: remove this xd override once kinematics.forward is fixed
+ state_g = g_pipeline.init(sys, init_q, init_qd)
+ xd = Transform.create(pos=state_g.x.pos - state_g.com).vmap().do(state_g.cd)
+ state = state.replace(xd=xd, xd_i=com.from_world(sys, state.x, xd)[1])
+
j_pos_step = jax.jit(pipeline.step)
for _ in range(1000):
- state = j_pos_step(sys, state, jp.zeros(sys.qd_size()))
+ state = j_pos_step(sys, state, jp.zeros(sys.act_size()))
x = state.x
# compare against generalized step
- q, qd = init_q, init_qd
- state = g_pipeline.init(sys, q, qd)
j_g_step = jax.jit(g_pipeline.step)
j_forward = jax.jit(kinematics.forward)
+ state = state_g
for _ in range(1000):
- state = j_g_step(sys, state, jp.zeros(sys.qd_size()))
+ state = j_g_step(sys, state, jp.zeros(sys.act_size()))
x_g, _ = j_forward(sys, state.q, state.qd)
# trajectories should be close after 1 second of simulation
@@ -97,7 +105,7 @@ def test_3d_sliding_joint(self):
j_pos_step = jax.jit(pipeline.step)
states = []
for _ in range(1000):
- state = j_pos_step(sys, state, jp.zeros(sys.qd_size()))
+ state = j_pos_step(sys, state, jp.zeros(sys.act_size()))
states.append(state)
x, xd = state.x, state.xd
@@ -124,7 +132,7 @@ def test_3d_prismaversal_joint(self):
j_pos_step = jax.jit(pipeline.step)
states = []
for _ in range(1000):
- state = j_pos_step(sys, state, jp.zeros(sys.qd_size()))
+ state = j_pos_step(sys, state, jp.zeros(sys.act_size()))
states.append(state)
# reflects off limits and is still traveling close to 2.5 m/s
@@ -148,7 +156,7 @@ def test_sliding_capsule(self):
state = pipeline.init(sys, sys.init_q, qd)
j_pos_step = jax.jit(pipeline.step)
for _ in range(1000):
- state = j_pos_step(sys, state, jp.zeros(sys.qd_size()))
+ state = j_pos_step(sys, state, jp.zeros(sys.act_size()))
x, xd = state.x, state.xd
# capsule slides to a stop
diff --git a/brax/scan.py b/brax/scan.py
index 98e03cda..96e3837f 100644
--- a/brax/scan.py
+++ b/brax/scan.py
@@ -191,56 +191,3 @@ def f(typ, *args) -> y
y = out_ys[0] if len(out_types) == 1 else out_ys
return y
-
-
-def actuator_types(
- sys: System, f: Callable[..., Y], in_types: str, out_type: str, *args
-) -> Y:
- r"""Scan a function over System actuator type ranges.
-
- Args:
- sys: system defining the kinematic tree and other properties
- f: a function to be scanned with the following type signature:\
- def f(typ, link, q, qd) -> y
- where
- ``typ`` is the actuator, link type string
- ``*args`` are input arguments with types matching ``in_types``
- ``y`` is an output arguments with types matching ``out_type``
- in_types: string specifying the type of each input arg:
- 'a' is an input to be split according to act ranges
- 'l' is an input to be split according to link ranges
- 'q' is an input to be split according to q ranges
- 'd' is an input to be split according to qd ranges
- out_type: string specifying the type of the output
- *args: the input arguments corresponding to ``in_types``
-
- Returns:
- The stacked outputs of ``f`` matching the system actuator order.
- """
- typ_order = sorted(set(sys.actuator_types), key=sys.actuator_types.find)
-
- typ_order_idxs = []
- for i, t in enumerate(sys.actuator_types):
- order = typ_order.index(t)
- while order >= len(typ_order_idxs):
- typ_order_idxs.append({'a': [], 'l': [], 'q': [], 'd': []})
- typ_order_idxs[order]['a'].append(i)
- typ_order_idxs[order]['l'].append(sys.actuator_link_id[i])
- typ_order_idxs[order]['q'].append(sys.actuator_qid[i])
- typ_order_idxs[order]['d'].append(sys.actuator_qdid[i])
-
- ys = []
- for typ, typ_idxs in zip(typ_order, typ_order_idxs):
- in_args = [_take(a, typ_idxs[t]) for a, t in zip(args, in_types)]
- ys.append(f(typ, *in_args))
-
- y = jax.tree_map(lambda *x: jp.concatenate(x), *ys)
-
- # we concatenated results out of order, so put back in order if needed
-
- order = sum([t[out_type] for t in typ_order_idxs], [])
-
- if order != list(range(len(order))):
- y = _take(y, [order.index(i) for i in range(len(order))])
-
- return y
diff --git a/brax/scan_test.py b/brax/scan_test.py
index df8aa346..4821276b 100644
--- a/brax/scan_test.py
+++ b/brax/scan_test.py
@@ -16,7 +16,6 @@
"""Tests for scan functions."""
from absl.testing import absltest
-from absl.testing import parameterized
from brax import scan
from brax import test_utils
import numpy as np
@@ -155,65 +154,5 @@ def f(typ, link, q, qd):
np.testing.assert_array_equal(qds[1], np.arange(6, 14))
-class ParametrizedScanTest(parameterized.TestCase):
-
- @parameterized.parameters(
- (
- 'single_spherical_pendulum_position.xml',
- ['p'],
- [0, 1, 2], # act_id
- [0, 0, 0], # act_link_id
- [2, 0, 1], # q_id
- [2, 0, 1], # qd_id
- ),
- (
- 'ant.xml',
- ['m'],
- list(range(8)),
- [7, 8, 1, 2, 3, 4, 5, 6],
- [13, 14, 7, 8, 9, 10, 11, 12],
- [12, 13, 6, 7, 8, 9, 10, 11],
- ),
- )
- def test_scan_actuator_types(
- self, fname, act_typs, act_id, act_link_id, q_id, qd_id
- ):
- """Test scanning actuators."""
- sys = test_utils.load_fixture(fname)
-
- typs, links, qs, qds = [], [], [], []
-
- def f(typ, act, link, q, qd):
- typs.append(typ)
- links.append(link)
- qs.append(q)
- qds.append(qd)
- return act
-
- out = scan.actuator_types(
- sys,
- f,
- 'alqd',
- 'a',
- np.arange(sys.act_size()),
- np.arange(sys.num_links()),
- np.arange(sys.q_size()),
- np.arange(sys.qd_size()),
- )
-
- self.assertSequenceEqual(typs, act_typs)
- np.testing.assert_array_equal(out, np.array(act_id))
-
- self.assertLen(links, 1)
- self.assertSequenceEqual(sys.actuator_link_id, act_link_id)
- np.testing.assert_array_equal(links[0], np.array(sys.actuator_link_id))
-
- self.assertLen(qs, 1)
- np.testing.assert_array_equal(qs[0], np.array(q_id))
-
- self.assertLen(qds, 1)
- np.testing.assert_array_equal(qds[0], np.array(qd_id))
-
-
if __name__ == '__main__':
absltest.main()
diff --git a/brax/spring/perf_test.py b/brax/spring/perf_test.py
index d7d7ec61..2b33e042 100644
--- a/brax/spring/perf_test.py
+++ b/brax/spring/perf_test.py
@@ -34,7 +34,7 @@ def init_fn(rng):
return pipeline.init(sys, q, qd)
def step_fn(state):
- return pipeline.step(sys, state, jp.zeros(sys.qd_size()))
+ return pipeline.step(sys, state, jp.zeros(sys.act_size()))
test_utils.benchmark('spring pipeline ant', init_fn, step_fn)
diff --git a/brax/spring/pipeline.py b/brax/spring/pipeline.py
index 3a4a8456..748f76ca 100644
--- a/brax/spring/pipeline.py
+++ b/brax/spring/pipeline.py
@@ -87,7 +87,7 @@ def step(
state = state.replace(i_inv=com.inv_inertia(sys, state.x))
# calculate acceleration and delta-velocity terms
- tau = actuator.to_tau(sys, act, state.q)
+ tau = actuator.to_tau(sys, act, state.q, state.qd)
xdd_i = joints.resolve(sys, state, tau) + Motion.create(vel=sys.gravity)
# semi-implicit euler: apply acceleration update before resolving collisions
state = state.replace(xd_i=state.xd_i + xdd_i * sys.dt)
diff --git a/brax/spring/pipeline_test.py b/brax/spring/pipeline_test.py
index 5043d75d..698107eb 100644
--- a/brax/spring/pipeline_test.py
+++ b/brax/spring/pipeline_test.py
@@ -16,8 +16,10 @@
"""Tests for spring physics pipeline."""
from absl.testing import absltest
+from brax import com
from brax import kinematics
from brax import test_utils
+from brax.base import Transform
from brax.generalized import pipeline as g_pipeline
from brax.spring import pipeline
import jax
@@ -37,7 +39,7 @@ def test_pendulum(self):
state = pipeline.init(sys, sys.init_q, jp.zeros(sys.qd_size()))
j_spring_step = jax.jit(pipeline.step)
for _ in range(10_000):
- state = j_spring_step(sys, state, jp.zeros(sys.qd_size()))
+ state = j_spring_step(sys, state, jp.zeros(sys.act_size()))
x = state.x
# compare against generalized step
@@ -46,7 +48,7 @@ def test_pendulum(self):
j_g_step = jax.jit(g_pipeline.step)
j_forward = jax.jit(kinematics.forward)
for _ in range(10_000):
- state = j_g_step(sys, state, jp.zeros(sys.qd_size()))
+ state = j_g_step(sys, state, jp.zeros(sys.act_size()))
x_g, _ = j_forward(sys, state.q, state.qd)
# trajectories should be close after 1 second of simulation
@@ -71,12 +73,11 @@ def test_universal_pendulum(self):
state = pipeline.init(sys, init_q, init_qd)
j_spring_step = jax.jit(pipeline.step)
for _ in range(1000):
- state = j_spring_step(sys, state, jp.zeros(sys.qd_size()))
+ state = j_spring_step(sys, state, jp.zeros(sys.act_size()))
x = state.x
# compare against generalized step
- q, qd = init_q, init_qd
- state = g_pipeline.init(sys, q, qd)
+ state = g_pipeline.init(sys, init_q, init_qd)
j_g_step = jax.jit(g_pipeline.step)
j_forward = jax.jit(kinematics.forward)
for _ in range(1000):
@@ -84,7 +85,7 @@ def test_universal_pendulum(self):
x_g, _ = j_forward(sys, state.q, state.qd)
# trajectories should be close after 1 second of simulation
- self.assertLess(jp.linalg.norm(x_g.rot - x.rot), 1.5e-2)
+ self.assertLess(jp.linalg.norm(x_g.rot - x.rot), 1.51e-2)
def test_spherical_pendulum(self):
sys = test_utils.load_fixture('single_spherical_pendulum.xml')
@@ -103,18 +104,24 @@ def test_spherical_pendulum(self):
sys = sys.replace(solver_iterations=500)
state = pipeline.init(sys, init_q, init_qd)
+ # the qd calculation for pbd/spring doesn't match generalized, so we get xd
+ # from generalized and plug it back into pbd
+ # TODO: remove this xd override once kinematics.forward is fixed
+ state_g = g_pipeline.init(sys, init_q, init_qd)
+ xd = Transform.create(pos=state_g.x.pos - state_g.com).vmap().do(state_g.cd)
+ state = state.replace(xd=xd, xd_i=com.from_world(sys, state.x, xd)[1])
+
j_spring_step = jax.jit(pipeline.step)
for _ in range(1000):
- state = j_spring_step(sys, state, jp.zeros(sys.qd_size()))
+ state = j_spring_step(sys, state, jp.zeros(sys.act_size()))
x = state.x
# compare against generalized step
- q, qd = init_q, init_qd
- state = g_pipeline.init(sys, q, qd)
+ state = state_g
j_g_step = jax.jit(g_pipeline.step)
j_forward = jax.jit(kinematics.forward)
for _ in range(1000):
- state = j_g_step(sys, state, jp.zeros(sys.qd_size()))
+ state = j_g_step(sys, state, jp.zeros(sys.act_size()))
x_g, _ = j_forward(sys, state.q, state.qd)
# trajectories should be close after 1 second of simulation
@@ -140,7 +147,7 @@ def test_prismatic_joint(self):
state = pipeline.init(sys, init_q, init_qd)
j_spring_step = jax.jit(pipeline.step)
for _ in range(1000):
- state = j_spring_step(sys, state, jp.zeros(sys.qd_size()))
+ state = j_spring_step(sys, state, jp.zeros(sys.act_size()))
x = state.x
# compare against generalized step
@@ -172,7 +179,7 @@ def test_2d_sliding_joint(self):
j_spring_step = jax.jit(pipeline.step)
states = []
for _ in range(1000):
- state = j_spring_step(sys, state, jp.zeros(sys.qd_size()))
+ state = j_spring_step(sys, state, jp.zeros(sys.act_size()))
states.append(state)
x, xd = state.x, state.xd
@@ -201,7 +208,7 @@ def test_3d_sliding_joint(self):
j_spring_step = jax.jit(pipeline.step)
states = []
for _ in range(1000):
- state = j_spring_step(sys, state, jp.zeros(sys.qd_size()))
+ state = j_spring_step(sys, state, jp.zeros(sys.act_size()))
states.append(state)
x, xd = state.x, state.xd
@@ -228,7 +235,7 @@ def test_2d_prismaversal_joint(self):
j_spring_step = jax.jit(pipeline.step)
states = []
for _ in range(1000):
- state = j_spring_step(sys, state, jp.zeros(sys.qd_size()))
+ state = j_spring_step(sys, state, jp.zeros(sys.act_size()))
states.append(state)
# reflects off limits and is still traveling 2.5 m/s
@@ -251,7 +258,7 @@ def test_3d_prismaversal_joint(self):
j_spring_step = jax.jit(pipeline.step)
states = []
for _ in range(1000):
- state = j_spring_step(sys, state, jp.zeros(sys.qd_size()))
+ state = j_spring_step(sys, state, jp.zeros(sys.act_size()))
states.append(state)
# reflects off limits and is still traveling 2.5 m/s
@@ -272,7 +279,7 @@ def test_sliding_capsule(self):
state = pipeline.init(sys, sys.init_q, qd)
j_spring_step = jax.jit(pipeline.step)
for _ in range(1000):
- state = j_spring_step(sys, state, jp.zeros(sys.qd_size()))
+ state = j_spring_step(sys, state, jp.zeros(sys.act_size()))
x, xd = state.x, state.xd
# capsule slides to a stop
diff --git a/brax/test_data/fat_cylinder.xml b/brax/test_data/fat_cylinder.xml
new file mode 100644
index 00000000..e3fd71dc
--- /dev/null
+++ b/brax/test_data/fat_cylinder.xml
@@ -0,0 +1,11 @@
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/brax/test_data/flat_cylinder.xml b/brax/test_data/flat_cylinder.xml
new file mode 100644
index 00000000..266f5df3
--- /dev/null
+++ b/brax/test_data/flat_cylinder.xml
@@ -0,0 +1,11 @@
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/brax/test_data/fluid_box.xml b/brax/test_data/fluid_box.xml
new file mode 100644
index 00000000..59fa968f
--- /dev/null
+++ b/brax/test_data/fluid_box.xml
@@ -0,0 +1,14 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/brax/test_data/fluid_box_offset_com.xml b/brax/test_data/fluid_box_offset_com.xml
new file mode 100644
index 00000000..38ef6fcc
--- /dev/null
+++ b/brax/test_data/fluid_box_offset_com.xml
@@ -0,0 +1,23 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/brax/test_data/fluid_ellipsoid.xml b/brax/test_data/fluid_ellipsoid.xml
new file mode 100644
index 00000000..c52f6006
--- /dev/null
+++ b/brax/test_data/fluid_ellipsoid.xml
@@ -0,0 +1,14 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/brax/test_data/fluid_wind.xml b/brax/test_data/fluid_wind.xml
new file mode 100644
index 00000000..27cdf2dc
--- /dev/null
+++ b/brax/test_data/fluid_wind.xml
@@ -0,0 +1,14 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/brax/test_data/single_pendulum_velocity.xml b/brax/test_data/single_pendulum_velocity.xml
new file mode 100644
index 00000000..5eafb9ec
--- /dev/null
+++ b/brax/test_data/single_pendulum_velocity.xml
@@ -0,0 +1,20 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/brax/test_data/solver_params_v2.xml b/brax/test_data/solver_params_v2.xml
new file mode 100644
index 00000000..e52f0300
--- /dev/null
+++ b/brax/test_data/solver_params_v2.xml
@@ -0,0 +1,13 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/brax/test_utils.py b/brax/test_utils.py
index 05d6cf2c..6d2f3c35 100644
--- a/brax/test_utils.py
+++ b/brax/test_utils.py
@@ -37,10 +37,26 @@ def load_fixture_mujoco(path: str) -> mujoco.MjModel:
return model
+def _normalize_q(model: mujoco.MjModel, q: np.ndarray):
+ """Normalizes the quaternion part of q."""
+ q = np.array(q)
+ q_idx = 0
+ for typ in model.jnt_type:
+ q_dim = 7 if typ == 0 else 1
+ if typ == 0:
+ q[q_idx + 3:q_idx + 7] = (
+ q[q_idx + 3:q_idx + 7] / np.linalg.norm(q[q_idx + 3:q_idx + 7]))
+ q_idx += q_dim
+ return q
+
+
def sample_mujoco_states(
- path: str, count: int = 500, modulo: int = 20, force_pgs: bool = False
+ path: str, count: int = 500, modulo: int = 20, force_pgs: bool = False,
+ random_init: bool = False, random_q_scale: float = 1.0,
+ random_qd_scale: float = 0.1, vel_to_local: bool = True, seed: int = 42
) -> Iterable[Tuple[mujoco.MjData, mujoco.MjData]]:
"""Samples count / modulo states from mujoco for comparison."""
+ np.random.seed(seed)
model = load_fixture_mujoco(path)
model.opt.iterations = 50 # return to default for high-precision comparison
if force_pgs:
@@ -48,6 +64,10 @@ def sample_mujoco_states(
data = mujoco.MjData(model)
# give a little kick to avoid symmetry
data.qvel = np.random.uniform(low=-0.01, high=0.01, size=(model.nv,))
+ if random_init:
+ data.qpos = np.random.uniform(model.nq) * random_q_scale
+ data.qpos = _normalize_q(model, data.qpos)
+ data.qvel = np.random.uniform(size=(model.nv,)) * random_qd_scale
for i in range(count):
before = copy.deepcopy(data)
mujoco.mj_step(model, data)
@@ -55,7 +75,8 @@ def sample_mujoco_states(
# hijack subtree_angmom, subtree_linvel (unused) to store xang, xvel
for i in range(model.nbody):
vel = np.zeros((6,))
- mujoco.mj_objectVelocity(model, data, 2, i, vel, 1)
+ mujoco.mj_objectVelocity(
+ model, data, mujoco.mjtObj.mjOBJ_XBODY.value, i, vel, vel_to_local)
data.subtree_angmom[i] = vel[:3]
data.subtree_linvel[i] = vel[3:]
yield before, data
diff --git a/brax/training/agents/ars/train.py b/brax/training/agents/ars/train.py
index 1d12ce3e..bee52124 100644
--- a/brax/training/agents/ars/train.py
+++ b/brax/training/agents/ars/train.py
@@ -141,8 +141,9 @@ def add_noise(params: Params, key: PRNGKey) -> Tuple[Params, Params, Params]:
noise = jax.tree_util.tree_map(
lambda g, k: jax.random.normal(k, shape=g.shape, dtype=g.dtype), params,
jax.tree_util.tree_unflatten(treedef, all_keys))
- params_with_noise = jax.tree_util.tree_map(lambda g, n: g + n * exploration_noise_std,
- params, noise)
+ params_with_noise = jax.tree_util.tree_map(
+ lambda g, n: g + n * exploration_noise_std, params, noise
+ )
params_with_anti_noise = jax.tree_util.tree_map(
lambda g, n: g - n * exploration_noise_std, params, noise)
return params_with_noise, params_with_anti_noise, noise
@@ -154,15 +155,20 @@ def training_epoch(training_state: TrainingState,
key: PRNGKey) -> Tuple[TrainingState, Metrics]:
params = jax.tree_util.tree_map(
lambda x: jnp.repeat(
- jnp.expand_dims(x, axis=0), number_of_directions, axis=0),
- training_state.policy_params)
+ jnp.expand_dims(x, axis=0), number_of_directions, axis=0
+ ),
+ training_state.policy_params,
+ )
key, key_noise, key_es_eval = jax.random.split(key, 3)
# generate perturbations
params_with_noise, params_with_anti_noise, noise = add_noise(
params, key_noise)
- pparams = jax.tree_util.tree_map(lambda a, b: jnp.concatenate([a, b], axis=0),
- params_with_noise, params_with_anti_noise)
+ pparams = jax.tree_util.tree_map(
+ lambda a, b: jnp.concatenate([a, b], axis=0),
+ params_with_noise,
+ params_with_anti_noise,
+ )
pparams = jax.tree_util.tree_map(
lambda x: jnp.reshape(x, (local_devices_to_use, -1) + x.shape[1:]),
@@ -187,13 +193,17 @@ def training_epoch(training_state: TrainingState,
reward_weight_double = jnp.concatenate([reward_weight, reward_weight],
axis=0)
reward_std = jnp.std(eval_scores, where=reward_weight_double)
+ reward_std += (reward_std == 0.0) * 1e-6
noise = jax.tree_util.tree_map(
lambda x: jnp.sum(
jnp.transpose(
- jnp.transpose(x) * reward_weight *
- (reward_plus - reward_minus)),
- axis=0), noise)
+ jnp.transpose(x) * reward_weight * (reward_plus - reward_minus)
+ ),
+ axis=0,
+ ),
+ noise,
+ )
policy_params = jax.tree_util.tree_map(
lambda x, y: x + step_size * y / (top_directions * reward_std),
diff --git a/brax/training/agents/ars/train_test.py b/brax/training/agents/ars/train_test.py
index 750d33e2..229642ad 100644
--- a/brax/training/agents/ars/train_test.py
+++ b/brax/training/agents/ars/train_test.py
@@ -21,7 +21,6 @@
from brax.training.acme import running_statistics
from brax.training.agents.ars import networks as ars_networks
from brax.training.agents.ars import train as ars
-from brax.v1 import envs as envs_v1
import jax
diff --git a/brax/training/agents/sac/train.py b/brax/training/agents/sac/train.py
index 051dfc3d..b4b4d9f6 100644
--- a/brax/training/agents/sac/train.py
+++ b/brax/training/agents/sac/train.py
@@ -314,8 +314,7 @@ def training_step(
(training_state, training_key),
transitions)
- metrics['buffer_current_size'] = buffer_state.current_size
- metrics['buffer_current_position'] = buffer_state.current_position
+ metrics['buffer_current_size'] = replay_buffer.size(buffer_state)
return training_state, env_state, buffer_state, metrics
def prefill_replay_buffer(
diff --git a/brax/training/learner.py b/brax/training/learner.py
index 9d2e54d4..d85caf6a 100644
--- a/brax/training/learner.py
+++ b/brax/training/learner.py
@@ -118,8 +118,7 @@
'Std of a random noise added by ARS.')
flags.DEFINE_float('reward_shift', 0.,
'A reward shift to get rid of "stay alive" bonus.')
-flags.DEFINE_enum('head_type', '', ['', 'clip', 'tanh'],
- 'Which policy head to use.')
+
# ARS hps.
flags.DEFINE_integer('truncation_length', None,
'Truncation for gradient propagation in APG.')
diff --git a/brax/training/replay_buffers.py b/brax/training/replay_buffers.py
index 891efac2..1fe18b6d 100644
--- a/brax/training/replay_buffers.py
+++ b/brax/training/replay_buffers.py
@@ -71,8 +71,8 @@ class ReplayBufferState:
"""Contains data related to a replay buffer."""
data: jnp.ndarray
- current_position: jnp.ndarray
- current_size: jnp.ndarray
+ insert_position: jnp.ndarray
+ sample_position: jnp.ndarray
key: PRNGKey
@@ -109,14 +109,14 @@ def __init__(
def init(self, key: PRNGKey) -> ReplayBufferState:
return ReplayBufferState(
data=jnp.zeros(self._data_shape, self._data_dtype),
- current_size=jnp.zeros((), jnp.int32),
- current_position=jnp.zeros((), jnp.int32),
+ sample_position=jnp.zeros((), jnp.int32),
+ insert_position=jnp.zeros((), jnp.int32),
key=key,
)
def check_can_insert(self, buffer_state, samples, shards):
"""Checks whether insert operation can be performed."""
- assert isinstance(shards, int), "This method should not be JITed."
+ assert isinstance(shards, int), 'This method should not be JITed.'
insert_size = jax.tree_flatten(samples)[0][0].shape[0] // shards
if self._data_shape[0] < insert_size:
raise ValueError(
@@ -141,14 +141,15 @@ def insert_internal(
if buffer_state.data.shape != self._data_shape:
raise ValueError(
f'buffer_state.data.shape ({buffer_state.data.shape}) '
- f'doesn\'t match the expected value ({self._data_shape})')
+ f"doesn't match the expected value ({self._data_shape})"
+ )
update = self._flatten_fn(samples)
data = buffer_state.data
# If needed, roll the buffer to make sure there's enough space to fit
# `update` after the current position.
- position = buffer_state.current_position
+ position = buffer_state.insert_position
roll = jnp.minimum(0, len(data) - position - len(update))
data = jax.lax.cond(
roll, lambda: jnp.roll(data, roll, axis=0), lambda: data
@@ -157,11 +158,13 @@ def insert_internal(
# Update the buffer and the control numbers.
data = jax.lax.dynamic_update_slice_in_dim(data, update, position, axis=0)
- position = (position + len(update)) % len(data)
- size = jnp.minimum(buffer_state.current_size + len(update), len(data))
+ position = (position + len(update)) % (len(data) + 1)
+ sample_position = jnp.maximum(0, buffer_state.sample_position + roll)
return buffer_state.replace(
- data=data, current_position=position, current_size=size
+ data=data,
+ insert_position=position,
+ sample_position=sample_position,
)
def sample_internal(
@@ -170,21 +173,46 @@ def sample_internal(
raise NotImplementedError(f'{self.__class__}.sample() is not implemented.')
def size(self, buffer_state: ReplayBufferState) -> int:
- return buffer_state.current_size # pytype: disable=bad-return-type # jax-ndarray
+ return buffer_state.insert_position - buffer_state.sample_position # pytype: disable=bad-return-type # jax-ndarray
class Queue(QueueBase[Sample], Generic[Sample]):
"""Implements a limited-size queue replay buffer."""
+ def __init__(
+ self,
+ max_replay_size: int,
+ dummy_data_sample: Sample,
+ sample_batch_size: int,
+ cyclic: bool = False,
+ ):
+ """Initializes the queue.
+
+ Args:
+ max_replay_size: Maximum number of elements queue can have.
+ dummy_data_sample: Example record to be stored in the queue, it is used to
+ derive shapes.
+ sample_batch_size: How many elements sampling from the queue should return
+ in a batch.
+ cyclic: Should sampling from the queue behave cyclicly, ie. once recently
+ inserted element was sampled, sampling starts from the beginning of the
+ buffer. For example, if the current queue content is [0, 1, 2] and
+ `sample_batch_size` is 2, then consecutive calls to sample will give:
+ [0, 1], [2, 0], [1, 2]...
+ """
+ super().__init__(max_replay_size, dummy_data_sample, sample_batch_size)
+ self._cyclic = cyclic
+
def check_can_sample(self, buffer_state, shards):
"""Checks whether sampling can be performed. Do not JIT this method."""
- assert isinstance(shards, int), "This method should not be JITed."
+ assert isinstance(shards, int), 'This method should not be JITed.'
if self._size < self._sample_batch_size:
raise ValueError(
f'Trying to sample {self._sample_batch_size * shards} elements, but'
f' only {self._size * shards} available.'
)
- self._size -= self._sample_batch_size
+ if not self._cyclic:
+ self._size -= self._sample_batch_size
def sample_internal(
self, buffer_state: ReplayBufferState
@@ -205,34 +233,24 @@ def sample_internal(
# Note that this may be out of bound, but the operations below would still
# work fine as they take this number modulo the buffer size.
- first_element_idx = (
- buffer_state.current_position - buffer_state.current_size
- )
- idx = jnp.arange(self._sample_batch_size) + first_element_idx
+ idx = (jnp.arange(self._sample_batch_size) + buffer_state.sample_position) % buffer_state.insert_position
flat_batch = jnp.take(buffer_state.data, idx, axis=0, mode='wrap')
- # TODO: Raise an error instead of padding with zeros
- # when the buffer does not contain enough elements.
- # If the sample batch size is larger than the number of elements in the
- # queue, `mask` would contain 0s for all elements that are past the current
- # position. Otherwise, `mask` will be only ones.
- # mask.shape = (self._sample_batch_size,)
- mask = idx < buffer_state.current_position
- # mask.shape = (self._sample_batch_size, 1)
- mask = jnp.expand_dims(mask, axis=range(1, flat_batch.ndim))
- flat_batch = flat_batch * mask
-
- # The effective size of the sampled batch.
- sample_size = jnp.minimum(
- self._sample_batch_size, buffer_state.current_size
- )
# Remove the sampled batch from the queue.
- new_state = buffer_state.replace(
- current_size=buffer_state.current_size - sample_size
- )
+ sample_position = buffer_state.sample_position + self._sample_batch_size
+ if self._cyclic:
+ sample_position = sample_position % buffer_state.insert_position
+
+ new_state = buffer_state.replace(sample_position=sample_position)
return new_state, self._unflatten_fn(flat_batch)
+ def size(self, buffer_state: ReplayBufferState) -> int:
+ if self._cyclic:
+ return buffer_state.insert_position # pytype: disable=bad-return-type # jax-ndarray
+ else:
+ return buffer_state.insert_position - buffer_state.sample_position # pytype: disable=bad-return-type # jax-ndarray
+
class UniformSamplingQueue(QueueBase[Sample], Generic[Sample]):
"""Implements an uniform sampling limited-size replay queue.
@@ -257,8 +275,8 @@ def sample_internal(
idx = jax.random.randint(
sample_key,
(self._sample_batch_size,),
- minval=buffer_state.current_position - buffer_state.current_size,
- maxval=buffer_state.current_position,
+ minval=buffer_state.sample_position,
+ maxval=buffer_state.insert_position,
)
batch = jnp.take(buffer_state.data, idx, axis=0, mode='wrap')
return buffer_state.replace(key=key), self._unflatten_fn(batch)
diff --git a/brax/training/replay_buffers_test.py b/brax/training/replay_buffers_test.py
index e023a958..9aa11cc0 100644
--- a/brax/training/replay_buffers_test.py
+++ b/brax/training/replay_buffers_test.py
@@ -299,9 +299,9 @@ def testUniformSamplingQueueCyclicSample(self, wrapper):
)
assert_equal(self, replay_buffer.size(buffer_state), 10)
if jax.device_count() == 1 or wrapper in [no_wrap, jit_wrap]:
- assert_equal(self, buffer_state.current_position, 0)
+ assert_equal(self, buffer_state.insert_position, 10)
else:
- assert_equal(self, buffer_state.current_position, [5, 5])
+ assert_equal(self, buffer_state.insert_position, [5, 5])
@parameterized.parameters(WRAPPERS)
def testQueueSamplePyTree(self, wrapper):
@@ -357,14 +357,14 @@ def testQueueSample(self, wrapper):
buffer_state.data,
[[0], [1], [2], [3], [0], [0], [0], [0], [0], [0]],
)
- assert_equal(self, buffer_state.current_position, 4)
+ assert_equal(self, buffer_state.insert_position, 4)
else:
assert_equal(
self,
buffer_state.data,
[[[0], [2], [0], [0], [0]], [[1], [3], [0], [0], [0]]],
)
- assert_equal(self, buffer_state.current_position, [2, 2])
+ assert_equal(self, buffer_state.insert_position, [2, 2])
assert_equal(self, replay_buffer.size(buffer_state), 4)
buffer_state = replay_buffer.insert(buffer_state, jnp.arange(4, 10))
@@ -374,14 +374,14 @@ def testQueueSample(self, wrapper):
buffer_state.data,
[[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]],
)
- assert_equal(self, buffer_state.current_position, [0, 0])
+ assert_equal(self, buffer_state.insert_position, 10)
else:
assert_equal(
self,
buffer_state.data,
[[[0], [2], [4], [6], [8]], [[1], [3], [5], [7], [9]]],
)
- assert_equal(self, buffer_state.current_position, 0)
+ assert_equal(self, buffer_state.insert_position, [5, 5])
assert_equal(self, replay_buffer.size(buffer_state), 10)
buffer_state, samples = replay_buffer.sample(buffer_state)
@@ -402,6 +402,92 @@ def testQueueSample(self, wrapper):
assert_equal(self, samples, [8, 9, 20, 21])
assert_equal(self, replay_buffer.size(buffer_state), 2)
+ @parameterized.parameters(WRAPPERS)
+ def testCyclicQueueSample(self, wrapper):
+ mesh = get_mesh()
+ size_denominator = (
+ 1 if wrapper in [no_wrap, jit_wrap] else mesh.shape[AXIS_NAME]
+ )
+ mesh = get_mesh()
+ replay_buffer = wrapper(
+ replay_buffers.Queue(
+ max_replay_size=10 // size_denominator,
+ dummy_data_sample=0,
+ sample_batch_size=4 // size_denominator,
+ cyclic=True,
+ )
+ )
+ rng = jax.random.PRNGKey(0)
+
+ buffer_state = replay_buffer.init(rng)
+
+ buffer_state = replay_buffer.insert(buffer_state, jnp.arange(6))
+ if jax.device_count() == 1 or wrapper in [no_wrap, jit_wrap]:
+ assert_equal(
+ self,
+ buffer_state.data,
+ [[0], [1], [2], [3], [4], [5], [0], [0], [0], [0]],
+ )
+ assert_equal(self, buffer_state.insert_position, 6)
+ else:
+ assert_equal(
+ self,
+ buffer_state.data,
+ [[[0], [2], [4], [0], [0]], [[1], [3], [5], [0], [0]]],
+ )
+ assert_equal(self, buffer_state.insert_position, [3, 3])
+ assert_equal(self, replay_buffer.size(buffer_state), 6)
+
+ buffer_state, samples = replay_buffer.sample(buffer_state)
+ assert_equal(self, samples, [0, 1, 2, 3])
+ assert_equal(self, replay_buffer.size(buffer_state), 6)
+
+ buffer_state, samples = replay_buffer.sample(buffer_state)
+ assert_equal(self, samples, [4, 5, 0, 1])
+ assert_equal(self, replay_buffer.size(buffer_state), 6)
+
+ buffer_state = replay_buffer.insert(buffer_state, jnp.arange(6, 10))
+ if jax.device_count() == 1 or wrapper in [no_wrap, jit_wrap]:
+ assert_equal(
+ self,
+ buffer_state.data,
+ [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]],
+ )
+ assert_equal(self, buffer_state.insert_position, 10)
+ assert_equal(self, buffer_state.sample_position, 2)
+ else:
+ assert_equal(
+ self,
+ buffer_state.data,
+ [[[0], [2], [4], [6], [8]], [[1], [3], [5], [7], [9]]],
+ )
+ assert_equal(self, buffer_state.insert_position, [5, 5])
+ assert_equal(self, buffer_state.sample_position, [1, 1])
+ assert_equal(self, replay_buffer.size(buffer_state), 10)
+
+ buffer_state, samples = replay_buffer.sample(buffer_state)
+ assert_equal(self, samples, [2, 3, 4, 5])
+ assert_equal(self, replay_buffer.size(buffer_state), 10)
+
+ buffer_state, samples = replay_buffer.sample(buffer_state)
+ assert_equal(self, samples, [6, 7, 8, 9])
+ assert_equal(self, replay_buffer.size(buffer_state), 10)
+
+ buffer_state = replay_buffer.insert(buffer_state, jnp.arange(20, 24))
+ assert_equal(self, replay_buffer.size(buffer_state), 10)
+
+ buffer_state, samples = replay_buffer.sample(buffer_state)
+ assert_equal(self, samples, [4, 5, 6, 7])
+ assert_equal(self, replay_buffer.size(buffer_state), 10)
+
+ buffer_state, samples = replay_buffer.sample(buffer_state)
+ assert_equal(self, samples, [8, 9, 20, 21])
+ assert_equal(self, replay_buffer.size(buffer_state), 10)
+
+ buffer_state, samples = replay_buffer.sample(buffer_state)
+ assert_equal(self, samples, [22, 23, 4, 5])
+ assert_equal(self, replay_buffer.size(buffer_state), 10)
+
@parameterized.parameters(WRAPPERS)
def testQueueInsertWhenFull(self, wrapper):
mesh = get_mesh()
@@ -426,9 +512,9 @@ def testQueueInsertWhenFull(self, wrapper):
assert_equal(
self,
buffer_state.data,
- [[10], [11], [2], [3], [4], [5], [6], [7], [8], [9]],
+ [[2], [3], [4], [5], [6], [7], [8], [9], [10], [11]],
)
- assert_equal(self, buffer_state.current_position, 2)
+ assert_equal(self, buffer_state.insert_position, 10)
assert_equal(self, replay_buffer.size(buffer_state), 10)
@parameterized.parameters(WRAPPERS)
@@ -453,13 +539,13 @@ def testQueueWrappedSample(self, wrapper):
assert_equal(
self,
buffer_state.data,
- [[10], [11], [12], [13], [14], [15], [6], [7], [8], [9]],
+ [[6], [7], [8], [9], [10], [11], [12], [13], [14], [15]],
)
assert_equal(self, replay_buffer.size(buffer_state), 10)
if jax.device_count() == 1 or wrapper in [no_wrap, jit_wrap]:
- assert_equal(self, buffer_state.current_position, 6)
+ assert_equal(self, buffer_state.insert_position, 10)
else:
- assert_equal(self, buffer_state.current_position, [3, 3])
+ assert_equal(self, buffer_state.insert_position, [5, 5])
# This sample contains elements from both the beggining and the end of
# the buffer.
@@ -467,9 +553,9 @@ def testQueueWrappedSample(self, wrapper):
assert_equal(self, samples, jnp.array([6, 7, 8, 9, 10, 11, 12, 13]))
assert_equal(self, replay_buffer.size(buffer_state), 2)
if jax.device_count() == 1 or wrapper in [no_wrap, jit_wrap]:
- assert_equal(self, buffer_state.current_position, 6)
+ assert_equal(self, buffer_state.insert_position, 10)
else:
- assert_equal(self, buffer_state.current_position, [3, 3])
+ assert_equal(self, buffer_state.insert_position, [5, 5])
@parameterized.parameters(WRAPPERS)
def testQueueBatchSizeEqualsMaxSize(self, wrapper):
@@ -493,7 +579,10 @@ def testQueueBatchSizeEqualsMaxSize(self, wrapper):
buffer_state = replay_buffer.insert(buffer_state, jnp.arange(batch_size))
buffer_state, samples = replay_buffer.sample(buffer_state)
assert_equal(self, samples, range(batch_size))
- assert_equal(self, buffer_state.current_size, 0)
+ if jax.device_count() == 1 or wrapper in [no_wrap, jit_wrap]:
+ assert_equal(self, buffer_state.sample_position, 8)
+ else:
+ assert_equal(self, buffer_state.sample_position, [4, 4])
buffer_state = replay_buffer.insert(
buffer_state, jnp.zeros(batch_size, dtype=jnp.int32)
@@ -503,7 +592,10 @@ def testQueueBatchSizeEqualsMaxSize(self, wrapper):
)
buffer_state, samples = replay_buffer.sample(buffer_state)
assert_equal(self, samples, [1] * batch_size)
- assert_equal(self, buffer_state.current_size, 0)
+ if jax.device_count() == 1 or wrapper in [no_wrap, jit_wrap]:
+ assert_equal(self, buffer_state.sample_position, 8)
+ else:
+ assert_equal(self, buffer_state.sample_position, [4, 4])
@parameterized.parameters(WRAPPERS)
def testQueueSampleFromEmpty(self, wrapper) -> None:
@@ -526,19 +618,25 @@ def testQueueSampleFromEmpty(self, wrapper) -> None:
buffer_state = replay_buffer.insert(buffer_state, jnp.arange(batch_size))
buffer_state, samples = replay_buffer.sample(buffer_state)
assert_equal(self, samples, range(batch_size))
- assert_equal(self, buffer_state.current_size, 0)
+ if jax.device_count() == 1 or wrapper in [no_wrap, jit_wrap]:
+ assert_equal(self, buffer_state.sample_position, 10)
+ else:
+ assert_equal(self, buffer_state.sample_position, [5, 5])
with self.assertRaisesRegex(
ValueError, 'Trying to sample 10 elements, but only 0 available.'
):
replay_buffer.sample(buffer_state)
- assert_equal(self, buffer_state.current_size, 0)
+ if jax.device_count() == 1 or wrapper in [no_wrap, jit_wrap]:
+ assert_equal(self, buffer_state.sample_position, 10)
+ else:
+ assert_equal(self, buffer_state.sample_position, [5, 5])
buffer_state = replay_buffer.insert(buffer_state, jnp.arange(10, 14))
if jax.device_count() == 1 or wrapper in [no_wrap, jit_wrap]:
- assert_equal(self, buffer_state.current_size, 4)
+ assert_equal(self, buffer_state.sample_position, 6)
else:
- assert_equal(self, buffer_state.current_size, [2, 2])
+ assert_equal(self, buffer_state.sample_position, [3, 3])
with self.assertRaisesRegex(
ValueError, 'Trying to sample 10 elements, but only 4 available.'
):
diff --git a/brax/visualizer/js/system.js b/brax/visualizer/js/system.js
index 713231d1..15280c64 100644
--- a/brax/visualizer/js/system.js
+++ b/brax/visualizer/js/system.js
@@ -51,6 +51,16 @@ function getMeshAxisSize(geom) {
return size * 2;
}
+function createCylinder(radius, height, mat) {
+ const geometry = new THREE.CylinderGeometry(radius, radius, height, 32);
+ mat.side = THREE.DoubleSide;
+ const cyl = new THREE.Mesh(geometry, mat);
+ cyl.baseMaterial = cyl.material;
+ cyl.castShadow = true;
+ cyl.layers.enable(1);
+ return cyl;
+}
+
function createCapsule(capsule, mat) {
const sphere_geom = new THREE.SphereGeometry(capsule.radius, 16, 16);
const cylinder_geom = new THREE.CylinderGeometry(
@@ -80,7 +90,7 @@ function createCapsule(capsule, mat) {
}
function createBox(box, mat) {
- const geom = new THREE.BoxBufferGeometry(
+ const geom = new THREE.BoxGeometry(
2 * box.halfsize[0], 2 * box.halfsize[1], 2 * box.halfsize[2]);
const mesh = new THREE.Mesh(geom, mat);
mesh.castShadow = true;
@@ -141,6 +151,7 @@ function createScene(system) {
// Add a world axis for debugging.
const worldAxis = new THREE.AxesHelper(100);
+ const qRotx90 = new THREE.Quaternion(0.70710677, 0.0, 0.0, 0.7071067);
worldAxis.visible = false;
scene.add(worldAxis);
@@ -176,6 +187,9 @@ function createScene(system) {
} else if (collider.name == 'Mesh') {
child = createMesh(collider, mat);
axisSize = getMeshAxisSize(collider);
+ } else if (collider.name == 'Cylinder') {
+ child = createCylinder(collider.radius, collider.length, mat);
+ axisSize = 2 * Math.max(collider.radius, collider.length);
} else if ('clippedPlane' in collider) {
console.log('clippedPlane not implemented');
return;
@@ -184,9 +198,13 @@ function createScene(system) {
return;
}
if (collider.transform.rot) {
- child.quaternion.set(
- collider.transform.rot[1], collider.transform.rot[2],
- collider.transform.rot[3], collider.transform.rot[0]);
+ const quat = new THREE.Quaternion(
+ collider.transform.rot[1], collider.transform.rot[2],
+ collider.transform.rot[3], collider.transform.rot[0]);
+ if (collider.name == 'Cylinder') {
+ quat.multiply(qRotx90)
+ }
+ child.quaternion.fromArray(quat.toArray());
}
if (collider.transform.pos) {
child.position.set(