Skip to content

Commit

Permalink
Enable QuickGelu Function for CLIP models (#1408)
Browse files Browse the repository at this point in the history
* enabling quick_gelu fn

* better docformat

* test for act_fn

* fix comments

* changes for pre-commit
  • Loading branch information
gupta-abhay authored Jul 29, 2024
1 parent 6f4aa8c commit 7a7f6df
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 4 deletions.
24 changes: 20 additions & 4 deletions llmfoundry/models/layers/ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,19 @@
}


def quickgelu_activation(input: torch.Tensor) -> torch.Tensor:
"""Applies GELU approximation that is fast but somewhat inaccurate.
Args:
input (torch.Tensor): Input tensor of shape(*), where * means any
number of dimensions
Returns:
torch.Tensor: Tensor with same shape as input tensor
"""
return input * torch.sigmoid(1.702 * input)


def resolve_ffn_act_fn(
config: Optional[dict] = None,
) -> Callable[[torch.Tensor], torch.Tensor]:
Expand All @@ -70,10 +83,13 @@ def resolve_ffn_act_fn(
config = _FFN_ACT_FN_DEFAULT
config = deepcopy(config)
name = config.pop('name')
if not hasattr(torch.nn.functional, name):
raise ValueError(f'Unrecognized activation function name ({name}).')
act = getattr(torch.nn.functional, name)
return partial(act, **config)
if name == 'quick_gelu':
return quickgelu_activation
else:
if not hasattr(torch.nn.functional, name):
raise ValueError(f'Unrecognized activation function name ({name}).')
act = getattr(torch.nn.functional, name)
return partial(act, **config)


_DEFAULT_ACT_FN = resolve_ffn_act_fn(_FFN_ACT_FN_DEFAULT)
Expand Down
73 changes: 73 additions & 0 deletions tests/models/layers/test_ffn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch
import torch.distributed as dist
import torch.nn as nn

from llmfoundry.models.layers.ffn import quickgelu_activation
from llmfoundry.models.layers.layer_builders import build_ffn


@pytest.mark.gpu
def test_quickgelu_activation():
d_model = 32
expansion_ratio = 1
no_bias = True
ffn_config = {
'ffn_act_fn': {
'name': 'quick_gelu',
},
'ffn_type': 'mptmlp',
}
rank: int = dist.get_rank()
device_str = f'cuda:{rank}'
device: torch.device = torch.device(device_str)

ffn1 = build_ffn(
name=ffn_config['ffn_type'],
d_model=d_model,
expansion_ratio=expansion_ratio,
device=device_str,
bias=not no_bias,
ffn_kwargs=ffn_config,
)
assert (
ffn1.act == quickgelu_activation
), f'Expected quick_gelu activation function, got {ffn1.act}'

ffn_config = {
'ffn_act_fn': {
'name': 'gelu',
},
'ffn_type': 'mptmlp',
}
ffn2 = build_ffn(
name=ffn_config['ffn_type'],
d_model=d_model,
expansion_ratio=expansion_ratio,
device=device_str,
bias=not no_bias,
ffn_kwargs=ffn_config,
)

def num_params(model: nn.Module) -> int:
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
return sum([p.numel() for p in model_parameters])

ffn1_numparams = num_params(ffn1)
ffn2_numparams = num_params(ffn2)
assert (
ffn1_numparams == ffn2_numparams
), 'Only activation paths should have changed, re-check modeling!'

input_ = torch.rand(1, d_model, device=device)
output1 = ffn1(input_)
output2 = ffn2(input_)
assert (
output1.numel() == output2.numel()
), 'Only activation paths should have changed, re-check modeling!'
assert (
not torch.allclose(output1, output2)
), 'Functions are different, outputs should not match!'

0 comments on commit 7a7f6df

Please sign in to comment.