Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 533251755
Change-Id: I981f46e82e04f6d0afbb6cc88ca3982613ff8ef5
  • Loading branch information
Brax Team authored and erikfrey committed May 18, 2023
1 parent c2cd14c commit aebd8b8
Show file tree
Hide file tree
Showing 50 changed files with 982 additions and 431 deletions.
2 changes: 1 addition & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
include brax/envs/assets/*.xml
recursive-include brax/test_data *.xml *.stl *.obj *.urdf
recursive-include brax/visualizer *
recursive-include brax/visualizer *
29 changes: 11 additions & 18 deletions brax/actuator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,40 +15,33 @@
# pylint:disable=g-multiple-import
"""Functions for applying actuators to a physics pipeline."""

from brax import scan
from brax.base import System
from jax import numpy as jp


def to_tau(sys: System, act: jp.ndarray, q: jp.ndarray) -> jp.ndarray:
def to_tau(
sys: System, act: jp.ndarray, q: jp.ndarray, qd: jp.ndarray
) -> jp.ndarray:
"""Convert actuator to a joint force tau.
Args:
sys: system defining the kinematic tree and other properties
act: (act_size,) actuator force input vector
q: joint position vector
qd: joint velocity vector
Returns:
tau: (qd_size,) vector of joint forces
"""
if sys.act_size() == 0:
return jp.zeros(sys.qd_size())

def act_fn(act_type, act, actuator, q, qd_idx):
if act_type not in ('p', 'm'):
raise RuntimeError(f'unrecognized act type: {act_type}')

force = jp.clip(act, actuator.ctrl_range[:, 0], actuator.ctrl_range[:, 1])
if act_type == 'p':
force -= q # positional actuators have a bias
tau = actuator.gear * force

return tau, qd_idx

qd_idx = jp.arange(sys.qd_size())
tau, qd_idx = scan.actuator_types(
sys, act_fn, 'aaqd', 'a', act, sys.actuator, q, qd_idx
)
tau = jp.zeros(sys.qd_size()).at[qd_idx].add(tau)
q, qd = q[sys.actuator.q_id], qd[sys.actuator.qd_id]
ctrl_range = sys.actuator.ctrl_range
act = jp.clip(act, ctrl_range[:, 0], ctrl_range[:, 1])
# TODO: incorporate gain
act = act + q * sys.actuator.bias_q + qd * sys.actuator.bias_qd
act_force = sys.actuator.gear * act
tau = jp.zeros(sys.qd_size()).at[sys.actuator.qd_id].add(act_force)

return tau
22 changes: 20 additions & 2 deletions brax/actuator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_motor(self, pipeline, dt, n, decimal):

tau = jp.array([0.5 * 9.81]) # -mgl sin(theta)
act = jp.array([1.0 / 150.0 * 0.5 * 9.81])
tau2 = actuator.to_tau(sys, act, q)
tau2 = actuator.to_tau(sys, act, q, qd)
np.testing.assert_array_almost_equal(tau, tau2, 5)

q2, qd2 = _actuator_step(pipeline, sys, q, qd, act=act, dt=dt, n=n)
Expand All @@ -73,7 +73,7 @@ def test_position(self):

# a position actuator at the bottom should not move the pendulum
act = jp.array([theta])
tau = actuator.to_tau(sys, act, q)
tau = actuator.to_tau(sys, act, q, qd)
np.testing.assert_array_almost_equal(tau, jp.array([0]), 5)

# put the pendulum into the horizontal position with the positional actuator
Expand All @@ -83,6 +83,24 @@ def test_position(self):
q2, _ = _actuator_step(g_pipeline, sys, q, qd, act=act, dt=0.01, n=1)
np.testing.assert_array_almost_equal(q2, jp.array([0]), 1)

def test_velocity(self):
"""Tests a single pendulum with velocity actuator."""
sys = test_utils.load_fixture('single_pendulum_velocity.xml')
mj_model = test_utils.load_fixture_mujoco('single_pendulum_velocity.xml')
mj_data = mujoco.MjData(mj_model)
q, qd = jp.array(mj_data.qpos), jp.array(mj_data.qvel)
theta = jp.pi / 2.0 # pendulum is vertical
q = jp.array([theta])

act = jp.array([0])
tau = actuator.to_tau(sys, act, q, qd)
np.testing.assert_array_almost_equal(tau, jp.array([0]), 5)

# set the act to rotate at 1/s
act = jp.array([1])
_, qd = _actuator_step(g_pipeline, sys, q, qd, act=act, dt=0.001, n=200)
np.testing.assert_array_almost_equal(qd, jp.array([1]), 3)

@parameterized.parameters((g_pipeline,), (s_pipeline,), (p_pipeline,))
def test_three_link_pendulum(self, pipeline):
"""Tests a three link pendulum with a motor actuator."""
Expand Down
65 changes: 41 additions & 24 deletions brax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ class Link(Base):
constraint_stiffness: jp.ndarray
constraint_vel_damping: jp.ndarray
constraint_limit_stiffness: jp.ndarray
# only used by `brax.physics.spring` and `brax.physics.pbd`:
# only used by `brax.physics.spring` and `brax.physics.positional`:
constraint_ang_damping: jp.ndarray


Expand All @@ -284,6 +284,7 @@ class DoF(Base):
damping: restorative force back to zero velocity
limit: tuple of min, max angle limits
invweight: diagonal inverse inertia at init_qpos
solver_params: (7,) limit constraint solver parameters
"""

motion: Motion
Expand All @@ -293,6 +294,7 @@ class DoF(Base):
limit: Tuple[jp.ndarray, jp.ndarray]
# only used by `brax.physics.generalized`:
invweight: jp.ndarray
solver_params: jp.ndarray


@struct.dataclass
Expand All @@ -305,12 +307,14 @@ class Geometry(Base):
relative to the world frame in the case of unparented geometry
friction: resistance encountered when sliding against another geometry
elasticity: bounce/restitution encountered when hitting another geometry
solver_params: (7,) solver parameters (reference, impedance)
"""

link_idx: Optional[jp.ndarray]
transform: Transform
friction: jp.ndarray
elasticity: jp.ndarray
solver_params: jp.ndarray


@struct.dataclass
Expand Down Expand Up @@ -354,6 +358,21 @@ class Box(Geometry):
rgba: Optional[jp.ndarray] = None


@struct.dataclass
class Cylinder(Geometry):
"""A cylinder.
Attributes:
radius: (1,) radius of the top and bottom of the cylinder
length: (1,) length of the cylinder
rgba: (4,) the rgba to display in the renderer
"""

radius: jp.ndarray
length: jp.ndarray
rgba: Optional[jp.ndarray] = None


@struct.dataclass
class Plane(Geometry):
"""An infinite plane whose normal points at +z in its coordinate space.
Expand Down Expand Up @@ -410,6 +429,7 @@ class Contact(Base):
two geometries are interpenetrating, negative means they are not
friction: resistance encountered when sliding against another geometry
elasticity: bounce/restitution encountered when hitting another geometry
solver_params: (7,) collision constraint solver parameters
link_idx: Tuple of link indices participating in contact. The second part
of the tuple can be None if the second geometry is static.
"""
Expand All @@ -418,8 +438,9 @@ class Contact(Base):
normal: jp.ndarray
penetration: jp.ndarray
friction: jp.ndarray
# only used by `brax.physics.spring` and `brax.physics.pbd`:
# only used by `brax.physics.spring` and `brax.physics.positional`:
elasticity: jp.ndarray
solver_params: jp.ndarray

link_idx: Tuple[jp.ndarray, Optional[jp.ndarray]]

Expand All @@ -429,13 +450,20 @@ class Actuator(Base):
"""Actuator, transforms an input signal into a force (motor or thruster).
Attributes:
ctrl_range: (num_actuators, 2) control range for each actuator
gear: (num_actuators,) a list of floats used as a scaling factor for each
actuator torque output
q_id: (num_actuators,) q index associated with an actuator
qd_id: (num_actuators,) qd index associated with an actuator
ctrl_range: (num_actuators, 2) actuator control range
gear: (num_actuators,) scaling factor for each actuator torque output
bias_q: (num_actuators,) bias applied by q (e.g. position actuators)
bias_qd: (num_actuators,) bias applied by qd (e.g. velocity actuators)
"""

q_id: jp.ndarray
qd_id: jp.ndarray
ctrl_range: jp.ndarray
gear: jp.ndarray
bias_q: jp.ndarray
bias_qd: jp.ndarray


@struct.dataclass
Expand Down Expand Up @@ -464,13 +492,13 @@ class System:
Attributes:
dt: timestep used for the simulation
gravity: (3,) linear universal force applied during forward dynamics
viscosity: (1,) viscosity of the medium applied to all links
density: (1,) density of the medium applied to all links
link: (num_link,) the links in the system
dof: (qd_size,) every degree of freedom for the system
geoms: list of batched geoms grouped by type
actuator: actuators that can be applied to links
init_q: (q_size,) initial q position for the system
solver_params_joint: (7,) joint limit constraint solver parameters
solver_params_contact: (7,) collision constraint solver parameters
vel_damping: (1,) linear vel damping applied to each body.
ang_damping: (1,) angular vel damping applied to each body.
baumgarte_erp: how aggressively interpenetrating bodies should push away\
Expand All @@ -492,46 +520,35 @@ class System:
* '3': spherical, 3 dof, like a ball joint
link_parents: (num_link,) int list specifying the index of each link's
parent link, or -1 if the link has no parent
actuator_types: (num_actuators,) string specifying the actuator types:
* 't': torque
* 'p': position
actuator_link_id: (num_actuators,) the link id associated with each actuator
actuator_qid: (num_actuators,) the q index associated with each actuator
actuator_qdid: (num_actuators,) the qd index associated with each actuator
matrix_inv_iterations: maximum number of iterations of the matrix inverse
solver_iterations: maximum number of iterations of the constraint solver
solver_maxls: maximum number of line searches of the constraint solver
"""

dt: jp.ndarray
gravity: jp.ndarray
viscosity: jp.float32
density: jp.float32
link: Link
dof: DoF
geoms: List[Geometry]
actuator: Actuator
init_q: jp.ndarray
# only used in `brax.physics.generalized`
solver_params_joint: jp.ndarray
solver_params_contact: jp.ndarray
# only used in `brax.physics.spring` and `brax.physics.pbd`:
# only used in `brax.physics.spring` and `brax.physics.positional`:
vel_damping: jp.float32
ang_damping: jp.float32
baumgarte_erp: jp.float32
spring_mass_scale: jp.float32
spring_inertia_scale: jp.float32
# only used in `brax.physics.positional`
# only used in `brax.physics.positional`:
joint_scale_ang: jp.float32
joint_scale_pos: jp.float32
collide_scale: jp.float32

# non-pytree nodes
geom_masks: List[int] = struct.field(pytree_node=False)
link_names: List[str] = struct.field(pytree_node=False)
link_types: str = struct.field(pytree_node=False)
link_parents: Tuple[int, ...] = struct.field(pytree_node=False)
actuator_types: str = struct.field(pytree_node=False)
actuator_link_id: List[int] = struct.field(pytree_node=False)
actuator_qid: List[int] = struct.field(pytree_node=False)
actuator_qdid: List[int] = struct.field(pytree_node=False)
# only used in `brax.physics.generalized`:
matrix_inv_iterations: int = struct.field(pytree_node=False)
solver_iterations: int = struct.field(pytree_node=False)
Expand Down Expand Up @@ -595,7 +612,7 @@ def qd_size(self) -> int:

def act_size(self) -> int:
"""Returns the act dimension for the system."""
return sum({'m': 1, 'p': 1}[act_typ] for act_typ in self.actuator_types)
return self.actuator.q_id.shape[0]


# below are some operation dispatch derivations
Expand Down
3 changes: 2 additions & 1 deletion brax/envs/humanoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,8 @@ def _get_obs(
com_ang = xd_i.ang
com_velocity = jp.hstack([com_vel, com_ang])

qfrc_actuator = actuator.to_tau(self.sys, action, pipeline_state.q)
qfrc_actuator = actuator.to_tau(
self.sys, action, pipeline_state.q, pipeline_state.qd)

# external_contact_forces are excluded
return jp.concatenate([
Expand Down
3 changes: 2 additions & 1 deletion brax/envs/humanoidstandup.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,8 @@ def _get_obs(
com_ang = xd_i.ang
com_velocity = jp.hstack([com_vel, com_ang])

qfrc_actuator = actuator.to_tau(self.sys, action, pipeline_state.q)
qfrc_actuator = actuator.to_tau(
self.sys, action, pipeline_state.q, pipeline_state.qd)

# external_contact_forces are excluded
return jp.concatenate([
Expand Down
83 changes: 83 additions & 0 deletions brax/fluid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright 2023 The Brax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint:disable=g-multiple-import
"""Functions for forces/torques through fluids."""

from typing import Tuple, Union
from brax.base import Force, Motion, System, Transform
import jax
import jax.numpy as jp


def _box_viscosity(
box: jp.ndarray, xd_i: Motion, viscosity: jp.ndarray
) -> Force:
"""Gets force due to motion through a viscous fluid."""
diam = jp.mean(box, axis=-1)
ang_scale = -jp.pi * diam**3 * viscosity
vel_scale = -3.0 * jp.pi * diam * viscosity
frc = Force(
ang=ang_scale[:, None] * xd_i.ang, vel=vel_scale[:, None] * xd_i.vel
)
return frc


def _box_density(box: jp.ndarray, xd_i: Motion, density: jp.ndarray) -> Force:
"""Gets force due to motion through dense fluid."""

@jax.vmap
def apply(b: jp.ndarray, xd: Motion) -> Force:
box_mult_vel = jp.array([b[1] * b[2], b[0] * b[2], b[0] * b[1]])
vel = -0.5 * density * box_mult_vel * jp.abs(xd.vel) * xd.vel
box_mult_ang = jp.array([
b[0] * (b[1] ** 4 + b[2] ** 4),
b[1] * (b[0] ** 4 + b[2] ** 4),
b[2] * (b[0] ** 4 + b[1] ** 4),
])
ang = -1.0 * density * box_mult_ang * jp.abs(xd.ang) * xd.ang / 64.0
return Force(vel=vel, ang=ang)

return apply(box, xd_i)


def force(
sys: System,
x: Transform,
xd: Motion,
mass: jp.ndarray,
inertia: jp.ndarray,
subtree_com: Union[jp.ndarray, None] = None,
) -> Tuple[Force, jp.ndarray]:
"""Returns force due to motion through a fluid."""
# get the velocity at the com position/orientation
x_i = x.vmap().do(sys.link.inertia.transform)
# TODO: remove subtree_com when xd is fixed for stacked joints
offset = x_i.pos - x.pos if subtree_com is None else x_i.pos - subtree_com
xd_i = x_i.replace(pos=offset).vmap().do(xd)

# TODO: add ellipsoid fluid model from mujoco
# TODO: consider adding wind from mj.opt.wind
diag_inertia = jax.vmap(jp.diag)(inertia)
diag_inertia_v = jp.repeat(diag_inertia, 3, axis=-2).reshape((-1, 3, 3))
diag_inertia_v *= jp.ones((3, 3)) - 2 * jp.eye(3)
box = 6.0 * jp.clip(jp.sum(diag_inertia_v, axis=-1), a_min=1e-12)
box = jp.sqrt(box / mass[:, None])

frc = _box_viscosity(box, xd_i, sys.viscosity)
frc += _box_density(box, xd_i, sys.density)

# rotate back to the world orientation
frc = Transform.create(rot=x_i.rot).vmap().do(frc)
return frc, x_i.pos
Loading

0 comments on commit aebd8b8

Please sign in to comment.