From 560ede0ff117d7d9c71b22e576cea96f414508c0 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 28 Jul 2023 09:58:25 -0700 Subject: [PATCH] add an exp2 primitive and lax.exp2 part of fixing https://github.com/jax-ml/jax-triton/issues/204 --- jax/_src/internal_test_util/lax_test_util.py | 1 + jax/_src/lax/lax.py | 15 +++++++++++++-- jax/_src/lax_reference.py | 1 + jax/_src/numpy/ufuncs.py | 3 ++- jax/experimental/jax2tf/jax2tf.py | 2 ++ jax/lax/__init__.py | 1 + tests/lax_autodiff_test.py | 6 +++++- 7 files changed, 25 insertions(+), 4 deletions(-) diff --git a/jax/_src/internal_test_util/lax_test_util.py b/jax/_src/internal_test_util/lax_test_util.py index 8ef8c12ba917..560b56e02cdc 100644 --- a/jax/_src/internal_test_util/lax_test_util.py +++ b/jax/_src/internal_test_util/lax_test_util.py @@ -167,6 +167,7 @@ def lax_ops(): ), op_record("is_finite", 1, float_dtypes, test_util.rand_small), op_record("exp", 1, float_dtypes + complex_dtypes, test_util.rand_small), + op_record("exp2", 1, float_dtypes + complex_dtypes, test_util.rand_small), # TODO(b/142975473): on CPU, expm1 for float64 is only accurate to ~float32 # precision. op_record( diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 2f98b6daae6f..d13ef24afe60 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -302,6 +302,10 @@ def exp(x: ArrayLike) -> Array: r"""Elementwise exponential: :math:`e^x`.""" return exp_p.bind(x) +def exp2(x: ArrayLike) -> Array: + r"""Elementwise base-2 exponential: :math:`2^x`.""" + return exp2_p.bind(x) + def expm1(x: ArrayLike) -> Array: r"""Elementwise :math:`e^{x} - 1`.""" return expm1_p.bind(x) @@ -1757,10 +1761,17 @@ def _round_lower(ctx, x, *, rounding_method): exp_p = standard_unop(_float | _complex, 'exp') ad.defjvp2(exp_p, lambda g, ans, x: mul(g, ans)) -# For exp_p it is more efficient to use the reconstructed output for the vjp -# rule instead of computing it again from the input. mlir.register_lowering(exp_p, partial(_nary_lower_hlo, hlo.ExpOp)) +exp2_p = standard_unop(_float | _complex, 'exp2') +ad.defjvp2(exp2_p, lambda g, ans, x: mul(log(_const(x, 2)), mul(g, ans))) +def _exp2_lower(ctx, x): + x_aval, = ctx.avals_in + log2 = mlir.ir_constant(np.array(np.log(2), x_aval.dtype)) + log2 = mlir.broadcast_in_dim(ctx, log2, x_aval, broadcast_dimensions=()) + return hlo.ExpOp(hlo.MulOp(log2, x).result).results +mlir.register_lowering(exp2_p, _exp2_lower) + log_p = standard_unop(_float | _complex, 'log') ad.defjvp(log_p, lambda g, x: div(g, x)) mlir.register_lowering(log_p, partial(_nary_lower_hlo, hlo.LogOp)) diff --git a/jax/_src/lax_reference.py b/jax/_src/lax_reference.py index 96d3b3dc2564..0fee154d03c1 100644 --- a/jax/_src/lax_reference.py +++ b/jax/_src/lax_reference.py @@ -45,6 +45,7 @@ def round(x): is_finite = np.isfinite exp = np.exp +exp2 = np.exp2 expm1 = np.expm1 log = np.log log1p = np.log1p diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 65d09ebc7e26..ac350cb1f699 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -429,8 +429,9 @@ def log10(x: ArrayLike, /) -> Array: @_wraps(np.exp2, module='numpy') @partial(jit, inline=True) def exp2(x: ArrayLike, /) -> Array: + assert False x, = promote_args_inexact("exp2", x) - return lax.exp(lax.mul(lax.log(_constant_like(x, 2)), x)) + return lax.exp2(x) @_wraps(np.signbit, module='numpy') diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index d546e98730b1..ad69b4dea4ca 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1575,6 +1575,8 @@ def _integer_pow(x, *, y: int, _in_avals: Sequence[core.ShapedArray], tf_impl_with_avals[lax.integer_pow_p] = _integer_pow tf_impl[lax.exp_p] = tf.math.exp +tf_impl[lax_internal.exp2_p] = lambda x: \ + tf.math.exp(tf.math.multiply(tf.math.log(tf.constant(2, x.dtype)), x)) tf_impl[lax.expm1_p] = tf.math.expm1 tf_impl[lax.log_p] = tf.math.log tf_impl[lax.log1p_p] = tf.math.log1p diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index 712d628c5ee2..c7591f0fa748 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -93,6 +93,7 @@ eq_p as eq_p, exp as exp, exp_p as exp_p, + exp2 as exp2, expand_dims as expand_dims, expm1 as expm1, expm1_p as expm1_p, diff --git a/tests/lax_autodiff_test.py b/tests/lax_autodiff_test.py index 3ece195b0535..569e106d744c 100644 --- a/tests/lax_autodiff_test.py +++ b/tests/lax_autodiff_test.py @@ -68,6 +68,8 @@ def grad_test_spec(op, nargs, order, rng_factory, dtypes, name=None, tol=None): grad_test_spec(lax.exp, nargs=1, order=2, rng_factory=jtu.rand_small, dtypes=grad_inexact_dtypes), + grad_test_spec(lax.exp2, nargs=1, order=2, rng_factory=jtu.rand_small, + dtypes=grad_inexact_dtypes), grad_test_spec(lax.expm1, nargs=1, order=2, rng_factory=jtu.rand_default, dtypes=grad_inexact_dtypes), grad_test_spec(lax.log, nargs=1, order=2, rng_factory=jtu.rand_positive, @@ -79,7 +81,7 @@ def grad_test_spec(op, nargs, order, rng_factory, dtypes, name=None, tol=None): grad_test_spec(lax.cosh, nargs=1, order=2, rng_factory=jtu.rand_default, dtypes=grad_inexact_dtypes, tol=1e-5), grad_test_spec(lax.tanh, nargs=1, order=2, rng_factory=jtu.rand_default, - dtypes=grad_inexact_dtypes, tol=1e-4), + dtypes=grad_inexact_dtypes, tol=2e-4), grad_test_spec(lax.sin, nargs=1, order=2, rng_factory=jtu.rand_default, dtypes=grad_inexact_dtypes, tol={np.float32: 5e-1}), grad_test_spec(lax.cos, nargs=1, order=2, rng_factory=jtu.rand_default, @@ -213,6 +215,8 @@ def testOpGrad(self, op, rng_factory, shapes, dtype, order, tol): raise SkipTest("pow grad imprecise on tpu") if op is lax.cos: order = 1 # 2nd-order gradient is imprecise on TPU. + if op is lax.log: + order = 1 # 2nd-order gradient is imprecise on TPU. tol = jtu.join_tolerance(1.5e-1, tol) if jtu.num_float_bits(dtype) == 32 else tol args = tuple(rng(shape, dtype) for shape in shapes)