diff --git a/MANIFEST.in b/MANIFEST.in index d82985f0..0e6fec3d 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,3 +1,3 @@ include brax/envs/assets/*.xml recursive-include brax/test_data *.xml *.stl *.obj *.urdf -recursive-include brax/visualizer * +recursive-include brax/visualizer * \ No newline at end of file diff --git a/brax/actuator.py b/brax/actuator.py index 192b99a6..39291c97 100644 --- a/brax/actuator.py +++ b/brax/actuator.py @@ -15,18 +15,20 @@ # pylint:disable=g-multiple-import """Functions for applying actuators to a physics pipeline.""" -from brax import scan from brax.base import System from jax import numpy as jp -def to_tau(sys: System, act: jp.ndarray, q: jp.ndarray) -> jp.ndarray: +def to_tau( + sys: System, act: jp.ndarray, q: jp.ndarray, qd: jp.ndarray +) -> jp.ndarray: """Convert actuator to a joint force tau. Args: sys: system defining the kinematic tree and other properties act: (act_size,) actuator force input vector q: joint position vector + qd: joint velocity vector Returns: tau: (qd_size,) vector of joint forces @@ -34,21 +36,12 @@ def to_tau(sys: System, act: jp.ndarray, q: jp.ndarray) -> jp.ndarray: if sys.act_size() == 0: return jp.zeros(sys.qd_size()) - def act_fn(act_type, act, actuator, q, qd_idx): - if act_type not in ('p', 'm'): - raise RuntimeError(f'unrecognized act type: {act_type}') - - force = jp.clip(act, actuator.ctrl_range[:, 0], actuator.ctrl_range[:, 1]) - if act_type == 'p': - force -= q # positional actuators have a bias - tau = actuator.gear * force - - return tau, qd_idx - - qd_idx = jp.arange(sys.qd_size()) - tau, qd_idx = scan.actuator_types( - sys, act_fn, 'aaqd', 'a', act, sys.actuator, q, qd_idx - ) - tau = jp.zeros(sys.qd_size()).at[qd_idx].add(tau) + q, qd = q[sys.actuator.q_id], qd[sys.actuator.qd_id] + ctrl_range = sys.actuator.ctrl_range + act = jp.clip(act, ctrl_range[:, 0], ctrl_range[:, 1]) + # TODO: incorporate gain + act = act + q * sys.actuator.bias_q + qd * sys.actuator.bias_qd + act_force = sys.actuator.gear * act + tau = jp.zeros(sys.qd_size()).at[sys.actuator.qd_id].add(act_force) return tau diff --git a/brax/actuator_test.py b/brax/actuator_test.py index c2e86508..a5089512 100644 --- a/brax/actuator_test.py +++ b/brax/actuator_test.py @@ -55,7 +55,7 @@ def test_motor(self, pipeline, dt, n, decimal): tau = jp.array([0.5 * 9.81]) # -mgl sin(theta) act = jp.array([1.0 / 150.0 * 0.5 * 9.81]) - tau2 = actuator.to_tau(sys, act, q) + tau2 = actuator.to_tau(sys, act, q, qd) np.testing.assert_array_almost_equal(tau, tau2, 5) q2, qd2 = _actuator_step(pipeline, sys, q, qd, act=act, dt=dt, n=n) @@ -73,7 +73,7 @@ def test_position(self): # a position actuator at the bottom should not move the pendulum act = jp.array([theta]) - tau = actuator.to_tau(sys, act, q) + tau = actuator.to_tau(sys, act, q, qd) np.testing.assert_array_almost_equal(tau, jp.array([0]), 5) # put the pendulum into the horizontal position with the positional actuator @@ -83,6 +83,24 @@ def test_position(self): q2, _ = _actuator_step(g_pipeline, sys, q, qd, act=act, dt=0.01, n=1) np.testing.assert_array_almost_equal(q2, jp.array([0]), 1) + def test_velocity(self): + """Tests a single pendulum with velocity actuator.""" + sys = test_utils.load_fixture('single_pendulum_velocity.xml') + mj_model = test_utils.load_fixture_mujoco('single_pendulum_velocity.xml') + mj_data = mujoco.MjData(mj_model) + q, qd = jp.array(mj_data.qpos), jp.array(mj_data.qvel) + theta = jp.pi / 2.0 # pendulum is vertical + q = jp.array([theta]) + + act = jp.array([0]) + tau = actuator.to_tau(sys, act, q, qd) + np.testing.assert_array_almost_equal(tau, jp.array([0]), 5) + + # set the act to rotate at 1/s + act = jp.array([1]) + _, qd = _actuator_step(g_pipeline, sys, q, qd, act=act, dt=0.001, n=200) + np.testing.assert_array_almost_equal(qd, jp.array([1]), 3) + @parameterized.parameters((g_pipeline,), (s_pipeline,), (p_pipeline,)) def test_three_link_pendulum(self, pipeline): """Tests a three link pendulum with a motor actuator.""" diff --git a/brax/base.py b/brax/base.py index b886ecae..f3c78c02 100644 --- a/brax/base.py +++ b/brax/base.py @@ -269,7 +269,7 @@ class Link(Base): constraint_stiffness: jp.ndarray constraint_vel_damping: jp.ndarray constraint_limit_stiffness: jp.ndarray - # only used by `brax.physics.spring` and `brax.physics.pbd`: + # only used by `brax.physics.spring` and `brax.physics.positional`: constraint_ang_damping: jp.ndarray @@ -284,6 +284,7 @@ class DoF(Base): damping: restorative force back to zero velocity limit: tuple of min, max angle limits invweight: diagonal inverse inertia at init_qpos + solver_params: (7,) limit constraint solver parameters """ motion: Motion @@ -293,6 +294,7 @@ class DoF(Base): limit: Tuple[jp.ndarray, jp.ndarray] # only used by `brax.physics.generalized`: invweight: jp.ndarray + solver_params: jp.ndarray @struct.dataclass @@ -305,12 +307,14 @@ class Geometry(Base): relative to the world frame in the case of unparented geometry friction: resistance encountered when sliding against another geometry elasticity: bounce/restitution encountered when hitting another geometry + solver_params: (7,) solver parameters (reference, impedance) """ link_idx: Optional[jp.ndarray] transform: Transform friction: jp.ndarray elasticity: jp.ndarray + solver_params: jp.ndarray @struct.dataclass @@ -354,6 +358,21 @@ class Box(Geometry): rgba: Optional[jp.ndarray] = None +@struct.dataclass +class Cylinder(Geometry): + """A cylinder. + + Attributes: + radius: (1,) radius of the top and bottom of the cylinder + length: (1,) length of the cylinder + rgba: (4,) the rgba to display in the renderer + """ + + radius: jp.ndarray + length: jp.ndarray + rgba: Optional[jp.ndarray] = None + + @struct.dataclass class Plane(Geometry): """An infinite plane whose normal points at +z in its coordinate space. @@ -410,6 +429,7 @@ class Contact(Base): two geometries are interpenetrating, negative means they are not friction: resistance encountered when sliding against another geometry elasticity: bounce/restitution encountered when hitting another geometry + solver_params: (7,) collision constraint solver parameters link_idx: Tuple of link indices participating in contact. The second part of the tuple can be None if the second geometry is static. """ @@ -418,8 +438,9 @@ class Contact(Base): normal: jp.ndarray penetration: jp.ndarray friction: jp.ndarray - # only used by `brax.physics.spring` and `brax.physics.pbd`: + # only used by `brax.physics.spring` and `brax.physics.positional`: elasticity: jp.ndarray + solver_params: jp.ndarray link_idx: Tuple[jp.ndarray, Optional[jp.ndarray]] @@ -429,13 +450,20 @@ class Actuator(Base): """Actuator, transforms an input signal into a force (motor or thruster). Attributes: - ctrl_range: (num_actuators, 2) control range for each actuator - gear: (num_actuators,) a list of floats used as a scaling factor for each - actuator torque output + q_id: (num_actuators,) q index associated with an actuator + qd_id: (num_actuators,) qd index associated with an actuator + ctrl_range: (num_actuators, 2) actuator control range + gear: (num_actuators,) scaling factor for each actuator torque output + bias_q: (num_actuators,) bias applied by q (e.g. position actuators) + bias_qd: (num_actuators,) bias applied by qd (e.g. velocity actuators) """ + q_id: jp.ndarray + qd_id: jp.ndarray ctrl_range: jp.ndarray gear: jp.ndarray + bias_q: jp.ndarray + bias_qd: jp.ndarray @struct.dataclass @@ -464,13 +492,13 @@ class System: Attributes: dt: timestep used for the simulation gravity: (3,) linear universal force applied during forward dynamics + viscosity: (1,) viscosity of the medium applied to all links + density: (1,) density of the medium applied to all links link: (num_link,) the links in the system dof: (qd_size,) every degree of freedom for the system geoms: list of batched geoms grouped by type actuator: actuators that can be applied to links init_q: (q_size,) initial q position for the system - solver_params_joint: (7,) joint limit constraint solver parameters - solver_params_contact: (7,) collision constraint solver parameters vel_damping: (1,) linear vel damping applied to each body. ang_damping: (1,) angular vel damping applied to each body. baumgarte_erp: how aggressively interpenetrating bodies should push away\ @@ -492,12 +520,6 @@ class System: * '3': spherical, 3 dof, like a ball joint link_parents: (num_link,) int list specifying the index of each link's parent link, or -1 if the link has no parent - actuator_types: (num_actuators,) string specifying the actuator types: - * 't': torque - * 'p': position - actuator_link_id: (num_actuators,) the link id associated with each actuator - actuator_qid: (num_actuators,) the q index associated with each actuator - actuator_qdid: (num_actuators,) the qd index associated with each actuator matrix_inv_iterations: maximum number of iterations of the matrix inverse solver_iterations: maximum number of iterations of the constraint solver solver_maxls: maximum number of line searches of the constraint solver @@ -505,33 +527,28 @@ class System: dt: jp.ndarray gravity: jp.ndarray + viscosity: jp.float32 + density: jp.float32 link: Link dof: DoF geoms: List[Geometry] actuator: Actuator init_q: jp.ndarray - # only used in `brax.physics.generalized` - solver_params_joint: jp.ndarray - solver_params_contact: jp.ndarray - # only used in `brax.physics.spring` and `brax.physics.pbd`: + # only used in `brax.physics.spring` and `brax.physics.positional`: vel_damping: jp.float32 ang_damping: jp.float32 baumgarte_erp: jp.float32 spring_mass_scale: jp.float32 spring_inertia_scale: jp.float32 - # only used in `brax.physics.positional` + # only used in `brax.physics.positional`: joint_scale_ang: jp.float32 joint_scale_pos: jp.float32 collide_scale: jp.float32 - + # non-pytree nodes geom_masks: List[int] = struct.field(pytree_node=False) link_names: List[str] = struct.field(pytree_node=False) link_types: str = struct.field(pytree_node=False) link_parents: Tuple[int, ...] = struct.field(pytree_node=False) - actuator_types: str = struct.field(pytree_node=False) - actuator_link_id: List[int] = struct.field(pytree_node=False) - actuator_qid: List[int] = struct.field(pytree_node=False) - actuator_qdid: List[int] = struct.field(pytree_node=False) # only used in `brax.physics.generalized`: matrix_inv_iterations: int = struct.field(pytree_node=False) solver_iterations: int = struct.field(pytree_node=False) @@ -595,7 +612,7 @@ def qd_size(self) -> int: def act_size(self) -> int: """Returns the act dimension for the system.""" - return sum({'m': 1, 'p': 1}[act_typ] for act_typ in self.actuator_types) + return self.actuator.q_id.shape[0] # below are some operation dispatch derivations diff --git a/brax/envs/humanoid.py b/brax/envs/humanoid.py index 4d32d8ec..d2bf0590 100644 --- a/brax/envs/humanoid.py +++ b/brax/envs/humanoid.py @@ -310,7 +310,8 @@ def _get_obs( com_ang = xd_i.ang com_velocity = jp.hstack([com_vel, com_ang]) - qfrc_actuator = actuator.to_tau(self.sys, action, pipeline_state.q) + qfrc_actuator = actuator.to_tau( + self.sys, action, pipeline_state.q, pipeline_state.qd) # external_contact_forces are excluded return jp.concatenate([ diff --git a/brax/envs/humanoidstandup.py b/brax/envs/humanoidstandup.py index 972d0d77..e7d6ea6d 100644 --- a/brax/envs/humanoidstandup.py +++ b/brax/envs/humanoidstandup.py @@ -254,7 +254,8 @@ def _get_obs( com_ang = xd_i.ang com_velocity = jp.hstack([com_vel, com_ang]) - qfrc_actuator = actuator.to_tau(self.sys, action, pipeline_state.q) + qfrc_actuator = actuator.to_tau( + self.sys, action, pipeline_state.q, pipeline_state.qd) # external_contact_forces are excluded return jp.concatenate([ diff --git a/brax/fluid.py b/brax/fluid.py new file mode 100644 index 00000000..b05e4e88 --- /dev/null +++ b/brax/fluid.py @@ -0,0 +1,83 @@ +# Copyright 2023 The Brax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint:disable=g-multiple-import +"""Functions for forces/torques through fluids.""" + +from typing import Tuple, Union +from brax.base import Force, Motion, System, Transform +import jax +import jax.numpy as jp + + +def _box_viscosity( + box: jp.ndarray, xd_i: Motion, viscosity: jp.ndarray +) -> Force: + """Gets force due to motion through a viscous fluid.""" + diam = jp.mean(box, axis=-1) + ang_scale = -jp.pi * diam**3 * viscosity + vel_scale = -3.0 * jp.pi * diam * viscosity + frc = Force( + ang=ang_scale[:, None] * xd_i.ang, vel=vel_scale[:, None] * xd_i.vel + ) + return frc + + +def _box_density(box: jp.ndarray, xd_i: Motion, density: jp.ndarray) -> Force: + """Gets force due to motion through dense fluid.""" + + @jax.vmap + def apply(b: jp.ndarray, xd: Motion) -> Force: + box_mult_vel = jp.array([b[1] * b[2], b[0] * b[2], b[0] * b[1]]) + vel = -0.5 * density * box_mult_vel * jp.abs(xd.vel) * xd.vel + box_mult_ang = jp.array([ + b[0] * (b[1] ** 4 + b[2] ** 4), + b[1] * (b[0] ** 4 + b[2] ** 4), + b[2] * (b[0] ** 4 + b[1] ** 4), + ]) + ang = -1.0 * density * box_mult_ang * jp.abs(xd.ang) * xd.ang / 64.0 + return Force(vel=vel, ang=ang) + + return apply(box, xd_i) + + +def force( + sys: System, + x: Transform, + xd: Motion, + mass: jp.ndarray, + inertia: jp.ndarray, + subtree_com: Union[jp.ndarray, None] = None, +) -> Tuple[Force, jp.ndarray]: + """Returns force due to motion through a fluid.""" + # get the velocity at the com position/orientation + x_i = x.vmap().do(sys.link.inertia.transform) + # TODO: remove subtree_com when xd is fixed for stacked joints + offset = x_i.pos - x.pos if subtree_com is None else x_i.pos - subtree_com + xd_i = x_i.replace(pos=offset).vmap().do(xd) + + # TODO: add ellipsoid fluid model from mujoco + # TODO: consider adding wind from mj.opt.wind + diag_inertia = jax.vmap(jp.diag)(inertia) + diag_inertia_v = jp.repeat(diag_inertia, 3, axis=-2).reshape((-1, 3, 3)) + diag_inertia_v *= jp.ones((3, 3)) - 2 * jp.eye(3) + box = 6.0 * jp.clip(jp.sum(diag_inertia_v, axis=-1), a_min=1e-12) + box = jp.sqrt(box / mass[:, None]) + + frc = _box_viscosity(box, xd_i, sys.viscosity) + frc += _box_density(box, xd_i, sys.density) + + # rotate back to the world orientation + frc = Transform.create(rot=x_i.rot).vmap().do(frc) + return frc, x_i.pos diff --git a/brax/fluid_test.py b/brax/fluid_test.py new file mode 100644 index 00000000..5a9e54df --- /dev/null +++ b/brax/fluid_test.py @@ -0,0 +1,82 @@ +# Copyright 2023 The Brax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint:disable=g-multiple-import +"""Tests for fluid.""" + +from absl.testing import absltest +from absl.testing import parameterized +from brax import test_utils +from brax.generalized import dynamics +from brax.generalized import pipeline as g_pipeline +import jax +from jax import numpy as jp +import mujoco +import numpy as np + +assert_almost_equal = np.testing.assert_almost_equal + + +class FluidTest(parameterized.TestCase): + + @parameterized.parameters( + ('fluid_box.xml', 0.0, 0.0), + ('fluid_box.xml', 0.0, 1.3), + ('fluid_box.xml', 2.0, 0.0), + ('fluid_box.xml', 2.0, 1.3), + ('fluid_box_offset_com.xml', 0.0, 1.3), + ('fluid_box_offset_com.xml', 2.0, 0.0), + ('fluid_box_offset_com.xml', 2.0, 1.3), + ) + def test_fluid_mj_generalized(self, config, density, viscosity): + """Tests fluid interactions.""" + sys = test_utils.load_fixture(config) + sys = sys.replace(density=density, viscosity=viscosity) + + mj_model = test_utils.load_fixture_mujoco(config) + mj_model.opt.density = density + mj_model.opt.viscosity = viscosity + mj_data = mujoco.MjData(mj_model) + # initialize qd so that the object interacts with the fluid + mj_data.qvel[:3] = [1, 2, 4] + mj_data.qvel[3] = jp.pi + q, qd = jp.asarray(mj_data.qpos), jp.asarray(mj_data.qvel) + + # check qfrc_passive after the first step + mujoco.mj_step(mj_model, mj_data) + state = jax.jit(g_pipeline.init)(sys, q, qd) + qfrc_passive = jax.jit(dynamics._passive)(sys, state) + np.testing.assert_array_almost_equal( + qfrc_passive[3:], mj_data.qfrc_passive[3:], 2 + ) + m = max(jp.abs(mj_data.qfrc_passive[:3])) + 1e-6 + np.testing.assert_array_almost_equal( + qfrc_passive[:3] / m, mj_data.qfrc_passive[:3] / m, 1 + ) + + # check q/qd after multiple steps + for _ in range(500): + mujoco.mj_step(mj_model, mj_data) + mq, mqd = jp.asarray(mj_data.qpos), jp.asarray(mj_data.qvel) + + for _ in range(500): + state = jax.jit(g_pipeline.step)(sys, state, jp.zeros((sys.act_size(),))) + gq, gqd = state.q, state.qd + + np.testing.assert_array_almost_equal(gq, mq, 2) + np.testing.assert_array_almost_equal(gqd, mqd, 2) + + +if __name__ == '__main__': + absltest.main() diff --git a/brax/generalized/constraint.py b/brax/generalized/constraint.py index 5b9ef367..be59d59a 100644 --- a/brax/generalized/constraint.py +++ b/brax/generalized/constraint.py @@ -26,38 +26,6 @@ import jaxopt -def _pt_jac( - sys: System, - com: jp.ndarray, - cdof: Motion, - pos: jp.ndarray, - link_idx: jp.ndarray, -) -> jp.ndarray: - """Calculates the point jacobian. - - Args: - sys: a brax system - com: center of mass position - cdof: dofs in com frame - pos: position in world frame - link_idx: index of link frame to transform point jacobian - - Returns: - pt: point jacobian - """ - # backward scan up tree: build the link mask corresponding to link_idx - def mask_fn(mask_child, link): - mask = link == link_idx - if mask_child is not None: - mask += mask_child - return mask - - mask = scan.tree(sys, mask_fn, 'l', jp.arange(sys.num_links()), reverse=True) - cdof = jax.vmap(lambda a, b: a * b)(cdof, jp.take(mask, sys.dof_link())) - off = Transform.create(pos=pos - com[link_idx]) - return off.vmap(in_axes=(None, 0)).do(cdof).vel - - def _imp_aref( params: jp.ndarray, pos: jp.ndarray, vel: jp.ndarray ) -> Tuple[jp.ndarray, jp.ndarray]: @@ -87,11 +55,49 @@ def _imp_aref( b = 2 / (dmax * timeconst) k = 1 / (dmax * dmax * timeconst * timeconst * dampratio * dampratio) + # See https://mujoco.readthedocs.io/en/latest/modeling.html#solver-parameters + stiffness, damping = params[:2] + use_v2_params = (stiffness < 0) & (damping < 0) + b = jp.where(use_v2_params, -damping / dmax, b) + k = jp.where(use_v2_params, -stiffness / (dmax * dmax), k) + aref = -b * vel - k * imp * pos return imp, aref +def point_jacobian( + sys: System, + com: jp.ndarray, + cdof: Motion, + pos: jp.ndarray, + link_idx: jp.ndarray, +) -> Motion: + """Calculates the jacobian of a point on a link. + + Args: + sys: a brax system + com: center of mass position + cdof: dofs in com frame + pos: position in world frame to calculate the jacobian + link_idx: index of link frame to transform point jacobian + + Returns: + pt: point jacobian + """ + # backward scan up tree: build the link mask corresponding to link_idx + def mask_fn(mask_child, link): + mask = link == link_idx + if mask_child is not None: + mask += mask_child + return mask + + mask = scan.tree(sys, mask_fn, 'l', jp.arange(sys.num_links()), reverse=True) + cdof = jax.vmap(lambda a, b: a * b)(cdof, jp.take(mask, sys.dof_link())) + off = Transform.create(pos=pos - com[link_idx]) + return off.vmap(in_axes=(None, 0)).do(cdof) + + def jac_limit( sys: System, state: State ) -> Tuple[jp.ndarray, jp.ndarray, jp.ndarray]: @@ -118,7 +124,8 @@ def jac_limit( side = ((pos_min < pos_max) * 2 - 1) * (pos < 0) jac = jax.vmap(jp.multiply)(jp.eye(sys.qd_size())[qd_idx], side) - imp, aref = _imp_aref(sys.solver_params_joint, pos, jac @ state.qd) + params = sys.dof.solver_params[qd_idx] + imp, aref = jax.vmap(_imp_aref)(params, pos, jac @ state.qd) diag = sys.dof.invweight[qd_idx] * (pos < 0) * (1 - imp) / (imp + 1e-8) aref = jax.vmap(lambda x, y: x * y)(aref, (pos < 0)) @@ -146,8 +153,8 @@ def jac_contact( def row_fn(contact): link_a, link_b = contact.link_idx - a = _pt_jac(sys, state.com, state.cdof, contact.pos, link_a) - b = _pt_jac(sys, state.com, state.cdof, contact.pos, link_b) + a = point_jacobian(sys, state.com, state.cdof, contact.pos, link_a).vel + b = point_jacobian(sys, state.com, state.cdof, contact.pos, link_b).vel diff = b - a # 4 pyramidal friction directions @@ -158,7 +165,7 @@ def row_fn(contact): jac = jp.stack(jac) pos = -jp.tile(contact.penetration, 4) - imp, aref = _imp_aref(sys.solver_params_contact, pos, jac @ state.qd) + imp, aref = _imp_aref(contact.solver_params, pos, jac @ state.qd) t = sys.link.invweight[link_a] + sys.link.invweight[link_b] * (link_b > -1) diag = jp.tile(t + contact.friction * contact.friction * t, 4) diag *= 2 * contact.friction * contact.friction * (1 - imp) / (imp + 1e-8) diff --git a/brax/generalized/constraint_test.py b/brax/generalized/constraint_test.py index 057918da..d10de72d 100644 --- a/brax/generalized/constraint_test.py +++ b/brax/generalized/constraint_test.py @@ -20,6 +20,7 @@ from brax import test_utils from brax.generalized import pipeline import jax +from jax import numpy as jp import numpy as np @@ -30,6 +31,7 @@ class ConstraintTest(parameterized.TestCase): ('triple_pendulum.xml',), ('humanoid.xml',), ('half_cheetah.xml',), + ('solver_params_v2.xml',) ) def test_jacobian(self, xml_file): """Test constraint jacobian.""" @@ -56,6 +58,7 @@ def test_jacobian(self, xml_file): ('triple_pendulum.xml',), ('humanoid.xml',), ('half_cheetah.xml',), + ('solver_params_v2.xml',) ) def test_force(self, xml_file): """Test constraint force.""" @@ -68,8 +71,9 @@ def test_force(self, xml_file): # force PGS so we can reference efc_AR: samples = test_utils.sample_mujoco_states(xml_file, force_pgs=True) for mj_prev, mj_next in samples: + act = jp.zeros(sys.act_size()) state = jax.jit(pipeline.init)(sys, mj_prev.qpos, mj_prev.qvel) - state = jax.jit(pipeline.step)(sys, state, mj_prev.qfrc_applied) + state = jax.jit(pipeline.step)(sys, state, act) efc_jt = np.reshape(mj_next.efc_J, (-1, sys.qd_size())).T # recover con_frc by backing it out from qf_constraint con_frc = np.linalg.lstsq(efc_jt, state.qf_constraint, None)[0] diff --git a/brax/generalized/dynamics.py b/brax/generalized/dynamics.py index 6743ca11..781f58af 100644 --- a/brax/generalized/dynamics.py +++ b/brax/generalized/dynamics.py @@ -14,10 +14,14 @@ # pylint:disable=g-multiple-import """Functions for smooth forward and inverse dynamics.""" +import functools + +from brax import fluid from brax import math from brax import scan from brax.base import Motion, System, Transform from brax.generalized.base import State +from brax.generalized.constraint import point_jacobian import jax from jax import numpy as jp @@ -180,16 +184,27 @@ def cfrc_fn(cfrc_child, cfrc): return tau -def _passive(sys: System, q: jp.ndarray, qd: jp.ndarray) -> jp.ndarray: +def _passive(sys: System, state: State) -> jp.ndarray: """Calculates the system's passive forces given input motion and position.""" - def stiffness_fn(typ, q, dof): if typ in 'fb': return jp.zeros_like(dof.stiffness) return -q * dof.stiffness - frc = scan.link_types(sys, stiffness_fn, 'qd', 'd', q, sys.dof) - frc -= sys.dof.damping * qd + frc = scan.link_types(sys, stiffness_fn, 'qd', 'd', state.q, sys.dof) + frc -= sys.dof.damping * state.qd + + fluid_frc, pos = fluid.force( + sys, + state.x, + state.cd, + sys.link.inertia.mass, + sys.link.inertia.i, + state.com, + ) + cdof_fn = functools.partial(point_jacobian, sys, state.com, state.cdof) + jac = jax.vmap(cdof_fn)(pos, jp.arange(sys.num_links())) + frc += jax.vmap(lambda x, y: x.dot(y))(jac, fluid_frc).sum(axis=0) return frc @@ -210,7 +225,7 @@ def forward(sys: System, state: State, tau: jp.ndarray) -> jp.ndarray: Returns: qfrc: joint force vector """ - qfrc_passive = _passive(sys, state.q, state.qd) + qfrc_passive = _passive(sys, state) qfrc_bias = inverse(sys, state) qfrc = qfrc_passive - qfrc_bias + tau diff --git a/brax/generalized/dynamics_test.py b/brax/generalized/dynamics_test.py index daefdf89..fb0487cc 100644 --- a/brax/generalized/dynamics_test.py +++ b/brax/generalized/dynamics_test.py @@ -20,6 +20,7 @@ from brax import test_utils from brax.generalized import pipeline import jax +from jax import numpy as jp import numpy as np @@ -55,8 +56,9 @@ def test_forward(self, xml_file): """Test dynamics forward.""" sys = test_utils.load_fixture(xml_file) for mj_prev, mj_next in test_utils.sample_mujoco_states(xml_file): + act = jp.zeros(sys.act_size()) state = jax.jit(pipeline.init)(sys, mj_prev.qpos, mj_prev.qvel) - state = jax.jit(pipeline.step)(sys, state, mj_prev.qfrc_applied) + state = jax.jit(pipeline.step)(sys, state, act) np.testing.assert_allclose( state.qf_smooth, mj_next.qfrc_smooth, rtol=1e-4, atol=1e-4 diff --git a/brax/generalized/perf_test.py b/brax/generalized/perf_test.py index bacbb317..2f6de4e0 100644 --- a/brax/generalized/perf_test.py +++ b/brax/generalized/perf_test.py @@ -35,7 +35,7 @@ def init_fn(rng): return pipeline.init(sys, q, qd) def step_fn(state): - return pipeline.step(sys, state, jp.zeros(sys.qd_size())) + return pipeline.step(sys, state, jp.zeros(sys.act_size())) test_utils.benchmark('generalized pipeline ant', init_fn, step_fn) diff --git a/brax/generalized/pipeline.py b/brax/generalized/pipeline.py index 9c9425e8..638f8c1f 100644 --- a/brax/generalized/pipeline.py +++ b/brax/generalized/pipeline.py @@ -67,7 +67,7 @@ def step( state: physics state after step """ # calculate acceleration terms - tau = actuator.to_tau(sys, act, state.q) + tau = actuator.to_tau(sys, act, state.q, state.qd) state = state.replace(qf_smooth=dynamics.forward(sys, state, tau)) state = state.replace(qf_constraint=constraint.force(sys, state)) diff --git a/brax/generalized/pipeline_test.py b/brax/generalized/pipeline_test.py index 1c17350d..b34f6817 100644 --- a/brax/generalized/pipeline_test.py +++ b/brax/generalized/pipeline_test.py @@ -20,6 +20,7 @@ from brax import test_utils from brax.generalized import pipeline import jax +from jax import numpy as jp import numpy as np @@ -38,7 +39,7 @@ def test_forward(self, xml_file): sys = sys.replace(solver_iterations=500) for mj_prev, mj_next in test_utils.sample_mujoco_states(xml_file): state = jax.jit(pipeline.init)(sys, mj_prev.qpos, mj_prev.qvel) - state = jax.jit(pipeline.step)(sys, state, mj_prev.qfrc_applied) + state = jax.jit(pipeline.step)(sys, state, jp.zeros(sys.act_size())) np.testing.assert_allclose(state.q, mj_next.qpos, atol=0.002) np.testing.assert_allclose(state.qd, mj_next.qvel, atol=0.5) diff --git a/brax/geometry/contact.py b/brax/geometry/contact.py index 997c061e..b6c2a37c 100644 --- a/brax/geometry/contact.py +++ b/brax/geometry/contact.py @@ -20,6 +20,7 @@ from brax import math from brax.base import ( Capsule, + Cylinder, Contact, Convex, Geometry, @@ -35,19 +36,6 @@ from jax import numpy as jp -def _combine( - geom_a: Geometry, geom_b: Geometry -) -> Tuple[float, float, Tuple[int, int]]: - # default is to take maximum, but can override - friction = jp.maximum(geom_a.friction, geom_b.friction) - elasticity = jp.maximum(geom_a.elasticity, geom_b.elasticity) - link_idx = ( - geom_a.link_idx, - geom_b.link_idx if geom_b.link_idx is not None else -1, - ) - return friction, elasticity, link_idx # pytype: disable=bad-return-type # jax-ndarray - - def _sphere_plane(sphere: Sphere, plane: Plane) -> Contact: """Calculates one contact between a sphere and a plane.""" n = math.rotate(jp.array([0.0, 0.0, 1.0]), plane.transform.rot) @@ -55,8 +43,8 @@ def _sphere_plane(sphere: Sphere, plane: Plane) -> Contact: penetration = sphere.radius - t # halfway between contact points on sphere and on plane pos = sphere.transform.pos - n * (sphere.radius - 0.5 * penetration) - c = Contact(pos, n, penetration, *_combine(sphere, plane)) # pytype: disable=wrong-arg-types # jax-ndarray - # add a batch dimension of size 1 + c = Contact(pos, n, penetration, *_combine(sphere, plane)) + # returns 1 contact, so add a batch dimension of size 1 return jax.tree_map(lambda x: jp.expand_dims(x, axis=0), c) @@ -67,8 +55,8 @@ def _sphere_sphere(s_a: Sphere, s_b: Sphere) -> Contact: s_a_pos = s_a.transform.pos - n * s_a.radius s_b_pos = s_b.transform.pos + n * s_b.radius pos = (s_a_pos + s_b_pos) * 0.5 - c = Contact(pos, n, penetration, *_combine(s_a, s_b)) # pytype: disable=wrong-arg-types # jax-ndarray - # add a batch dimension of size 1 + c = Contact(pos, n, penetration, *_combine(s_a, s_b)) + # returns 1 contact, so add a batch dimension of size 1 return jax.tree_map(lambda x: jp.expand_dims(x, axis=0), c) @@ -88,8 +76,57 @@ def _sphere_capsule(sphere: Sphere, capsule: Capsule) -> Contact: cap_pos = pt + n * capsule.radius pos = (sphere_pos + cap_pos) * 0.5 - c = Contact(pos, n, penetration, *_combine(sphere, capsule)) # pytype: disable=wrong-arg-types # jax-ndarray - # add a batch dimension of size 1 + c = Contact(pos, n, penetration, *_combine(sphere, capsule)) + # returns 1 contact, so add a batch dimension of size 1 + return jax.tree_map(lambda x: jp.expand_dims(x, axis=0), c) + + +def _sphere_circle(sphere: Sphere, circle: Cylinder) -> Contact: + """Calculates one contact between a sphere and a circle.""" + n = math.rotate(jp.array([0.0, 0.0, 1.0]), circle.transform.rot) + + # orient the normal s.t. it points at the CoM of the sphere + normal_dir = jp.sign( + (sphere.transform.pos - circle.transform.pos).dot(n)) + n = n * normal_dir + + pos = sphere.transform.pos - n * sphere.radius + plane_pt = circle.transform.pos + penetration = jp.dot(plane_pt - pos, n) + + # check if the sphere radius is within the cylinder in the normal dir of the + # circle + plane_pt2 = plane_pt + n + line_pt = geom_math.closest_line_point( + plane_pt, plane_pt2, sphere.transform.pos + ) + in_cylinder = (sphere.transform.pos - line_pt).dot( + sphere.transform.pos - line_pt + ) <= circle.radius**2 + + # get closest point on circle edge + perp_dir = jp.cross(n, sphere.transform.pos - plane_pt) + perp_dir = math.rotate(perp_dir, math.quat_rot_axis(n, -jp.pi / 2.0)) + perp_dir, _ = math.normalize(perp_dir) + edge_pt = plane_pt + perp_dir * circle.radius + edge_contact = (sphere.transform.pos - edge_pt).dot( + sphere.transform.pos - edge_pt + ) <= sphere.radius**2 + edge_to_sphere = sphere.transform.pos - edge_pt + edge_to_sphere = math.normalize(edge_to_sphere)[0] + + penetration = jp.where(in_cylinder, penetration, -jp.ones_like(penetration)) + penetration = jp.where( + edge_contact, + sphere.radius + - jp.sqrt( + (sphere.transform.pos - edge_pt).dot(sphere.transform.pos - edge_pt) + ), + penetration, + ) + n = jp.where(edge_contact, edge_to_sphere, n) + pos = jp.where(edge_contact, edge_pt, pos) + c = Contact(pos, n, penetration, *_combine(sphere, circle)) # pytype: disable=wrong-arg-types # jax-ndarray return jax.tree_map(lambda x: jp.expand_dims(x, axis=0), c) @@ -154,7 +191,7 @@ def get_support(faces, normal): penetration = sphere.radius - d pos = (pt + spt) * 0.5 - c = Contact(pos, n, penetration, *_combine(sphere, convex)) # pytype: disable=wrong-arg-types # jax-ndarray + c = Contact(pos, n, penetration, *_combine(sphere, convex)) return jax.tree_map(lambda x: jp.expand_dims(x, axis=0), c) @@ -174,7 +211,7 @@ def sphere_face(face): penetration = sphere.radius - dist sph_p = sphere.transform.pos - n * sphere.radius pos = (tri_p + sph_p) * 0.5 - return Contact(pos, n, penetration, *_combine(sphere, mesh)) # pytype: disable=wrong-arg-types # jax-ndarray + return Contact(pos, n, penetration, *_combine(sphere, mesh)) return sphere_face(jp.take(mesh.vert, mesh.face, axis=0)) @@ -187,11 +224,12 @@ def _capsule_plane(capsule: Capsule, plane: Plane) -> Contact: results = [] for off in [segment, -segment]: sphere = Sphere( + radius=capsule.radius, link_idx=capsule.link_idx, transform=Transform.create(pos=capsule.transform.pos + off), friction=capsule.friction, elasticity=capsule.elasticity, - radius=capsule.radius, + solver_params=capsule.solver_params, ) results.append(_sphere_plane(sphere, plane)) @@ -217,8 +255,8 @@ def _capsule_capsule(cap_a: Capsule, cap_b: Capsule) -> Contact: cap_b_pos = pt_b + n * cap_b.radius pos = (cap_a_pos + cap_b_pos) * 0.5 - c = Contact(pos, n, penetration, *_combine(cap_a, cap_b)) # pytype: disable=wrong-arg-types # jax-ndarray - # add a batch dimension of size 1 + c = Contact(pos, n, penetration, *_combine(cap_a, cap_b)) + # returns 1 contact, so add a batch dimension of size 1 return jax.tree_map(lambda x: jp.expand_dims(x, axis=0), c) @@ -306,10 +344,9 @@ def get_support(face, normal): penetration = jp.where( has_edge_contact, penetration.at[0].set(edge_penetration), penetration ) - friction, elasticity, link_idx = jax.tree_map( - lambda x: jp.repeat(x, 2), _combine(capsule, convex) - ) - return Contact(pos, norm, penetration, friction, elasticity, link_idx) + tile_fn = lambda x: jp.tile(x, (2,) + tuple([1 for _ in x.shape])) + params = jax.tree_map(tile_fn, _combine(capsule, convex)) + return Contact(pos, norm, penetration, *params) def _capsule_mesh(capsule: Capsule, mesh: Mesh) -> Contact: @@ -335,7 +372,7 @@ def capsule_face(face, face_norm): penetration = capsule.radius - dist cap_p = seg_p - n * capsule.radius pos = (tri_p + cap_p) * 0.5 - return Contact(pos, n, penetration, *_combine(capsule, mesh)) # pytype: disable=wrong-arg-types # jax-ndarray + return Contact(pos, n, penetration, *_combine(capsule, mesh)) face_vert = jp.take(mesh.vert, mesh.face, axis=0) face_norm = geom_mesh.get_face_norm(mesh.vert, mesh.face) @@ -360,10 +397,9 @@ def transform_verts(vertices): normal = jp.stack([n] * 4, axis=0) unique = jp.tril(idx == idx[:, None]).sum(axis=1) == 1 penetration = jp.where(unique, support[idx], -1) - friction, elasticity, link_idx = jax.tree_map( - lambda x: jp.repeat(x, 4), _combine(convex, plane) - ) - return Contact(pos, normal, penetration, friction, elasticity, link_idx) + tile_fn = lambda x: jp.tile(x, (4,) + tuple([1 for _ in x.shape])) + params = jax.tree_map(tile_fn, _combine(convex, plane)) + return Contact(pos, normal, penetration, *params) def _convex_convex(convex_a: Convex, convex_b: Convex) -> Contact: @@ -416,18 +452,10 @@ def transform_verts(convex, vertices): unique_edges_a, unique_edges_b, ) - friction, elasticity, link_idx = jax.tree_map( - lambda x: jp.repeat(x, 4), _combine(convex_a, convex_b) - ) + tile_fn = lambda x: jp.tile(x, (4,) + tuple([1 for _ in x.shape])) + params = jax.tree_map(tile_fn, _combine(convex_a, convex_b)) - return Contact( - c.pos, - c.normal, - c.penetration, - friction, - elasticity, - link_idx, - ) + return Contact(c.pos, c.normal, c.penetration, *params) def _mesh_plane(mesh: Mesh, plane: Plane) -> Contact: @@ -438,15 +466,30 @@ def point_plane(vert): n = math.rotate(jp.array([0.0, 0.0, 1.0]), plane.transform.rot) pos = mesh.transform.pos + math.rotate(vert, mesh.transform.rot) penetration = jp.dot(plane.transform.pos - pos, n) - return Contact(pos, n, penetration, *_combine(mesh, plane)) # pytype: disable=wrong-arg-types # jax-ndarray + return Contact(pos, n, penetration, *_combine(mesh, plane)) return point_plane(mesh.vert) +def _combine( + geom_a: Geometry, geom_b: Geometry +) -> Tuple[jp.ndarray, jp.ndarray, jp.ndarray, Tuple[jp.ndarray, jp.ndarray]]: + # default is to take maximum, but can override + friction = jp.maximum(geom_a.friction, geom_b.friction) + elasticity = (geom_a.elasticity + geom_b.elasticity) * 0.5 + solver_params = (geom_a.solver_params + geom_b.solver_params) * 0.5 + link_idx = ( + jp.array(geom_a.link_idx if geom_a.link_idx is not None else -1), + jp.array(geom_b.link_idx if geom_b.link_idx is not None else -1), + ) + return friction, elasticity, solver_params, link_idx + + _TYPE_FUN = { (Sphere, Plane): _sphere_plane, (Sphere, Sphere): _sphere_sphere, (Sphere, Capsule): _sphere_capsule, + (Sphere, Cylinder): _sphere_circle, (Sphere, Convex): _sphere_convex, (Sphere, Mesh): _sphere_mesh, (Capsule, Plane): _capsule_plane, diff --git a/brax/geometry/contact_test.py b/brax/geometry/contact_test.py index cccd7373..95532d59 100644 --- a/brax/geometry/contact_test.py +++ b/brax/geometry/contact_test.py @@ -91,6 +91,28 @@ def test_sphere_capsule(self): np.testing.assert_array_almost_equal(c.pos, jp.array([0.0, 0.3, 0.045])) np.testing.assert_array_almost_equal(c.normal, jp.array([0, 0.0, -1.0])) + _SPHERE_CYLINDER = """ + + + + + + + + + + + + + """ + + def test_sphere_cylinder(self): + sys = mjcf.loads(self._SPHERE_CYLINDER) + x, _ = kinematics.forward(sys, sys.init_q, jp.zeros(sys.qd_size())) + c = geometry.contact(sys, x).take(0) + + np.testing.assert_array_almost_equal(c.penetration, 0.01) + _SPHERE_CONVEX = """ diff --git a/brax/geometry/math.py b/brax/geometry/math.py index 24d5041d..bfb06a71 100644 --- a/brax/geometry/math.py +++ b/brax/geometry/math.py @@ -40,6 +40,15 @@ def closest_segment_point_and_dist( return closest, dist +def closest_line_point( + a: jp.ndarray, b: jp.ndarray, pt: jp.ndarray +) -> jp.ndarray: + """Returns the closest point on the a-b line to a point pt.""" + ab = b - a + t = jp.dot(pt - a, ab) / (jp.dot(ab, ab) + 1e-6) + return a + t * ab + + def closest_segment_to_segment_points( a0: jp.ndarray, a1: jp.ndarray, b0: jp.ndarray, b1: jp.ndarray ) -> Tuple[jp.ndarray, jp.ndarray]: @@ -432,6 +441,7 @@ def _create_contact_manifold( penetration=penetration, friction=jp.array([]), elasticity=jp.array([]), + solver_params=jp.array([]), link_idx=jp.array([]), ) diff --git a/brax/geometry/mesh.py b/brax/geometry/mesh.py index 89b25cba..9a9afb41 100644 --- a/brax/geometry/mesh.py +++ b/brax/geometry/mesh.py @@ -128,6 +128,7 @@ def box_tri(b: Box) -> Mesh: transform=b.transform, friction=b.friction, elasticity=b.elasticity, + solver_params=b.solver_params, rgba=b.rgba, ) @@ -142,6 +143,7 @@ def _box_hull(b: Box) -> Convex: transform=b.transform, friction=b.friction, elasticity=b.elasticity, + solver_params=b.solver_params, unique_edge=get_unique_edges(vert, face), rgba=b.rgba, ) @@ -225,6 +227,7 @@ def _convex_hull(m: Mesh) -> Convex: transform=m.transform, friction=m.friction, elasticity=m.elasticity, + solver_params=m.solver_params, unique_edge=get_unique_edges(vert, face), rgba=m.rgba, ) @@ -244,6 +247,7 @@ def convex_hull(obj: Union[Box, Mesh]) -> Convex: transform=obj.transform, friction=obj.friction, elasticity=obj.elasticity, + solver_params=obj.solver_params, rgba=obj.rgba, ) return convex diff --git a/brax/geometry/mesh_test.py b/brax/geometry/mesh_test.py index 6a0f5cdc..d84772a2 100644 --- a/brax/geometry/mesh_test.py +++ b/brax/geometry/mesh_test.py @@ -31,6 +31,7 @@ def test_box(self): transform=None, friction=0.42, elasticity=1, + solver_params=None, ) m = mesh.box_tri(b) self.assertIsInstance(m, Mesh) @@ -64,6 +65,7 @@ def test_box_hull(self): transform=None, friction=0.42, elasticity=1, + solver_params=None, ) h = mesh.convex_hull(b) self.assertIsInstance(h, Convex) @@ -106,6 +108,7 @@ def test_pyramid(self): face=face, friction=1, elasticity=0, + solver_params=None, ) h = mesh.convex_hull(pyramid) diff --git a/brax/io/json.py b/brax/io/json.py index 56d7208b..015eee58 100644 --- a/brax/io/json.py +++ b/brax/io/json.py @@ -18,7 +18,6 @@ import json from typing import List, Text -from brax import geometry from brax.base import State, System from etils import epath import jax diff --git a/brax/io/mjcf.py b/brax/io/mjcf.py index c646bba9..c1f086c4 100644 --- a/brax/io/mjcf.py +++ b/brax/io/mjcf.py @@ -25,6 +25,7 @@ Actuator, Box, Capsule, + Cylinder, DoF, Inertia, Link, @@ -43,13 +44,6 @@ import numpy as np -# map from mujoco bias type to brax actuator type string -_ACT_TYPE_STR = { - 0: 'm', # motor - 1: 'p', # position -} - - def _transform_do( pos: np.ndarray, quat: np.ndarray, cpos: np.ndarray, cquat: np.ndarray ) -> Tuple[np.ndarray, np.ndarray]: @@ -251,6 +245,10 @@ def load_model(mj: mujoco.MjModel) -> System: raise NotImplementedError( 'Only joint transmission types are supported for actuators.' ) + if (mj.geom_solmix[0] != mj.geom_solmix).any(): + raise NotImplementedError('geom_solmix parameter not supported.') + if (mj.geom_priority[0] != mj.geom_priority).any(): + raise NotImplementedError('geom_priority parameter not supported.') if mj.opt.collision == 1: raise NotImplementedError('Predefined collisions not supported.') q_width = {0: 7, 1: 4, 2: 1, 3: 1} @@ -258,7 +256,10 @@ def load_model(mj: mujoco.MjModel) -> System: if mj.qpos0[non_free].any(): raise NotImplementedError( 'The `ref` attribute on joint types is not supported.') - + if (mj.geom_fluid != 0).any(): + raise NotImplementedError('Ellipsoid fluid model not implemented.') + if mj.opt.wind.any(): + raise NotImplementedError('option.wind is not implemented.') custom = _get_custom(mj) # create links @@ -326,22 +327,25 @@ def load_model(mj: mujoco.MjModel) -> System: if np.any(mj.jnt_limited): limit = jax.tree_map(lambda *x: np.concatenate(x), *limits) stiffness = np.concatenate(stiffnesses) + solver_params_jnt = np.concatenate((mj.jnt_solref, mj.jnt_solimp), axis=1) + solver_params_dof = solver_params_jnt[mj.dof_jntid] - dof = DoF( # pytype: disable=wrong-arg-types # jax-ndarray + dof = DoF( # pytype: disable=wrong-arg-types motion=motion, armature=mj.dof_armature, stiffness=stiffness, damping=mj.dof_damping, limit=limit, invweight=mj.dof_invweight0, + solver_params=solver_params_dof, ) + solver_params_geom = np.concatenate((mj.geom_solref, mj.geom_solimp), axis=1) # group geoms so that they can be stacked. two geoms can be stacked if: # - they have the same type # - their fields have the same shape (e.g. Mesh verts might vary) # - they have the same mask key_fn = lambda g, m: (jax.tree_map(np.shape, g), m) - geom_groups = {} for i, typ in enumerate(mj.geom_type): rgba = mj.geom_rgba[i] @@ -353,12 +357,21 @@ def load_model(mj: mujoco.MjModel) -> System: 'transform': Transform(pos=mj.geom_pos[i], rot=mj.geom_quat[i]), 'friction': mj.geom_friction[i, 0], 'elasticity': custom['elasticity'][i], + 'solver_params': solver_params_geom[i], 'rgba': rgba, } mask = mj.geom_contype[i] | mj.geom_conaffinity[i] << 32 if typ == 0: # Plane geom = Plane(**kwargs) geom_groups.setdefault(key_fn(geom, mask), []).append(geom) + elif typ == 5: # Cylinder + radius, halflength = mj.geom_size[i, 0:2] + if halflength > 0.001 and mask > 0: + # TODO: support cylinders with depth. + raise NotImplementedError( + 'Cylinders of half-length>0.001 are not supported for collision.') + geom = Cylinder(radius=radius, length=halflength * 2, **kwargs) + geom_groups.setdefault(key_fn(geom, mask), []).append(geom) elif typ == 2: # Sphere geom = Sphere(radius=mj.geom_size[i, 0], **kwargs) geom_groups.setdefault(key_fn(geom, mask), []).append(geom) @@ -386,30 +399,28 @@ def load_model(mj: mujoco.MjModel) -> System: continue geoms = [ - jax.tree_map(lambda *x: jp.stack(x), *g) for g in geom_groups.values() + jax.tree_map(lambda *x: np.stack(x), *g) for g in geom_groups.values() ] geom_masks = [m for _, m in geom_groups.keys()] # create actuators ctrl_range = mj.actuator_ctrlrange ctrl_range[~(mj.actuator_ctrllimited == 1), :] = np.array([-np.inf, np.inf]) - actuator = Actuator( # pytype: disable=wrong-arg-types # jax-ndarray + q_id = np.array([mj.jnt_qposadr[i] for i in mj.actuator_trnid[:, 0]]) + qd_id = np.array([mj.jnt_dofadr[i] for i in mj.actuator_trnid[:, 0]]) + bias_q = mj.actuator_biasprm[:, 1] + bias_qd = mj.actuator_biasprm[:, 2] + + # TODO: might be nice to add actuator names for debugging + actuator = Actuator( # pytype: disable=wrong-arg-types + q_id=q_id, + qd_id=qd_id, gear=mj.actuator_gear[:, 0], ctrl_range=ctrl_range, + bias_q=bias_q, + bias_qd=bias_qd, ) - # create generalized solver params - params_joint = jp.concatenate((mj.jnt_solref, mj.jnt_solimp), axis=1) - params_geom = jp.concatenate((mj.geom_solref, mj.geom_solimp), axis=1) - params_pair = jp.concatenate((mj.pair_solref, mj.pair_solimp), axis=1) - params_contact = jp.concatenate((params_geom, params_pair)) - if (params_joint[0] != params_joint).any(): - raise NotImplementedError('brax only supports one joint solver params') - if (params_contact[0] != params_contact).any(): - raise NotImplementedError('brax only supports one contact solver params') - solver_params_joint = params_joint[0] - solver_params_contact = params_contact[0] - # create non-pytree params. these do not live on device directly, and they # cannot be differentiated, but they do change the emitted control flow link_names = [_get_name(mj, i) for i in mj.name_bodyadr[1:]] @@ -430,21 +441,6 @@ def load_model(mj: mujoco.MjModel) -> System: link_types += typ link_parents = tuple(mj.body_parentid - 1)[1:] - # create non-pytree params for actuators. - actuator_types = ''.join([_ACT_TYPE_STR[bt] for bt in mj.actuator_biastype]) - actuator_link_id = [mj.jnt_bodyid[i] - 1 for i in mj.actuator_trnid[:, 0]] - unsupported_act_links = set(link_types[i] for i in actuator_link_id) - { - '1', - '2', - '3', - } - if unsupported_act_links: - raise NotImplementedError( - f'Link types {unsupported_act_links} are not supported for actuators.' - ) - actuator_qid = [mj.jnt_qposadr[i] for i in mj.actuator_trnid[:, 0]] - actuator_qdid = [mj.jnt_dofadr[i] for i in mj.actuator_trnid[:, 0]] - # mujoco stores free q in world frame, so clear link transform for free links if 'f' in link_types: free_idx = np.array([i for i, typ in enumerate(link_types) if typ == 'f']) @@ -454,13 +450,13 @@ def load_model(mj: mujoco.MjModel) -> System: sys = System( # pytype: disable=wrong-arg-types # jax-ndarray dt=mj.opt.timestep, gravity=mj.opt.gravity, + viscosity=mj.opt.viscosity, + density=mj.opt.density, link=link, dof=dof, geoms=geoms, actuator=actuator, init_q=custom['init_qpos'] if 'init_qpos' in custom else mj.qpos0, - solver_params_joint=solver_params_joint, - solver_params_contact=solver_params_contact, vel_damping=custom['vel_damping'], ang_damping=custom['ang_damping'], baumgarte_erp=custom['baumgarte_erp'], @@ -473,10 +469,6 @@ def load_model(mj: mujoco.MjModel) -> System: link_names=link_names, link_types=link_types, link_parents=link_parents, - actuator_types=actuator_types, - actuator_link_id=actuator_link_id, - actuator_qid=actuator_qid, - actuator_qdid=actuator_qdid, matrix_inv_iterations=int(custom['matrix_inv_iterations']), solver_iterations=mj.opt.iterations, solver_maxls=int(custom['solver_maxls']), diff --git a/brax/io/mjcf_test.py b/brax/io/mjcf_test.py index dde7a0b5..02e2434e 100644 --- a/brax/io/mjcf_test.py +++ b/brax/io/mjcf_test.py @@ -171,6 +171,34 @@ def test_world_body_transform(self): sys.init_q, np.array([1.245, 0.0, 0.0, 0.5, 0.5, 0.5, -0.5]) ) + def test_load_flat_cylinder(self): + sys = test_utils.load_fixture('flat_cylinder.xml') + self.assertEqual(sys.geoms[1].radius, 0.25) + self.assertEqual(sys.geoms[1].length, 0.002) + + def test_load_fat_cylinder(self): + with self.assertRaisesRegex( + NotImplementedError, 'Cylinders of half-length' + ): + test_utils.load_fixture('fat_cylinder.xml') + + def test_load_fluid_box(self): + sys = test_utils.load_fixture('fluid_box.xml') + assert_almost_equal(sys.density, 1.2) + assert_almost_equal(sys.viscosity, 0.15) + + def test_load_fluid_ellipsoid(self): + with self.assertRaisesRegex( + NotImplementedError, 'Ellipsoid fluid model not implemented' + ): + test_utils.load_fixture('fluid_ellipsoid.xml') + + def test_load_wind(self): + with self.assertRaisesRegex( + NotImplementedError, 'option.wind is not implemented' + ): + test_utils.load_fixture('fluid_wind.xml') + if __name__ == '__main__': absltest.main() diff --git a/brax/kinematics.py b/brax/kinematics.py index f52d5e19..a3b89497 100644 --- a/brax/kinematics.py +++ b/brax/kinematics.py @@ -70,6 +70,7 @@ def jcalc(typ, q, qd, motion): for i in range(1, num_dofs): j_i, jd_i = j_stack.take(i, axis=1), jd_stack.take(i, axis=1) j = j.vmap().do(j_i) + # TODO: fix qd->jd calculation for stacked/offset joints jd = jd + Motion( ang=jax.vmap(math.rotate)(jd_i.ang, j_i.rot), vel=jax.vmap(math.rotate)( @@ -88,16 +89,16 @@ def jcalc(typ, q, qd, motion): def world(parent, j, jd): """Convert transform/motion from joint frame to world frame.""" if parent is None: + jd = jd.replace(ang=jax.vmap(math.rotate)(jd.ang, j.rot)) return j, jd - x, xd = parent - # TODO: determine why the motion `do` is inverted - x = x.vmap().do(j) - xd = xd + Motion( - ang=jax.vmap(math.rotate)(jd.ang, x.rot), - vel=jax.vmap(math.rotate)( - jd.vel + jax.vmap(jp.cross)(x.pos, jd.ang), x.rot - ), - ) + x_p, xd_p = parent + x = x_p.vmap().do(j) + # get the linear velocity at the tip of the parent + vel = xd_p.vel + jax.vmap(jp.cross)(xd_p.ang, x.pos - x_p.pos) + # add in the child linear velocity in the world frame + vel += jax.vmap(math.rotate)(jd.vel, x_p.rot) + ang = xd_p.ang + jax.vmap(math.rotate)(jd.ang, x.rot) + xd = Motion(vel=vel, ang=ang) return x, xd x, xd = scan.tree(sys, world, 'll', j, jd) @@ -126,7 +127,7 @@ def world_to_joint( # move into joint coordinates xd_joint = xd - xd_wj - inv_rotate = jax.vmap(lambda x, y: math.rotate(x, math.quat_inv(y))) + inv_rotate = jax.vmap(math.inv_rotate) jd = jax.tree_map(lambda x: inv_rotate(x, a_p.rot), xd_joint) return j, jd, a_p, a_c @@ -158,6 +159,9 @@ def link_to_joint_frame(motion: Motion) -> Tuple[Motion, float]: We also need translational components because the prismatic components of a joint might not be aligned with the rotational components of the joint. """ + if motion.ang.shape[0] > 3 or motion.ang.shape[0] == 0: + raise AssertionError('Motion shape must be in (0, 3], ' + f'got {motion.ang.shape[0]}') # 1-dof if motion.ang.shape[0] == 1: @@ -282,16 +286,14 @@ def axis_angle_ang( child_frame = v_rot(joint_motion.ang, j.rot) line_of_nodes = jp.cross(child_frame[2], joint_motion.ang[0]) - line_of_nodes = line_of_nodes / (1e-10 + math.safe_norm(line_of_nodes)) + line_of_nodes, _ = math.normalize(line_of_nodes) y_n_normal = joint_motion.ang[0] psi = math.signed_angle(y_n_normal, joint_motion.ang[1], line_of_nodes) axis_1_p_in_xz_c = ( jp.dot(joint_motion.ang[0], child_frame[0]) * child_frame[0] + jp.dot(joint_motion.ang[0], child_frame[1]) * child_frame[1] ) - axis_1_p_in_xz_c = axis_1_p_in_xz_c / ( - 1e-10 + math.safe_norm(axis_1_p_in_xz_c) - ) + axis_1_p_in_xz_c, _ = math.normalize(axis_1_p_in_xz_c) ang_between_1_p_xz_c = jp.dot(axis_1_p_in_xz_c, joint_motion.ang[0]) theta = math.safe_arccos(jp.clip(ang_between_1_p_xz_c, -1, 1)) * jp.sign( jp.dot(joint_motion.ang[0], child_frame[2]) @@ -331,10 +333,13 @@ def inverse( ) -> Tuple[jp.ndarray, jp.ndarray]: """Translates maximal coordinates into reduced coordinates.""" - def free(x, xd, _): - return jp.concatenate([x.pos, x.rot]), jp.concatenate([xd.vel, xd.ang]) + def free(x, xd, *_): + ang = math.inv_rotate(xd.ang, x.rot) + return jp.concatenate([x.pos, x.rot]), jp.concatenate([xd.vel, ang]) - def x_dof(j, jd, motion, x): + def x_dof(j, jd, parent_idx, motion, x): + j_rot = jp.where(parent_idx == -1, j.rot, jp.array([1.0, 0.0, 0.0, 0.0])) + jd = jd.replace(ang=math.inv_rotate(jd.ang, j_rot)) joint_frame, parity = link_to_joint_frame(motion) axis, angles, _ = axis_angle_ang(j, joint_frame, parity) angle_vels = jax.tree_map(lambda x: jp.dot(x, jd.ang), axis) @@ -350,7 +355,7 @@ def x_dof(j, jd, motion, x): ) return q, qd - def q_fn(typ, j, jd, motion): + def q_fn(typ, j, jd, parent_idx, motion): motion = jax.tree_map( lambda y: y.reshape((-1, base.QD_WIDTHS[typ], 3)), motion ) @@ -361,10 +366,12 @@ def q_fn(typ, j, jd, motion): '3': functools.partial(x_dof, x=3), } - q, qd = jax.vmap(q_fn_map[typ])(j, jd, motion) + q, qd = jax.vmap(q_fn_map[typ])(j, jd, parent_idx, motion) # transposed to preserve order of outputs return jp.array(q).reshape(-1), jp.array(qd).reshape(-1) - q, qd = scan.link_types(sys, q_fn, 'lld', 'qd', j, jd, sys.dof.motion) + parent_idx = jp.array(sys.link_parents) + q, qd = scan.link_types(sys, q_fn, 'llld', 'qd', j, jd, parent_idx, + sys.dof.motion) return q, qd diff --git a/brax/kinematics_test.py b/brax/kinematics_test.py index 7101074b..e3d39fb8 100644 --- a/brax/kinematics_test.py +++ b/brax/kinematics_test.py @@ -21,6 +21,7 @@ from brax import kinematics from brax import scan from brax import test_utils +from brax.base import Motion import jax import jax.numpy as jp import numpy as np @@ -31,14 +32,33 @@ class KinematicsTest(parameterized.TestCase): @parameterized.parameters( ('ant.xml',), ('humanoid.xml',), ('reacher.xml',), ('half_cheetah.xml',) ) - def test_forward_q(self, xml_file): + def test_forward(self, xml_file): """Test dynamics forward q.""" sys = test_utils.load_fixture(xml_file) - for mj_prev, mj_next in test_utils.sample_mujoco_states(xml_file): - x, _ = jax.jit(kinematics.forward)(sys, mj_prev.qpos, mj_prev.qvel) + + for mj_prev, mj_next in test_utils.sample_mujoco_states( + xml_file, random_init=True, vel_to_local=False): + x, xd = jax.jit(kinematics.forward)(sys, mj_prev.qpos, mj_prev.qvel) + np.testing.assert_almost_equal(x.pos, mj_next.xpos[1:], 3) + # handle quat rotations +/- 2pi + quat_sign = np.allclose( + np.sum(mj_next.xquat[1:]) - np.sum(x.rot), 0, atol=1e-2) + quat_sign = 1 if quat_sign else -1 + x = x.replace(rot=x.rot * quat_sign) np.testing.assert_almost_equal(x.rot, mj_next.xquat[1:], 3) + # xd vel/ang were added to linvel/angmom in `sample_mujoco_states` + xd_mj = Motion( + vel=mj_next.subtree_linvel[1:], ang=mj_next.subtree_angmom[1:]) + + if xml_file == 'humanoid.xml': + # TODO: get forward to match MJ for stacked/offset joints + return + + np.testing.assert_array_almost_equal(xd.ang, xd_mj.ang, 3) + np.testing.assert_array_almost_equal(xd.vel, xd_mj.vel, 3) + def test_init_q(self): sys = test_utils.load_fixture('ant.xml') np.testing.assert_almost_equal( @@ -53,10 +73,11 @@ def test_init_q(self): def test_inverse(self, xml_file): np.random.seed(0) sys = test_utils.load_fixture(xml_file) - # # test at random init + # test at random init rand_q = np.random.rand(sys.init_q.shape[0]) - # normalize quaternion part of init_q - rand_q[3:7] = rand_q[3:7] / np.linalg.norm(rand_q[3:7]) + if sys.link_types[0] == 'f': + # normalize quaternion part of init_q + rand_q[3:7] = rand_q[3:7] / np.linalg.norm(rand_q[3:7]) rand_q = jp.array(rand_q) rand_qd = jp.array(np.random.rand(sys.qd_size())) * 0.1 @@ -94,10 +115,11 @@ def _collect_frame(typ, motion): ), ), armature=np.array([0.0, 0.0, 0.0, 0.0]), - invweight=np.array([0.0, 0.0, 0.0, 0.0]), stiffness=np.array([0.0, 0.0, 0.0, 0.0]), damping=np.array([0.0, 0.0, 0.0, 0.0]), limit=None, + invweight=np.array([0.0, 0.0, 0.0, 0.0]), + solver_params=np.zeros((4, 7)), ) ) @@ -129,10 +151,11 @@ def _collect_frame(typ, motion): ]), ), armature=np.array([0.0, 0.0, 0.0, 0.0]), - invweight=np.array([0.0, 0.0, 0.0, 0.0]), stiffness=np.array([0.0, 0.0, 0.0, 0.0]), damping=np.array([0.0, 0.0, 0.0, 0.0]), limit=None, + invweight=np.array([0.0, 0.0, 0.0, 0.0]), + solver_params=np.zeros((4, 7)), ) ) diff --git a/brax/math.py b/brax/math.py index da3f33fa..9725d4e8 100644 --- a/brax/math.py +++ b/brax/math.py @@ -40,6 +40,19 @@ def rotate(vec: jp.ndarray, quat: jp.ndarray): return r +def inv_rotate(vec: jp.ndarray, quat: jp.ndarray): + """Rotates a vector vec by an inverted unit quaternion quat. + + Args: + vec: (3,) a vector + quat: (4,) a quaternion + + Returns: + ndarray(3) containing vec rotated by the inverse of quat. + """ + return rotate(vec, quat_inv(quat)) + + def rotate_np(vec: np.ndarray, quat: np.ndarray): """Rotates a vector vec by a unit quaternion quat. diff --git a/brax/positional/perf_test.py b/brax/positional/perf_test.py index 65d10a0c..ef1b4f47 100644 --- a/brax/positional/perf_test.py +++ b/brax/positional/perf_test.py @@ -34,7 +34,7 @@ def init_fn(rng): return pipeline.init(sys, q, qd) def step_fn(state): - return pipeline.step(sys, state, jp.zeros(sys.qd_size())) + return pipeline.step(sys, state, jp.zeros(sys.act_size())) test_utils.benchmark('pbd pipeline ant', init_fn, step_fn) diff --git a/brax/positional/pipeline.py b/brax/positional/pipeline.py index e00a0648..eb2fcdce 100644 --- a/brax/positional/pipeline.py +++ b/brax/positional/pipeline.py @@ -70,7 +70,7 @@ def step( x_i_prev = state.x_i # calculate acceleration level updates - tau = actuator.to_tau(sys, act, state.q) + tau = actuator.to_tau(sys, act, state.q, state.qd) xdd_i = Motion.create(vel=sys.gravity) xdd_i += joints.acceleration_update(sys, state, tau) diff --git a/brax/positional/pipeline_test.py b/brax/positional/pipeline_test.py index 12a74709..82ed5d6c 100644 --- a/brax/positional/pipeline_test.py +++ b/brax/positional/pipeline_test.py @@ -16,8 +16,10 @@ """Tests for spring physics pipeline.""" from absl.testing import absltest +from brax import com from brax import kinematics from brax import test_utils +from brax.base import Transform from brax.generalized import pipeline as g_pipeline from brax.positional import pipeline import jax @@ -36,7 +38,7 @@ def test_pendulum(self): state = pipeline.init(sys, sys.init_q, jp.zeros(sys.qd_size())) j_pos_step = jax.jit(pipeline.step) for _ in range(2_000): - state = j_pos_step(sys, state, jp.zeros(sys.qd_size())) + state = j_pos_step(sys, state, jp.zeros(sys.act_size())) x = state.x # compare against generalized step @@ -45,7 +47,7 @@ def test_pendulum(self): j_g_step = jax.jit(g_pipeline.step) j_forward = jax.jit(kinematics.forward) for _ in range(2_000): - state = j_g_step(sys, state, jp.zeros(sys.qd_size())) + state = j_g_step(sys, state, jp.zeros(sys.act_size())) x_g, _ = j_forward(sys, state.q, state.qd) # trajectories should be close after .1 second of simulation @@ -66,18 +68,24 @@ def test_spherical_pendulum(self): sys = sys.replace(solver_iterations=500) state = pipeline.init(sys, init_q, init_qd) + # the qd calculation for pbd/spring doesn't match generalized, so we get xd + # from generalized and plug it back into pbd + # TODO: remove this xd override once kinematics.forward is fixed + state_g = g_pipeline.init(sys, init_q, init_qd) + xd = Transform.create(pos=state_g.x.pos - state_g.com).vmap().do(state_g.cd) + state = state.replace(xd=xd, xd_i=com.from_world(sys, state.x, xd)[1]) + j_pos_step = jax.jit(pipeline.step) for _ in range(1000): - state = j_pos_step(sys, state, jp.zeros(sys.qd_size())) + state = j_pos_step(sys, state, jp.zeros(sys.act_size())) x = state.x # compare against generalized step - q, qd = init_q, init_qd - state = g_pipeline.init(sys, q, qd) j_g_step = jax.jit(g_pipeline.step) j_forward = jax.jit(kinematics.forward) + state = state_g for _ in range(1000): - state = j_g_step(sys, state, jp.zeros(sys.qd_size())) + state = j_g_step(sys, state, jp.zeros(sys.act_size())) x_g, _ = j_forward(sys, state.q, state.qd) # trajectories should be close after 1 second of simulation @@ -97,7 +105,7 @@ def test_3d_sliding_joint(self): j_pos_step = jax.jit(pipeline.step) states = [] for _ in range(1000): - state = j_pos_step(sys, state, jp.zeros(sys.qd_size())) + state = j_pos_step(sys, state, jp.zeros(sys.act_size())) states.append(state) x, xd = state.x, state.xd @@ -124,7 +132,7 @@ def test_3d_prismaversal_joint(self): j_pos_step = jax.jit(pipeline.step) states = [] for _ in range(1000): - state = j_pos_step(sys, state, jp.zeros(sys.qd_size())) + state = j_pos_step(sys, state, jp.zeros(sys.act_size())) states.append(state) # reflects off limits and is still traveling close to 2.5 m/s @@ -148,7 +156,7 @@ def test_sliding_capsule(self): state = pipeline.init(sys, sys.init_q, qd) j_pos_step = jax.jit(pipeline.step) for _ in range(1000): - state = j_pos_step(sys, state, jp.zeros(sys.qd_size())) + state = j_pos_step(sys, state, jp.zeros(sys.act_size())) x, xd = state.x, state.xd # capsule slides to a stop diff --git a/brax/scan.py b/brax/scan.py index 98e03cda..96e3837f 100644 --- a/brax/scan.py +++ b/brax/scan.py @@ -191,56 +191,3 @@ def f(typ, *args) -> y y = out_ys[0] if len(out_types) == 1 else out_ys return y - - -def actuator_types( - sys: System, f: Callable[..., Y], in_types: str, out_type: str, *args -) -> Y: - r"""Scan a function over System actuator type ranges. - - Args: - sys: system defining the kinematic tree and other properties - f: a function to be scanned with the following type signature:\ - def f(typ, link, q, qd) -> y - where - ``typ`` is the actuator, link type string - ``*args`` are input arguments with types matching ``in_types`` - ``y`` is an output arguments with types matching ``out_type`` - in_types: string specifying the type of each input arg: - 'a' is an input to be split according to act ranges - 'l' is an input to be split according to link ranges - 'q' is an input to be split according to q ranges - 'd' is an input to be split according to qd ranges - out_type: string specifying the type of the output - *args: the input arguments corresponding to ``in_types`` - - Returns: - The stacked outputs of ``f`` matching the system actuator order. - """ - typ_order = sorted(set(sys.actuator_types), key=sys.actuator_types.find) - - typ_order_idxs = [] - for i, t in enumerate(sys.actuator_types): - order = typ_order.index(t) - while order >= len(typ_order_idxs): - typ_order_idxs.append({'a': [], 'l': [], 'q': [], 'd': []}) - typ_order_idxs[order]['a'].append(i) - typ_order_idxs[order]['l'].append(sys.actuator_link_id[i]) - typ_order_idxs[order]['q'].append(sys.actuator_qid[i]) - typ_order_idxs[order]['d'].append(sys.actuator_qdid[i]) - - ys = [] - for typ, typ_idxs in zip(typ_order, typ_order_idxs): - in_args = [_take(a, typ_idxs[t]) for a, t in zip(args, in_types)] - ys.append(f(typ, *in_args)) - - y = jax.tree_map(lambda *x: jp.concatenate(x), *ys) - - # we concatenated results out of order, so put back in order if needed - - order = sum([t[out_type] for t in typ_order_idxs], []) - - if order != list(range(len(order))): - y = _take(y, [order.index(i) for i in range(len(order))]) - - return y diff --git a/brax/scan_test.py b/brax/scan_test.py index df8aa346..4821276b 100644 --- a/brax/scan_test.py +++ b/brax/scan_test.py @@ -16,7 +16,6 @@ """Tests for scan functions.""" from absl.testing import absltest -from absl.testing import parameterized from brax import scan from brax import test_utils import numpy as np @@ -155,65 +154,5 @@ def f(typ, link, q, qd): np.testing.assert_array_equal(qds[1], np.arange(6, 14)) -class ParametrizedScanTest(parameterized.TestCase): - - @parameterized.parameters( - ( - 'single_spherical_pendulum_position.xml', - ['p'], - [0, 1, 2], # act_id - [0, 0, 0], # act_link_id - [2, 0, 1], # q_id - [2, 0, 1], # qd_id - ), - ( - 'ant.xml', - ['m'], - list(range(8)), - [7, 8, 1, 2, 3, 4, 5, 6], - [13, 14, 7, 8, 9, 10, 11, 12], - [12, 13, 6, 7, 8, 9, 10, 11], - ), - ) - def test_scan_actuator_types( - self, fname, act_typs, act_id, act_link_id, q_id, qd_id - ): - """Test scanning actuators.""" - sys = test_utils.load_fixture(fname) - - typs, links, qs, qds = [], [], [], [] - - def f(typ, act, link, q, qd): - typs.append(typ) - links.append(link) - qs.append(q) - qds.append(qd) - return act - - out = scan.actuator_types( - sys, - f, - 'alqd', - 'a', - np.arange(sys.act_size()), - np.arange(sys.num_links()), - np.arange(sys.q_size()), - np.arange(sys.qd_size()), - ) - - self.assertSequenceEqual(typs, act_typs) - np.testing.assert_array_equal(out, np.array(act_id)) - - self.assertLen(links, 1) - self.assertSequenceEqual(sys.actuator_link_id, act_link_id) - np.testing.assert_array_equal(links[0], np.array(sys.actuator_link_id)) - - self.assertLen(qs, 1) - np.testing.assert_array_equal(qs[0], np.array(q_id)) - - self.assertLen(qds, 1) - np.testing.assert_array_equal(qds[0], np.array(qd_id)) - - if __name__ == '__main__': absltest.main() diff --git a/brax/spring/perf_test.py b/brax/spring/perf_test.py index d7d7ec61..2b33e042 100644 --- a/brax/spring/perf_test.py +++ b/brax/spring/perf_test.py @@ -34,7 +34,7 @@ def init_fn(rng): return pipeline.init(sys, q, qd) def step_fn(state): - return pipeline.step(sys, state, jp.zeros(sys.qd_size())) + return pipeline.step(sys, state, jp.zeros(sys.act_size())) test_utils.benchmark('spring pipeline ant', init_fn, step_fn) diff --git a/brax/spring/pipeline.py b/brax/spring/pipeline.py index 3a4a8456..748f76ca 100644 --- a/brax/spring/pipeline.py +++ b/brax/spring/pipeline.py @@ -87,7 +87,7 @@ def step( state = state.replace(i_inv=com.inv_inertia(sys, state.x)) # calculate acceleration and delta-velocity terms - tau = actuator.to_tau(sys, act, state.q) + tau = actuator.to_tau(sys, act, state.q, state.qd) xdd_i = joints.resolve(sys, state, tau) + Motion.create(vel=sys.gravity) # semi-implicit euler: apply acceleration update before resolving collisions state = state.replace(xd_i=state.xd_i + xdd_i * sys.dt) diff --git a/brax/spring/pipeline_test.py b/brax/spring/pipeline_test.py index 5043d75d..698107eb 100644 --- a/brax/spring/pipeline_test.py +++ b/brax/spring/pipeline_test.py @@ -16,8 +16,10 @@ """Tests for spring physics pipeline.""" from absl.testing import absltest +from brax import com from brax import kinematics from brax import test_utils +from brax.base import Transform from brax.generalized import pipeline as g_pipeline from brax.spring import pipeline import jax @@ -37,7 +39,7 @@ def test_pendulum(self): state = pipeline.init(sys, sys.init_q, jp.zeros(sys.qd_size())) j_spring_step = jax.jit(pipeline.step) for _ in range(10_000): - state = j_spring_step(sys, state, jp.zeros(sys.qd_size())) + state = j_spring_step(sys, state, jp.zeros(sys.act_size())) x = state.x # compare against generalized step @@ -46,7 +48,7 @@ def test_pendulum(self): j_g_step = jax.jit(g_pipeline.step) j_forward = jax.jit(kinematics.forward) for _ in range(10_000): - state = j_g_step(sys, state, jp.zeros(sys.qd_size())) + state = j_g_step(sys, state, jp.zeros(sys.act_size())) x_g, _ = j_forward(sys, state.q, state.qd) # trajectories should be close after 1 second of simulation @@ -71,12 +73,11 @@ def test_universal_pendulum(self): state = pipeline.init(sys, init_q, init_qd) j_spring_step = jax.jit(pipeline.step) for _ in range(1000): - state = j_spring_step(sys, state, jp.zeros(sys.qd_size())) + state = j_spring_step(sys, state, jp.zeros(sys.act_size())) x = state.x # compare against generalized step - q, qd = init_q, init_qd - state = g_pipeline.init(sys, q, qd) + state = g_pipeline.init(sys, init_q, init_qd) j_g_step = jax.jit(g_pipeline.step) j_forward = jax.jit(kinematics.forward) for _ in range(1000): @@ -84,7 +85,7 @@ def test_universal_pendulum(self): x_g, _ = j_forward(sys, state.q, state.qd) # trajectories should be close after 1 second of simulation - self.assertLess(jp.linalg.norm(x_g.rot - x.rot), 1.5e-2) + self.assertLess(jp.linalg.norm(x_g.rot - x.rot), 1.51e-2) def test_spherical_pendulum(self): sys = test_utils.load_fixture('single_spherical_pendulum.xml') @@ -103,18 +104,24 @@ def test_spherical_pendulum(self): sys = sys.replace(solver_iterations=500) state = pipeline.init(sys, init_q, init_qd) + # the qd calculation for pbd/spring doesn't match generalized, so we get xd + # from generalized and plug it back into pbd + # TODO: remove this xd override once kinematics.forward is fixed + state_g = g_pipeline.init(sys, init_q, init_qd) + xd = Transform.create(pos=state_g.x.pos - state_g.com).vmap().do(state_g.cd) + state = state.replace(xd=xd, xd_i=com.from_world(sys, state.x, xd)[1]) + j_spring_step = jax.jit(pipeline.step) for _ in range(1000): - state = j_spring_step(sys, state, jp.zeros(sys.qd_size())) + state = j_spring_step(sys, state, jp.zeros(sys.act_size())) x = state.x # compare against generalized step - q, qd = init_q, init_qd - state = g_pipeline.init(sys, q, qd) + state = state_g j_g_step = jax.jit(g_pipeline.step) j_forward = jax.jit(kinematics.forward) for _ in range(1000): - state = j_g_step(sys, state, jp.zeros(sys.qd_size())) + state = j_g_step(sys, state, jp.zeros(sys.act_size())) x_g, _ = j_forward(sys, state.q, state.qd) # trajectories should be close after 1 second of simulation @@ -140,7 +147,7 @@ def test_prismatic_joint(self): state = pipeline.init(sys, init_q, init_qd) j_spring_step = jax.jit(pipeline.step) for _ in range(1000): - state = j_spring_step(sys, state, jp.zeros(sys.qd_size())) + state = j_spring_step(sys, state, jp.zeros(sys.act_size())) x = state.x # compare against generalized step @@ -172,7 +179,7 @@ def test_2d_sliding_joint(self): j_spring_step = jax.jit(pipeline.step) states = [] for _ in range(1000): - state = j_spring_step(sys, state, jp.zeros(sys.qd_size())) + state = j_spring_step(sys, state, jp.zeros(sys.act_size())) states.append(state) x, xd = state.x, state.xd @@ -201,7 +208,7 @@ def test_3d_sliding_joint(self): j_spring_step = jax.jit(pipeline.step) states = [] for _ in range(1000): - state = j_spring_step(sys, state, jp.zeros(sys.qd_size())) + state = j_spring_step(sys, state, jp.zeros(sys.act_size())) states.append(state) x, xd = state.x, state.xd @@ -228,7 +235,7 @@ def test_2d_prismaversal_joint(self): j_spring_step = jax.jit(pipeline.step) states = [] for _ in range(1000): - state = j_spring_step(sys, state, jp.zeros(sys.qd_size())) + state = j_spring_step(sys, state, jp.zeros(sys.act_size())) states.append(state) # reflects off limits and is still traveling 2.5 m/s @@ -251,7 +258,7 @@ def test_3d_prismaversal_joint(self): j_spring_step = jax.jit(pipeline.step) states = [] for _ in range(1000): - state = j_spring_step(sys, state, jp.zeros(sys.qd_size())) + state = j_spring_step(sys, state, jp.zeros(sys.act_size())) states.append(state) # reflects off limits and is still traveling 2.5 m/s @@ -272,7 +279,7 @@ def test_sliding_capsule(self): state = pipeline.init(sys, sys.init_q, qd) j_spring_step = jax.jit(pipeline.step) for _ in range(1000): - state = j_spring_step(sys, state, jp.zeros(sys.qd_size())) + state = j_spring_step(sys, state, jp.zeros(sys.act_size())) x, xd = state.x, state.xd # capsule slides to a stop diff --git a/brax/test_data/fat_cylinder.xml b/brax/test_data/fat_cylinder.xml new file mode 100644 index 00000000..e3fd71dc --- /dev/null +++ b/brax/test_data/fat_cylinder.xml @@ -0,0 +1,11 @@ + + + \ No newline at end of file diff --git a/brax/test_data/flat_cylinder.xml b/brax/test_data/flat_cylinder.xml new file mode 100644 index 00000000..266f5df3 --- /dev/null +++ b/brax/test_data/flat_cylinder.xml @@ -0,0 +1,11 @@ + + + \ No newline at end of file diff --git a/brax/test_data/fluid_box.xml b/brax/test_data/fluid_box.xml new file mode 100644 index 00000000..59fa968f --- /dev/null +++ b/brax/test_data/fluid_box.xml @@ -0,0 +1,14 @@ + + + + + + \ No newline at end of file diff --git a/brax/test_data/fluid_box_offset_com.xml b/brax/test_data/fluid_box_offset_com.xml new file mode 100644 index 00000000..38ef6fcc --- /dev/null +++ b/brax/test_data/fluid_box_offset_com.xml @@ -0,0 +1,23 @@ + + + + + + + + \ No newline at end of file diff --git a/brax/test_data/fluid_ellipsoid.xml b/brax/test_data/fluid_ellipsoid.xml new file mode 100644 index 00000000..c52f6006 --- /dev/null +++ b/brax/test_data/fluid_ellipsoid.xml @@ -0,0 +1,14 @@ + + + + + + \ No newline at end of file diff --git a/brax/test_data/fluid_wind.xml b/brax/test_data/fluid_wind.xml new file mode 100644 index 00000000..27cdf2dc --- /dev/null +++ b/brax/test_data/fluid_wind.xml @@ -0,0 +1,14 @@ + + + + + + \ No newline at end of file diff --git a/brax/test_data/single_pendulum_velocity.xml b/brax/test_data/single_pendulum_velocity.xml new file mode 100644 index 00000000..5eafb9ec --- /dev/null +++ b/brax/test_data/single_pendulum_velocity.xml @@ -0,0 +1,20 @@ + + + diff --git a/brax/test_data/solver_params_v2.xml b/brax/test_data/solver_params_v2.xml new file mode 100644 index 00000000..e52f0300 --- /dev/null +++ b/brax/test_data/solver_params_v2.xml @@ -0,0 +1,13 @@ + + + \ No newline at end of file diff --git a/brax/test_utils.py b/brax/test_utils.py index 05d6cf2c..6d2f3c35 100644 --- a/brax/test_utils.py +++ b/brax/test_utils.py @@ -37,10 +37,26 @@ def load_fixture_mujoco(path: str) -> mujoco.MjModel: return model +def _normalize_q(model: mujoco.MjModel, q: np.ndarray): + """Normalizes the quaternion part of q.""" + q = np.array(q) + q_idx = 0 + for typ in model.jnt_type: + q_dim = 7 if typ == 0 else 1 + if typ == 0: + q[q_idx + 3:q_idx + 7] = ( + q[q_idx + 3:q_idx + 7] / np.linalg.norm(q[q_idx + 3:q_idx + 7])) + q_idx += q_dim + return q + + def sample_mujoco_states( - path: str, count: int = 500, modulo: int = 20, force_pgs: bool = False + path: str, count: int = 500, modulo: int = 20, force_pgs: bool = False, + random_init: bool = False, random_q_scale: float = 1.0, + random_qd_scale: float = 0.1, vel_to_local: bool = True, seed: int = 42 ) -> Iterable[Tuple[mujoco.MjData, mujoco.MjData]]: """Samples count / modulo states from mujoco for comparison.""" + np.random.seed(seed) model = load_fixture_mujoco(path) model.opt.iterations = 50 # return to default for high-precision comparison if force_pgs: @@ -48,6 +64,10 @@ def sample_mujoco_states( data = mujoco.MjData(model) # give a little kick to avoid symmetry data.qvel = np.random.uniform(low=-0.01, high=0.01, size=(model.nv,)) + if random_init: + data.qpos = np.random.uniform(model.nq) * random_q_scale + data.qpos = _normalize_q(model, data.qpos) + data.qvel = np.random.uniform(size=(model.nv,)) * random_qd_scale for i in range(count): before = copy.deepcopy(data) mujoco.mj_step(model, data) @@ -55,7 +75,8 @@ def sample_mujoco_states( # hijack subtree_angmom, subtree_linvel (unused) to store xang, xvel for i in range(model.nbody): vel = np.zeros((6,)) - mujoco.mj_objectVelocity(model, data, 2, i, vel, 1) + mujoco.mj_objectVelocity( + model, data, mujoco.mjtObj.mjOBJ_XBODY.value, i, vel, vel_to_local) data.subtree_angmom[i] = vel[:3] data.subtree_linvel[i] = vel[3:] yield before, data diff --git a/brax/training/agents/ars/train.py b/brax/training/agents/ars/train.py index 1d12ce3e..bee52124 100644 --- a/brax/training/agents/ars/train.py +++ b/brax/training/agents/ars/train.py @@ -141,8 +141,9 @@ def add_noise(params: Params, key: PRNGKey) -> Tuple[Params, Params, Params]: noise = jax.tree_util.tree_map( lambda g, k: jax.random.normal(k, shape=g.shape, dtype=g.dtype), params, jax.tree_util.tree_unflatten(treedef, all_keys)) - params_with_noise = jax.tree_util.tree_map(lambda g, n: g + n * exploration_noise_std, - params, noise) + params_with_noise = jax.tree_util.tree_map( + lambda g, n: g + n * exploration_noise_std, params, noise + ) params_with_anti_noise = jax.tree_util.tree_map( lambda g, n: g - n * exploration_noise_std, params, noise) return params_with_noise, params_with_anti_noise, noise @@ -154,15 +155,20 @@ def training_epoch(training_state: TrainingState, key: PRNGKey) -> Tuple[TrainingState, Metrics]: params = jax.tree_util.tree_map( lambda x: jnp.repeat( - jnp.expand_dims(x, axis=0), number_of_directions, axis=0), - training_state.policy_params) + jnp.expand_dims(x, axis=0), number_of_directions, axis=0 + ), + training_state.policy_params, + ) key, key_noise, key_es_eval = jax.random.split(key, 3) # generate perturbations params_with_noise, params_with_anti_noise, noise = add_noise( params, key_noise) - pparams = jax.tree_util.tree_map(lambda a, b: jnp.concatenate([a, b], axis=0), - params_with_noise, params_with_anti_noise) + pparams = jax.tree_util.tree_map( + lambda a, b: jnp.concatenate([a, b], axis=0), + params_with_noise, + params_with_anti_noise, + ) pparams = jax.tree_util.tree_map( lambda x: jnp.reshape(x, (local_devices_to_use, -1) + x.shape[1:]), @@ -187,13 +193,17 @@ def training_epoch(training_state: TrainingState, reward_weight_double = jnp.concatenate([reward_weight, reward_weight], axis=0) reward_std = jnp.std(eval_scores, where=reward_weight_double) + reward_std += (reward_std == 0.0) * 1e-6 noise = jax.tree_util.tree_map( lambda x: jnp.sum( jnp.transpose( - jnp.transpose(x) * reward_weight * - (reward_plus - reward_minus)), - axis=0), noise) + jnp.transpose(x) * reward_weight * (reward_plus - reward_minus) + ), + axis=0, + ), + noise, + ) policy_params = jax.tree_util.tree_map( lambda x, y: x + step_size * y / (top_directions * reward_std), diff --git a/brax/training/agents/ars/train_test.py b/brax/training/agents/ars/train_test.py index 750d33e2..229642ad 100644 --- a/brax/training/agents/ars/train_test.py +++ b/brax/training/agents/ars/train_test.py @@ -21,7 +21,6 @@ from brax.training.acme import running_statistics from brax.training.agents.ars import networks as ars_networks from brax.training.agents.ars import train as ars -from brax.v1 import envs as envs_v1 import jax diff --git a/brax/training/agents/sac/train.py b/brax/training/agents/sac/train.py index 051dfc3d..b4b4d9f6 100644 --- a/brax/training/agents/sac/train.py +++ b/brax/training/agents/sac/train.py @@ -314,8 +314,7 @@ def training_step( (training_state, training_key), transitions) - metrics['buffer_current_size'] = buffer_state.current_size - metrics['buffer_current_position'] = buffer_state.current_position + metrics['buffer_current_size'] = replay_buffer.size(buffer_state) return training_state, env_state, buffer_state, metrics def prefill_replay_buffer( diff --git a/brax/training/learner.py b/brax/training/learner.py index 9d2e54d4..d85caf6a 100644 --- a/brax/training/learner.py +++ b/brax/training/learner.py @@ -118,8 +118,7 @@ 'Std of a random noise added by ARS.') flags.DEFINE_float('reward_shift', 0., 'A reward shift to get rid of "stay alive" bonus.') -flags.DEFINE_enum('head_type', '', ['', 'clip', 'tanh'], - 'Which policy head to use.') + # ARS hps. flags.DEFINE_integer('truncation_length', None, 'Truncation for gradient propagation in APG.') diff --git a/brax/training/replay_buffers.py b/brax/training/replay_buffers.py index 891efac2..1fe18b6d 100644 --- a/brax/training/replay_buffers.py +++ b/brax/training/replay_buffers.py @@ -71,8 +71,8 @@ class ReplayBufferState: """Contains data related to a replay buffer.""" data: jnp.ndarray - current_position: jnp.ndarray - current_size: jnp.ndarray + insert_position: jnp.ndarray + sample_position: jnp.ndarray key: PRNGKey @@ -109,14 +109,14 @@ def __init__( def init(self, key: PRNGKey) -> ReplayBufferState: return ReplayBufferState( data=jnp.zeros(self._data_shape, self._data_dtype), - current_size=jnp.zeros((), jnp.int32), - current_position=jnp.zeros((), jnp.int32), + sample_position=jnp.zeros((), jnp.int32), + insert_position=jnp.zeros((), jnp.int32), key=key, ) def check_can_insert(self, buffer_state, samples, shards): """Checks whether insert operation can be performed.""" - assert isinstance(shards, int), "This method should not be JITed." + assert isinstance(shards, int), 'This method should not be JITed.' insert_size = jax.tree_flatten(samples)[0][0].shape[0] // shards if self._data_shape[0] < insert_size: raise ValueError( @@ -141,14 +141,15 @@ def insert_internal( if buffer_state.data.shape != self._data_shape: raise ValueError( f'buffer_state.data.shape ({buffer_state.data.shape}) ' - f'doesn\'t match the expected value ({self._data_shape})') + f"doesn't match the expected value ({self._data_shape})" + ) update = self._flatten_fn(samples) data = buffer_state.data # If needed, roll the buffer to make sure there's enough space to fit # `update` after the current position. - position = buffer_state.current_position + position = buffer_state.insert_position roll = jnp.minimum(0, len(data) - position - len(update)) data = jax.lax.cond( roll, lambda: jnp.roll(data, roll, axis=0), lambda: data @@ -157,11 +158,13 @@ def insert_internal( # Update the buffer and the control numbers. data = jax.lax.dynamic_update_slice_in_dim(data, update, position, axis=0) - position = (position + len(update)) % len(data) - size = jnp.minimum(buffer_state.current_size + len(update), len(data)) + position = (position + len(update)) % (len(data) + 1) + sample_position = jnp.maximum(0, buffer_state.sample_position + roll) return buffer_state.replace( - data=data, current_position=position, current_size=size + data=data, + insert_position=position, + sample_position=sample_position, ) def sample_internal( @@ -170,21 +173,46 @@ def sample_internal( raise NotImplementedError(f'{self.__class__}.sample() is not implemented.') def size(self, buffer_state: ReplayBufferState) -> int: - return buffer_state.current_size # pytype: disable=bad-return-type # jax-ndarray + return buffer_state.insert_position - buffer_state.sample_position # pytype: disable=bad-return-type # jax-ndarray class Queue(QueueBase[Sample], Generic[Sample]): """Implements a limited-size queue replay buffer.""" + def __init__( + self, + max_replay_size: int, + dummy_data_sample: Sample, + sample_batch_size: int, + cyclic: bool = False, + ): + """Initializes the queue. + + Args: + max_replay_size: Maximum number of elements queue can have. + dummy_data_sample: Example record to be stored in the queue, it is used to + derive shapes. + sample_batch_size: How many elements sampling from the queue should return + in a batch. + cyclic: Should sampling from the queue behave cyclicly, ie. once recently + inserted element was sampled, sampling starts from the beginning of the + buffer. For example, if the current queue content is [0, 1, 2] and + `sample_batch_size` is 2, then consecutive calls to sample will give: + [0, 1], [2, 0], [1, 2]... + """ + super().__init__(max_replay_size, dummy_data_sample, sample_batch_size) + self._cyclic = cyclic + def check_can_sample(self, buffer_state, shards): """Checks whether sampling can be performed. Do not JIT this method.""" - assert isinstance(shards, int), "This method should not be JITed." + assert isinstance(shards, int), 'This method should not be JITed.' if self._size < self._sample_batch_size: raise ValueError( f'Trying to sample {self._sample_batch_size * shards} elements, but' f' only {self._size * shards} available.' ) - self._size -= self._sample_batch_size + if not self._cyclic: + self._size -= self._sample_batch_size def sample_internal( self, buffer_state: ReplayBufferState @@ -205,34 +233,24 @@ def sample_internal( # Note that this may be out of bound, but the operations below would still # work fine as they take this number modulo the buffer size. - first_element_idx = ( - buffer_state.current_position - buffer_state.current_size - ) - idx = jnp.arange(self._sample_batch_size) + first_element_idx + idx = (jnp.arange(self._sample_batch_size) + buffer_state.sample_position) % buffer_state.insert_position flat_batch = jnp.take(buffer_state.data, idx, axis=0, mode='wrap') - # TODO: Raise an error instead of padding with zeros - # when the buffer does not contain enough elements. - # If the sample batch size is larger than the number of elements in the - # queue, `mask` would contain 0s for all elements that are past the current - # position. Otherwise, `mask` will be only ones. - # mask.shape = (self._sample_batch_size,) - mask = idx < buffer_state.current_position - # mask.shape = (self._sample_batch_size, 1) - mask = jnp.expand_dims(mask, axis=range(1, flat_batch.ndim)) - flat_batch = flat_batch * mask - - # The effective size of the sampled batch. - sample_size = jnp.minimum( - self._sample_batch_size, buffer_state.current_size - ) # Remove the sampled batch from the queue. - new_state = buffer_state.replace( - current_size=buffer_state.current_size - sample_size - ) + sample_position = buffer_state.sample_position + self._sample_batch_size + if self._cyclic: + sample_position = sample_position % buffer_state.insert_position + + new_state = buffer_state.replace(sample_position=sample_position) return new_state, self._unflatten_fn(flat_batch) + def size(self, buffer_state: ReplayBufferState) -> int: + if self._cyclic: + return buffer_state.insert_position # pytype: disable=bad-return-type # jax-ndarray + else: + return buffer_state.insert_position - buffer_state.sample_position # pytype: disable=bad-return-type # jax-ndarray + class UniformSamplingQueue(QueueBase[Sample], Generic[Sample]): """Implements an uniform sampling limited-size replay queue. @@ -257,8 +275,8 @@ def sample_internal( idx = jax.random.randint( sample_key, (self._sample_batch_size,), - minval=buffer_state.current_position - buffer_state.current_size, - maxval=buffer_state.current_position, + minval=buffer_state.sample_position, + maxval=buffer_state.insert_position, ) batch = jnp.take(buffer_state.data, idx, axis=0, mode='wrap') return buffer_state.replace(key=key), self._unflatten_fn(batch) diff --git a/brax/training/replay_buffers_test.py b/brax/training/replay_buffers_test.py index e023a958..9aa11cc0 100644 --- a/brax/training/replay_buffers_test.py +++ b/brax/training/replay_buffers_test.py @@ -299,9 +299,9 @@ def testUniformSamplingQueueCyclicSample(self, wrapper): ) assert_equal(self, replay_buffer.size(buffer_state), 10) if jax.device_count() == 1 or wrapper in [no_wrap, jit_wrap]: - assert_equal(self, buffer_state.current_position, 0) + assert_equal(self, buffer_state.insert_position, 10) else: - assert_equal(self, buffer_state.current_position, [5, 5]) + assert_equal(self, buffer_state.insert_position, [5, 5]) @parameterized.parameters(WRAPPERS) def testQueueSamplePyTree(self, wrapper): @@ -357,14 +357,14 @@ def testQueueSample(self, wrapper): buffer_state.data, [[0], [1], [2], [3], [0], [0], [0], [0], [0], [0]], ) - assert_equal(self, buffer_state.current_position, 4) + assert_equal(self, buffer_state.insert_position, 4) else: assert_equal( self, buffer_state.data, [[[0], [2], [0], [0], [0]], [[1], [3], [0], [0], [0]]], ) - assert_equal(self, buffer_state.current_position, [2, 2]) + assert_equal(self, buffer_state.insert_position, [2, 2]) assert_equal(self, replay_buffer.size(buffer_state), 4) buffer_state = replay_buffer.insert(buffer_state, jnp.arange(4, 10)) @@ -374,14 +374,14 @@ def testQueueSample(self, wrapper): buffer_state.data, [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]], ) - assert_equal(self, buffer_state.current_position, [0, 0]) + assert_equal(self, buffer_state.insert_position, 10) else: assert_equal( self, buffer_state.data, [[[0], [2], [4], [6], [8]], [[1], [3], [5], [7], [9]]], ) - assert_equal(self, buffer_state.current_position, 0) + assert_equal(self, buffer_state.insert_position, [5, 5]) assert_equal(self, replay_buffer.size(buffer_state), 10) buffer_state, samples = replay_buffer.sample(buffer_state) @@ -402,6 +402,92 @@ def testQueueSample(self, wrapper): assert_equal(self, samples, [8, 9, 20, 21]) assert_equal(self, replay_buffer.size(buffer_state), 2) + @parameterized.parameters(WRAPPERS) + def testCyclicQueueSample(self, wrapper): + mesh = get_mesh() + size_denominator = ( + 1 if wrapper in [no_wrap, jit_wrap] else mesh.shape[AXIS_NAME] + ) + mesh = get_mesh() + replay_buffer = wrapper( + replay_buffers.Queue( + max_replay_size=10 // size_denominator, + dummy_data_sample=0, + sample_batch_size=4 // size_denominator, + cyclic=True, + ) + ) + rng = jax.random.PRNGKey(0) + + buffer_state = replay_buffer.init(rng) + + buffer_state = replay_buffer.insert(buffer_state, jnp.arange(6)) + if jax.device_count() == 1 or wrapper in [no_wrap, jit_wrap]: + assert_equal( + self, + buffer_state.data, + [[0], [1], [2], [3], [4], [5], [0], [0], [0], [0]], + ) + assert_equal(self, buffer_state.insert_position, 6) + else: + assert_equal( + self, + buffer_state.data, + [[[0], [2], [4], [0], [0]], [[1], [3], [5], [0], [0]]], + ) + assert_equal(self, buffer_state.insert_position, [3, 3]) + assert_equal(self, replay_buffer.size(buffer_state), 6) + + buffer_state, samples = replay_buffer.sample(buffer_state) + assert_equal(self, samples, [0, 1, 2, 3]) + assert_equal(self, replay_buffer.size(buffer_state), 6) + + buffer_state, samples = replay_buffer.sample(buffer_state) + assert_equal(self, samples, [4, 5, 0, 1]) + assert_equal(self, replay_buffer.size(buffer_state), 6) + + buffer_state = replay_buffer.insert(buffer_state, jnp.arange(6, 10)) + if jax.device_count() == 1 or wrapper in [no_wrap, jit_wrap]: + assert_equal( + self, + buffer_state.data, + [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]], + ) + assert_equal(self, buffer_state.insert_position, 10) + assert_equal(self, buffer_state.sample_position, 2) + else: + assert_equal( + self, + buffer_state.data, + [[[0], [2], [4], [6], [8]], [[1], [3], [5], [7], [9]]], + ) + assert_equal(self, buffer_state.insert_position, [5, 5]) + assert_equal(self, buffer_state.sample_position, [1, 1]) + assert_equal(self, replay_buffer.size(buffer_state), 10) + + buffer_state, samples = replay_buffer.sample(buffer_state) + assert_equal(self, samples, [2, 3, 4, 5]) + assert_equal(self, replay_buffer.size(buffer_state), 10) + + buffer_state, samples = replay_buffer.sample(buffer_state) + assert_equal(self, samples, [6, 7, 8, 9]) + assert_equal(self, replay_buffer.size(buffer_state), 10) + + buffer_state = replay_buffer.insert(buffer_state, jnp.arange(20, 24)) + assert_equal(self, replay_buffer.size(buffer_state), 10) + + buffer_state, samples = replay_buffer.sample(buffer_state) + assert_equal(self, samples, [4, 5, 6, 7]) + assert_equal(self, replay_buffer.size(buffer_state), 10) + + buffer_state, samples = replay_buffer.sample(buffer_state) + assert_equal(self, samples, [8, 9, 20, 21]) + assert_equal(self, replay_buffer.size(buffer_state), 10) + + buffer_state, samples = replay_buffer.sample(buffer_state) + assert_equal(self, samples, [22, 23, 4, 5]) + assert_equal(self, replay_buffer.size(buffer_state), 10) + @parameterized.parameters(WRAPPERS) def testQueueInsertWhenFull(self, wrapper): mesh = get_mesh() @@ -426,9 +512,9 @@ def testQueueInsertWhenFull(self, wrapper): assert_equal( self, buffer_state.data, - [[10], [11], [2], [3], [4], [5], [6], [7], [8], [9]], + [[2], [3], [4], [5], [6], [7], [8], [9], [10], [11]], ) - assert_equal(self, buffer_state.current_position, 2) + assert_equal(self, buffer_state.insert_position, 10) assert_equal(self, replay_buffer.size(buffer_state), 10) @parameterized.parameters(WRAPPERS) @@ -453,13 +539,13 @@ def testQueueWrappedSample(self, wrapper): assert_equal( self, buffer_state.data, - [[10], [11], [12], [13], [14], [15], [6], [7], [8], [9]], + [[6], [7], [8], [9], [10], [11], [12], [13], [14], [15]], ) assert_equal(self, replay_buffer.size(buffer_state), 10) if jax.device_count() == 1 or wrapper in [no_wrap, jit_wrap]: - assert_equal(self, buffer_state.current_position, 6) + assert_equal(self, buffer_state.insert_position, 10) else: - assert_equal(self, buffer_state.current_position, [3, 3]) + assert_equal(self, buffer_state.insert_position, [5, 5]) # This sample contains elements from both the beggining and the end of # the buffer. @@ -467,9 +553,9 @@ def testQueueWrappedSample(self, wrapper): assert_equal(self, samples, jnp.array([6, 7, 8, 9, 10, 11, 12, 13])) assert_equal(self, replay_buffer.size(buffer_state), 2) if jax.device_count() == 1 or wrapper in [no_wrap, jit_wrap]: - assert_equal(self, buffer_state.current_position, 6) + assert_equal(self, buffer_state.insert_position, 10) else: - assert_equal(self, buffer_state.current_position, [3, 3]) + assert_equal(self, buffer_state.insert_position, [5, 5]) @parameterized.parameters(WRAPPERS) def testQueueBatchSizeEqualsMaxSize(self, wrapper): @@ -493,7 +579,10 @@ def testQueueBatchSizeEqualsMaxSize(self, wrapper): buffer_state = replay_buffer.insert(buffer_state, jnp.arange(batch_size)) buffer_state, samples = replay_buffer.sample(buffer_state) assert_equal(self, samples, range(batch_size)) - assert_equal(self, buffer_state.current_size, 0) + if jax.device_count() == 1 or wrapper in [no_wrap, jit_wrap]: + assert_equal(self, buffer_state.sample_position, 8) + else: + assert_equal(self, buffer_state.sample_position, [4, 4]) buffer_state = replay_buffer.insert( buffer_state, jnp.zeros(batch_size, dtype=jnp.int32) @@ -503,7 +592,10 @@ def testQueueBatchSizeEqualsMaxSize(self, wrapper): ) buffer_state, samples = replay_buffer.sample(buffer_state) assert_equal(self, samples, [1] * batch_size) - assert_equal(self, buffer_state.current_size, 0) + if jax.device_count() == 1 or wrapper in [no_wrap, jit_wrap]: + assert_equal(self, buffer_state.sample_position, 8) + else: + assert_equal(self, buffer_state.sample_position, [4, 4]) @parameterized.parameters(WRAPPERS) def testQueueSampleFromEmpty(self, wrapper) -> None: @@ -526,19 +618,25 @@ def testQueueSampleFromEmpty(self, wrapper) -> None: buffer_state = replay_buffer.insert(buffer_state, jnp.arange(batch_size)) buffer_state, samples = replay_buffer.sample(buffer_state) assert_equal(self, samples, range(batch_size)) - assert_equal(self, buffer_state.current_size, 0) + if jax.device_count() == 1 or wrapper in [no_wrap, jit_wrap]: + assert_equal(self, buffer_state.sample_position, 10) + else: + assert_equal(self, buffer_state.sample_position, [5, 5]) with self.assertRaisesRegex( ValueError, 'Trying to sample 10 elements, but only 0 available.' ): replay_buffer.sample(buffer_state) - assert_equal(self, buffer_state.current_size, 0) + if jax.device_count() == 1 or wrapper in [no_wrap, jit_wrap]: + assert_equal(self, buffer_state.sample_position, 10) + else: + assert_equal(self, buffer_state.sample_position, [5, 5]) buffer_state = replay_buffer.insert(buffer_state, jnp.arange(10, 14)) if jax.device_count() == 1 or wrapper in [no_wrap, jit_wrap]: - assert_equal(self, buffer_state.current_size, 4) + assert_equal(self, buffer_state.sample_position, 6) else: - assert_equal(self, buffer_state.current_size, [2, 2]) + assert_equal(self, buffer_state.sample_position, [3, 3]) with self.assertRaisesRegex( ValueError, 'Trying to sample 10 elements, but only 4 available.' ): diff --git a/brax/visualizer/js/system.js b/brax/visualizer/js/system.js index 713231d1..15280c64 100644 --- a/brax/visualizer/js/system.js +++ b/brax/visualizer/js/system.js @@ -51,6 +51,16 @@ function getMeshAxisSize(geom) { return size * 2; } +function createCylinder(radius, height, mat) { + const geometry = new THREE.CylinderGeometry(radius, radius, height, 32); + mat.side = THREE.DoubleSide; + const cyl = new THREE.Mesh(geometry, mat); + cyl.baseMaterial = cyl.material; + cyl.castShadow = true; + cyl.layers.enable(1); + return cyl; +} + function createCapsule(capsule, mat) { const sphere_geom = new THREE.SphereGeometry(capsule.radius, 16, 16); const cylinder_geom = new THREE.CylinderGeometry( @@ -80,7 +90,7 @@ function createCapsule(capsule, mat) { } function createBox(box, mat) { - const geom = new THREE.BoxBufferGeometry( + const geom = new THREE.BoxGeometry( 2 * box.halfsize[0], 2 * box.halfsize[1], 2 * box.halfsize[2]); const mesh = new THREE.Mesh(geom, mat); mesh.castShadow = true; @@ -141,6 +151,7 @@ function createScene(system) { // Add a world axis for debugging. const worldAxis = new THREE.AxesHelper(100); + const qRotx90 = new THREE.Quaternion(0.70710677, 0.0, 0.0, 0.7071067); worldAxis.visible = false; scene.add(worldAxis); @@ -176,6 +187,9 @@ function createScene(system) { } else if (collider.name == 'Mesh') { child = createMesh(collider, mat); axisSize = getMeshAxisSize(collider); + } else if (collider.name == 'Cylinder') { + child = createCylinder(collider.radius, collider.length, mat); + axisSize = 2 * Math.max(collider.radius, collider.length); } else if ('clippedPlane' in collider) { console.log('clippedPlane not implemented'); return; @@ -184,9 +198,13 @@ function createScene(system) { return; } if (collider.transform.rot) { - child.quaternion.set( - collider.transform.rot[1], collider.transform.rot[2], - collider.transform.rot[3], collider.transform.rot[0]); + const quat = new THREE.Quaternion( + collider.transform.rot[1], collider.transform.rot[2], + collider.transform.rot[3], collider.transform.rot[0]); + if (collider.name == 'Cylinder') { + quat.multiply(qRotx90) + } + child.quaternion.fromArray(quat.toArray()); } if (collider.transform.pos) { child.position.set(