Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Flax activation functions and corresponding tests #35191

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
24 changes: 16 additions & 8 deletions src/transformers/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,26 @@
logger = logging.get_logger(__name__)


def quick_gelu(x):
return x * jax.nn.sigmoid(1.702 * x)


ACT2FN = {
"gelu": partial(nn.gelu, approximate=False),
"gelu_10": lambda x: jnp.clip(nn.gelu(x, approximate=False), a_min=-10.0, a_max=10.0),
"gelu_fast": lambda x: 0.5 * x * (1.0 + nn.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x))),
"gelu_new": lambda x: 0.5 * x * (1.0 + nn.tanh((2.0 / jnp.pi) ** 0.5 * (x + 0.044715 * jnp.power(x, 3.0)))),
"gelu_python": lambda x: x * 0.5 * (1.0 + jax.scipy.special.erf(x / 2.0**0.5)),
"gelu_pytorch_tanh": partial(nn.gelu, approximate=True),
"gelu_accurate": lambda x: 0.5 * x * (1 + nn.tanh((2 / jnp.pi) ** 0.5 * (x + 0.044715 * jnp.power(x, 3)))),
"laplace": lambda x: 0.5 * (1.0 + jax.scipy.special.erf(jnp.divide(x - 0.707107, 0.282095 * 2.0**0.5))),
"leaky_relu": nn.leaky_relu,
"linear": lambda x: x,
"mish": lambda x: x * nn.tanh(nn.softplus(x)),
"quick_gelu": lambda x: x * nn.sigmoid(1.702 * x),
"relu": nn.relu,
"silu": nn.swish,
"relu2": lambda x: jnp.square(nn.relu(x)),
"relu6": nn.relu6,
"sigmoid": nn.sigmoid,
"silu": nn.silu,
"swish": nn.swish,
"gelu_new": partial(nn.gelu, approximate=True),
"quick_gelu": quick_gelu,
"gelu_pytorch_tanh": partial(nn.gelu, approximate=True),
"tanh": nn.tanh,
}


Expand Down
31 changes: 31 additions & 0 deletions tests/test_modeling_flax_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,3 +1149,34 @@ def test_gradient_checkpointing(self):
# ensure that the outputs remain precisely equal
for output, remat_output in zip(outputs, remat_outputs):
self.assertTrue((output == remat_output).all())

def test_activation_fns(self):
# Assuming Torch activation functions of `activations.ACT2FN` as a base, compares Flax implementations
# to produce equal/close results.

import jax.numpy as jnp
import torch

from transformers.activations import ACT2FN as TORCH_ACT2FN
from transformers.modeling_flax_utils import ACT2FN as FLAX_ACT2FN

limit_left = -10.0
limit_right = 10.0
x = np.linspace(limit_left, limit_right, 500)

for fn_name in TORCH_ACT2FN.keys():
if fn_name in FLAX_ACT2FN:
flax_act_fn = FLAX_ACT2FN[fn_name]
torch_act_fn = TORCH_ACT2FN[fn_name]
torch_x = torch.Tensor(x)
jax_x: jnp.ndarray = jnp.float32(x)
torch_y = np.float32(torch_act_fn(torch_x))
flax_y = np.float32(flax_act_fn(jax_x))
self.assertEqual(torch_y.shape, flax_y.shape)
np.testing.assert_allclose(
torch_y,
flax_y,
atol=1e-6,
rtol=1e-6,
err_msg=f"ACT2FN '{fn_name}' of torch and flax are not close enough.",
)