From 7a7f6df33f9e9938c4b0e82a753ae93d3b43d8c9 Mon Sep 17 00:00:00 2001 From: Abhay Gupta Date: Mon, 29 Jul 2024 16:40:09 -0700 Subject: [PATCH] Enable QuickGelu Function for CLIP models (#1408) * enabling quick_gelu fn * better docformat * test for act_fn * fix comments * changes for pre-commit --- llmfoundry/models/layers/ffn.py | 24 +++++++++-- tests/models/layers/test_ffn.py | 73 +++++++++++++++++++++++++++++++++ 2 files changed, 93 insertions(+), 4 deletions(-) create mode 100644 tests/models/layers/test_ffn.py diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py index a28725ee0f..8028a65a8b 100644 --- a/llmfoundry/models/layers/ffn.py +++ b/llmfoundry/models/layers/ffn.py @@ -53,6 +53,19 @@ } +def quickgelu_activation(input: torch.Tensor) -> torch.Tensor: + """Applies GELU approximation that is fast but somewhat inaccurate. + + Args: + input (torch.Tensor): Input tensor of shape(*), where * means any + number of dimensions + + Returns: + torch.Tensor: Tensor with same shape as input tensor + """ + return input * torch.sigmoid(1.702 * input) + + def resolve_ffn_act_fn( config: Optional[dict] = None, ) -> Callable[[torch.Tensor], torch.Tensor]: @@ -70,10 +83,13 @@ def resolve_ffn_act_fn( config = _FFN_ACT_FN_DEFAULT config = deepcopy(config) name = config.pop('name') - if not hasattr(torch.nn.functional, name): - raise ValueError(f'Unrecognized activation function name ({name}).') - act = getattr(torch.nn.functional, name) - return partial(act, **config) + if name == 'quick_gelu': + return quickgelu_activation + else: + if not hasattr(torch.nn.functional, name): + raise ValueError(f'Unrecognized activation function name ({name}).') + act = getattr(torch.nn.functional, name) + return partial(act, **config) _DEFAULT_ACT_FN = resolve_ffn_act_fn(_FFN_ACT_FN_DEFAULT) diff --git a/tests/models/layers/test_ffn.py b/tests/models/layers/test_ffn.py new file mode 100644 index 0000000000..bb78763f58 --- /dev/null +++ b/tests/models/layers/test_ffn.py @@ -0,0 +1,73 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +import torch.distributed as dist +import torch.nn as nn + +from llmfoundry.models.layers.ffn import quickgelu_activation +from llmfoundry.models.layers.layer_builders import build_ffn + + +@pytest.mark.gpu +def test_quickgelu_activation(): + d_model = 32 + expansion_ratio = 1 + no_bias = True + ffn_config = { + 'ffn_act_fn': { + 'name': 'quick_gelu', + }, + 'ffn_type': 'mptmlp', + } + rank: int = dist.get_rank() + device_str = f'cuda:{rank}' + device: torch.device = torch.device(device_str) + + ffn1 = build_ffn( + name=ffn_config['ffn_type'], + d_model=d_model, + expansion_ratio=expansion_ratio, + device=device_str, + bias=not no_bias, + ffn_kwargs=ffn_config, + ) + assert ( + ffn1.act == quickgelu_activation + ), f'Expected quick_gelu activation function, got {ffn1.act}' + + ffn_config = { + 'ffn_act_fn': { + 'name': 'gelu', + }, + 'ffn_type': 'mptmlp', + } + ffn2 = build_ffn( + name=ffn_config['ffn_type'], + d_model=d_model, + expansion_ratio=expansion_ratio, + device=device_str, + bias=not no_bias, + ffn_kwargs=ffn_config, + ) + + def num_params(model: nn.Module) -> int: + model_parameters = filter(lambda p: p.requires_grad, model.parameters()) + return sum([p.numel() for p in model_parameters]) + + ffn1_numparams = num_params(ffn1) + ffn2_numparams = num_params(ffn2) + assert ( + ffn1_numparams == ffn2_numparams + ), 'Only activation paths should have changed, re-check modeling!' + + input_ = torch.rand(1, d_model, device=device) + output1 = ffn1(input_) + output2 = ffn2(input_) + assert ( + output1.numel() == output2.numel() + ), 'Only activation paths should have changed, re-check modeling!' + assert ( + not torch.allclose(output1, output2) + ), 'Functions are different, outputs should not match!'