Skip to content

Commit

Permalink
add an exp2 primitive and lax.exp2
Browse files Browse the repository at this point in the history
part of fixing jax-ml/jax-triton#204
  • Loading branch information
mattjj committed Jul 28, 2023
1 parent 640ee1e commit 560ede0
Show file tree
Hide file tree
Showing 7 changed files with 25 additions and 4 deletions.
1 change: 1 addition & 0 deletions jax/_src/internal_test_util/lax_test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def lax_ops():
),
op_record("is_finite", 1, float_dtypes, test_util.rand_small),
op_record("exp", 1, float_dtypes + complex_dtypes, test_util.rand_small),
op_record("exp2", 1, float_dtypes + complex_dtypes, test_util.rand_small),
# TODO(b/142975473): on CPU, expm1 for float64 is only accurate to ~float32
# precision.
op_record(
Expand Down
15 changes: 13 additions & 2 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,10 @@ def exp(x: ArrayLike) -> Array:
r"""Elementwise exponential: :math:`e^x`."""
return exp_p.bind(x)

def exp2(x: ArrayLike) -> Array:
r"""Elementwise base-2 exponential: :math:`2^x`."""
return exp2_p.bind(x)

def expm1(x: ArrayLike) -> Array:
r"""Elementwise :math:`e^{x} - 1`."""
return expm1_p.bind(x)
Expand Down Expand Up @@ -1757,10 +1761,17 @@ def _round_lower(ctx, x, *, rounding_method):

exp_p = standard_unop(_float | _complex, 'exp')
ad.defjvp2(exp_p, lambda g, ans, x: mul(g, ans))
# For exp_p it is more efficient to use the reconstructed output for the vjp
# rule instead of computing it again from the input.
mlir.register_lowering(exp_p, partial(_nary_lower_hlo, hlo.ExpOp))

exp2_p = standard_unop(_float | _complex, 'exp2')
ad.defjvp2(exp2_p, lambda g, ans, x: mul(log(_const(x, 2)), mul(g, ans)))
def _exp2_lower(ctx, x):
x_aval, = ctx.avals_in
log2 = mlir.ir_constant(np.array(np.log(2), x_aval.dtype))
log2 = mlir.broadcast_in_dim(ctx, log2, x_aval, broadcast_dimensions=())
return hlo.ExpOp(hlo.MulOp(log2, x).result).results
mlir.register_lowering(exp2_p, _exp2_lower)

log_p = standard_unop(_float | _complex, 'log')
ad.defjvp(log_p, lambda g, x: div(g, x))
mlir.register_lowering(log_p, partial(_nary_lower_hlo, hlo.LogOp))
Expand Down
1 change: 1 addition & 0 deletions jax/_src/lax_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def round(x):
is_finite = np.isfinite

exp = np.exp
exp2 = np.exp2
expm1 = np.expm1
log = np.log
log1p = np.log1p
Expand Down
3 changes: 2 additions & 1 deletion jax/_src/numpy/ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,8 +429,9 @@ def log10(x: ArrayLike, /) -> Array:
@_wraps(np.exp2, module='numpy')
@partial(jit, inline=True)
def exp2(x: ArrayLike, /) -> Array:
assert False
x, = promote_args_inexact("exp2", x)
return lax.exp(lax.mul(lax.log(_constant_like(x, 2)), x))
return lax.exp2(x)


@_wraps(np.signbit, module='numpy')
Expand Down
2 changes: 2 additions & 0 deletions jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1575,6 +1575,8 @@ def _integer_pow(x, *, y: int, _in_avals: Sequence[core.ShapedArray],

tf_impl_with_avals[lax.integer_pow_p] = _integer_pow
tf_impl[lax.exp_p] = tf.math.exp
tf_impl[lax_internal.exp2_p] = lambda x: \
tf.math.exp(tf.math.multiply(tf.math.log(tf.constant(2, x.dtype)), x))
tf_impl[lax.expm1_p] = tf.math.expm1
tf_impl[lax.log_p] = tf.math.log
tf_impl[lax.log1p_p] = tf.math.log1p
Expand Down
1 change: 1 addition & 0 deletions jax/lax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
eq_p as eq_p,
exp as exp,
exp_p as exp_p,
exp2 as exp2,
expand_dims as expand_dims,
expm1 as expm1,
expm1_p as expm1_p,
Expand Down
6 changes: 5 additions & 1 deletion tests/lax_autodiff_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def grad_test_spec(op, nargs, order, rng_factory, dtypes, name=None, tol=None):

grad_test_spec(lax.exp, nargs=1, order=2, rng_factory=jtu.rand_small,
dtypes=grad_inexact_dtypes),
grad_test_spec(lax.exp2, nargs=1, order=2, rng_factory=jtu.rand_small,
dtypes=grad_inexact_dtypes),
grad_test_spec(lax.expm1, nargs=1, order=2, rng_factory=jtu.rand_default,
dtypes=grad_inexact_dtypes),
grad_test_spec(lax.log, nargs=1, order=2, rng_factory=jtu.rand_positive,
Expand All @@ -79,7 +81,7 @@ def grad_test_spec(op, nargs, order, rng_factory, dtypes, name=None, tol=None):
grad_test_spec(lax.cosh, nargs=1, order=2, rng_factory=jtu.rand_default,
dtypes=grad_inexact_dtypes, tol=1e-5),
grad_test_spec(lax.tanh, nargs=1, order=2, rng_factory=jtu.rand_default,
dtypes=grad_inexact_dtypes, tol=1e-4),
dtypes=grad_inexact_dtypes, tol=2e-4),
grad_test_spec(lax.sin, nargs=1, order=2, rng_factory=jtu.rand_default,
dtypes=grad_inexact_dtypes, tol={np.float32: 5e-1}),
grad_test_spec(lax.cos, nargs=1, order=2, rng_factory=jtu.rand_default,
Expand Down Expand Up @@ -213,6 +215,8 @@ def testOpGrad(self, op, rng_factory, shapes, dtype, order, tol):
raise SkipTest("pow grad imprecise on tpu")
if op is lax.cos:
order = 1 # 2nd-order gradient is imprecise on TPU.
if op is lax.log:
order = 1 # 2nd-order gradient is imprecise on TPU.

tol = jtu.join_tolerance(1.5e-1, tol) if jtu.num_float_bits(dtype) == 32 else tol
args = tuple(rng(shape, dtype) for shape in shapes)
Expand Down

0 comments on commit 560ede0

Please sign in to comment.