Skip to content

Commit

Permalink
Add Flax activation functions and corresponding tests
Browse files Browse the repository at this point in the history
  • Loading branch information
m-momeni committed Dec 10, 2024
1 parent 10feacd commit 66223b1
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 9 deletions.
24 changes: 16 additions & 8 deletions src/transformers/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,26 @@
logger = logging.get_logger(__name__)


def quick_gelu(x):
return x * jax.nn.sigmoid(1.702 * x)


ACT2FN = {
"gelu": partial(nn.gelu, approximate=False),
"gelu_10": lambda x: jnp.clip(nn.gelu(x, approximate=False), min=-10, max=10),
"gelu_fast": lambda x: 0.5 * x * (1.0 + nn.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x))),
"gelu_new": lambda x: 0.5 * x * (1.0 + nn.tanh((2.0 / jnp.pi) ** 0.5 * (x + 0.044715 * jnp.pow(x, 3.0)))),
"gelu_python": lambda x: x * 0.5 * (1.0 + jax.scipy.special.erf(x / 2.0**0.5)),
"gelu_pytorch_tanh": partial(nn.gelu, approximate=True),
"gelu_accurate": lambda x: 0.5 * x * (1 + nn.tanh((2 / jnp.pi) ** 0.5 * (x + 0.044715 * jnp.pow(x, 3)))),
"laplace": lambda x: 0.5 * (1.0 + jax.scipy.special.erf(jnp.divide(x - 0.707107, 0.282095 * 2.0**0.5))),
"leaky_relu": nn.leaky_relu,
"linear": lambda x: x,
"mish": lambda x: x * nn.tanh(nn.softplus(x)),
"quick_gelu": lambda x: x * nn.sigmoid(1.702 * x),
"relu": nn.relu,
"silu": nn.swish,
"relu2": lambda x: jnp.square(nn.relu(x)),
"relu6": nn.relu6,
"sigmoid": nn.sigmoid,
"silu": nn.silu,
"swish": nn.swish,
"gelu_new": partial(nn.gelu, approximate=True),
"quick_gelu": quick_gelu,
"gelu_pytorch_tanh": partial(nn.gelu, approximate=True),
"tanh": nn.tanh,
}


Expand Down
35 changes: 34 additions & 1 deletion tests/test_modeling_flax_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import transformers
from transformers import is_flax_available, is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import CaptureLogger, is_pt_flax_cross_test, require_flax, torch_device
from transformers.testing_utils import CaptureLogger, is_pt_flax_cross_test, require_flax, require_torch, torch_device
from transformers.utils import CONFIG_NAME, GENERATION_CONFIG_NAME, logging
from transformers.utils.generic import ModelOutput

Expand Down Expand Up @@ -1147,3 +1147,36 @@ def test_gradient_checkpointing(self):
# ensure that the outputs remain precisely equal
for output, remat_output in zip(outputs, remat_outputs):
self.assertTrue((output == remat_output).all())


@require_torch
@require_flax
def test_activation_fns():
# Assuming Torch activation functions of `activations.ACT2FN` as a base, compares Flax implementations
# to produce equal/close results.

import jax.numpy as jnp
import torch

from transformers.activations import ACT2FN as TORCH_ACT2FN
from transformers.modeling_flax_utils import ACT2FN as FLAX_ACT2FN

limit_left = -10.0
limit_right = 10.0
x = np.linspace(limit_left, limit_right, 500)

for fn_name in TORCH_ACT2FN.keys():
if fn_name in FLAX_ACT2FN:
flax_act_fn = FLAX_ACT2FN[fn_name]
torch_act_fn = TORCH_ACT2FN[fn_name]
torch_x = torch.Tensor(x)
jax_x: jnp.ndarray = jnp.float32(x)
torch_y = np.float32(torch_act_fn(torch_x))
flax_y = np.float32(flax_act_fn(jax_x))
np.testing.assert_allclose(
torch_y,
flax_y,
atol=1e-6,
rtol=1e-6,
err_msg=f"ACT2FN '{fn_name}' of torch and flax are not close enough.",
)

0 comments on commit 66223b1

Please sign in to comment.