From 66223b1e9b94c5a9924985e9c6e21aecc874a8b2 Mon Sep 17 00:00:00 2001 From: m-momeni Date: Wed, 11 Dec 2024 00:30:18 +0330 Subject: [PATCH] Add Flax activation functions and corresponding tests --- src/transformers/modeling_flax_utils.py | 24 +++++++++++------ tests/test_modeling_flax_common.py | 35 ++++++++++++++++++++++++- 2 files changed, 50 insertions(+), 9 deletions(-) diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index dc4a3be732a4f9..2cf7790fbcd436 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -67,18 +67,26 @@ logger = logging.get_logger(__name__) -def quick_gelu(x): - return x * jax.nn.sigmoid(1.702 * x) - - ACT2FN = { "gelu": partial(nn.gelu, approximate=False), + "gelu_10": lambda x: jnp.clip(nn.gelu(x, approximate=False), min=-10, max=10), + "gelu_fast": lambda x: 0.5 * x * (1.0 + nn.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x))), + "gelu_new": lambda x: 0.5 * x * (1.0 + nn.tanh((2.0 / jnp.pi) ** 0.5 * (x + 0.044715 * jnp.pow(x, 3.0)))), + "gelu_python": lambda x: x * 0.5 * (1.0 + jax.scipy.special.erf(x / 2.0**0.5)), + "gelu_pytorch_tanh": partial(nn.gelu, approximate=True), + "gelu_accurate": lambda x: 0.5 * x * (1 + nn.tanh((2 / jnp.pi) ** 0.5 * (x + 0.044715 * jnp.pow(x, 3)))), + "laplace": lambda x: 0.5 * (1.0 + jax.scipy.special.erf(jnp.divide(x - 0.707107, 0.282095 * 2.0**0.5))), + "leaky_relu": nn.leaky_relu, + "linear": lambda x: x, + "mish": lambda x: x * nn.tanh(nn.softplus(x)), + "quick_gelu": lambda x: x * nn.sigmoid(1.702 * x), "relu": nn.relu, - "silu": nn.swish, + "relu2": lambda x: jnp.square(nn.relu(x)), + "relu6": nn.relu6, + "sigmoid": nn.sigmoid, + "silu": nn.silu, "swish": nn.swish, - "gelu_new": partial(nn.gelu, approximate=True), - "quick_gelu": quick_gelu, - "gelu_pytorch_tanh": partial(nn.gelu, approximate=True), + "tanh": nn.tanh, } diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index c7d098be3ea8f2..74ad8654de27b1 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -24,7 +24,7 @@ import transformers from transformers import is_flax_available, is_torch_available from transformers.models.auto import get_values -from transformers.testing_utils import CaptureLogger, is_pt_flax_cross_test, require_flax, torch_device +from transformers.testing_utils import CaptureLogger, is_pt_flax_cross_test, require_flax, require_torch, torch_device from transformers.utils import CONFIG_NAME, GENERATION_CONFIG_NAME, logging from transformers.utils.generic import ModelOutput @@ -1147,3 +1147,36 @@ def test_gradient_checkpointing(self): # ensure that the outputs remain precisely equal for output, remat_output in zip(outputs, remat_outputs): self.assertTrue((output == remat_output).all()) + + +@require_torch +@require_flax +def test_activation_fns(): + # Assuming Torch activation functions of `activations.ACT2FN` as a base, compares Flax implementations + # to produce equal/close results. + + import jax.numpy as jnp + import torch + + from transformers.activations import ACT2FN as TORCH_ACT2FN + from transformers.modeling_flax_utils import ACT2FN as FLAX_ACT2FN + + limit_left = -10.0 + limit_right = 10.0 + x = np.linspace(limit_left, limit_right, 500) + + for fn_name in TORCH_ACT2FN.keys(): + if fn_name in FLAX_ACT2FN: + flax_act_fn = FLAX_ACT2FN[fn_name] + torch_act_fn = TORCH_ACT2FN[fn_name] + torch_x = torch.Tensor(x) + jax_x: jnp.ndarray = jnp.float32(x) + torch_y = np.float32(torch_act_fn(torch_x)) + flax_y = np.float32(flax_act_fn(jax_x)) + np.testing.assert_allclose( + torch_y, + flax_y, + atol=1e-6, + rtol=1e-6, + err_msg=f"ACT2FN '{fn_name}' of torch and flax are not close enough.", + )