From c08c0a6574761370a50cb7c22ff1bc6d97c59dd8 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 29 Jul 2024 19:21:08 +0000 Subject: [PATCH 1/5] enabling quick_gelu fn --- llmfoundry/models/layers/ffn.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py index a28725ee0f..5b7ae86f4d 100644 --- a/llmfoundry/models/layers/ffn.py +++ b/llmfoundry/models/layers/ffn.py @@ -53,6 +53,14 @@ } +def quickgelu_activation(input: torch.Tensor) -> torch.Tensor: + """ + Applies GELU approximation that is fast but somewhat inaccurate. + See: https://github.com/hendrycks/GELUs + """ + return input * torch.sigmoid(1.702 * input) + + def resolve_ffn_act_fn( config: Optional[dict] = None, ) -> Callable[[torch.Tensor], torch.Tensor]: @@ -70,10 +78,13 @@ def resolve_ffn_act_fn( config = _FFN_ACT_FN_DEFAULT config = deepcopy(config) name = config.pop('name') - if not hasattr(torch.nn.functional, name): - raise ValueError(f'Unrecognized activation function name ({name}).') - act = getattr(torch.nn.functional, name) - return partial(act, **config) + if name == 'quick_gelu': + return quickgelu_activation + else: + if not hasattr(torch.nn.functional, name): + raise ValueError(f'Unrecognized activation function name ({name}).') + act = getattr(torch.nn.functional, name) + return partial(act, **config) _DEFAULT_ACT_FN = resolve_ffn_act_fn(_FFN_ACT_FN_DEFAULT) From 2de0ffc699aba8f811767cca41a88af5a7a600f5 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 29 Jul 2024 19:46:57 +0000 Subject: [PATCH 2/5] better docformat --- llmfoundry/models/layers/ffn.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py index 5b7ae86f4d..8028a65a8b 100644 --- a/llmfoundry/models/layers/ffn.py +++ b/llmfoundry/models/layers/ffn.py @@ -54,9 +54,14 @@ def quickgelu_activation(input: torch.Tensor) -> torch.Tensor: - """ - Applies GELU approximation that is fast but somewhat inaccurate. - See: https://github.com/hendrycks/GELUs + """Applies GELU approximation that is fast but somewhat inaccurate. + + Args: + input (torch.Tensor): Input tensor of shape(*), where * means any + number of dimensions + + Returns: + torch.Tensor: Tensor with same shape as input tensor """ return input * torch.sigmoid(1.702 * input) From d142aa73bf5f00ce0294fd1024a4e4dc10731630 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 29 Jul 2024 22:17:08 +0000 Subject: [PATCH 3/5] test for act_fn --- tests/models/layers/test_ffn.py | 62 +++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 tests/models/layers/test_ffn.py diff --git a/tests/models/layers/test_ffn.py b/tests/models/layers/test_ffn.py new file mode 100644 index 0000000000..4bd28de169 --- /dev/null +++ b/tests/models/layers/test_ffn.py @@ -0,0 +1,62 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +import torch.nn as nn +import torch.distributed as dist + +from llmfoundry.models.layers.layer_builders import build_ffn + + +@pytest.mark.gpu +def test_quickgelu_activation(): + d_model = 32 + expansion_ratio = 1 + no_bias = True + ffn_config={ + 'ffn_act_fn': { + 'name': 'quick_gelu', + }, + 'ffn_type': 'mptmlp', + } + rank: int = dist.get_rank() + device: torch.device = torch.device(f'cuda:{rank}') + + ffn1 = build_ffn( + name=ffn_config['ffn_type'], + d_model=d_model, + expansion_ratio=expansion_ratio, + device=device, + bias=not no_bias, + ffn_kwargs=ffn_config, + ) + + ffn_config={ + 'ffn_act_fn': { + 'name': 'gelu', + }, + 'ffn_type': 'mptmlp', + } + ffn2 = build_ffn( + name=ffn_config['ffn_type'], + d_model=d_model, + expansion_ratio=expansion_ratio, + device=device, + bias=not no_bias, + ffn_kwargs=ffn_config, + ) + + def num_params(model: nn.Module) -> int: + model_parameters = filter(lambda p: p.requires_grad, model.parameters()) + return sum([p.numel() for p in model_parameters]) + + ffn1_numparams = num_params(ffn1) + ffn2_numparams = num_params(ffn2) + assert ffn1_numparams == ffn2_numparams, "Only activation paths should have changed, re-check modeling!" + + input_ = torch.rand(1, d_model, device=device) + output1 = ffn1(input_) + output2 = ffn2(input_) + assert output1.numel() == output2.numel(), "Only activation paths should have changed, re-check modeling!" + assert not torch.allclose(output1, output2) From 3c8471d5eaab81cac12e287bc5259a503c125625 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 29 Jul 2024 22:29:56 +0000 Subject: [PATCH 4/5] fix comments --- tests/models/layers/test_ffn.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/tests/models/layers/test_ffn.py b/tests/models/layers/test_ffn.py index 4bd28de169..d6098bc80c 100644 --- a/tests/models/layers/test_ffn.py +++ b/tests/models/layers/test_ffn.py @@ -7,6 +7,7 @@ import torch.distributed as dist from llmfoundry.models.layers.layer_builders import build_ffn +from llmfoundry.models.layers.ffn import quickgelu_activation @pytest.mark.gpu @@ -31,6 +32,9 @@ def test_quickgelu_activation(): bias=not no_bias, ffn_kwargs=ffn_config, ) + assert ( + ffn1.act == quickgelu_activation + ), f"Expected quick_gelu activation function, got {ffn1.act}" ffn_config={ 'ffn_act_fn': { @@ -53,10 +57,16 @@ def num_params(model: nn.Module) -> int: ffn1_numparams = num_params(ffn1) ffn2_numparams = num_params(ffn2) - assert ffn1_numparams == ffn2_numparams, "Only activation paths should have changed, re-check modeling!" + assert ( + ffn1_numparams == ffn2_numparams + ), "Only activation paths should have changed, re-check modeling!" input_ = torch.rand(1, d_model, device=device) output1 = ffn1(input_) output2 = ffn2(input_) - assert output1.numel() == output2.numel(), "Only activation paths should have changed, re-check modeling!" - assert not torch.allclose(output1, output2) + assert ( + output1.numel() == output2.numel() + ), "Only activation paths should have changed, re-check modeling!" + assert ( + not torch.allclose(output1, output2) + ), "Functions are different, outputs should not match!" From 150016484ac4be3b34e873f6bc1e49ce2638c58f Mon Sep 17 00:00:00 2001 From: root Date: Mon, 29 Jul 2024 23:20:47 +0000 Subject: [PATCH 5/5] changes for pre-commit --- tests/models/layers/test_ffn.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/tests/models/layers/test_ffn.py b/tests/models/layers/test_ffn.py index d6098bc80c..bb78763f58 100644 --- a/tests/models/layers/test_ffn.py +++ b/tests/models/layers/test_ffn.py @@ -3,11 +3,11 @@ import pytest import torch -import torch.nn as nn import torch.distributed as dist +import torch.nn as nn -from llmfoundry.models.layers.layer_builders import build_ffn from llmfoundry.models.layers.ffn import quickgelu_activation +from llmfoundry.models.layers.layer_builders import build_ffn @pytest.mark.gpu @@ -15,28 +15,29 @@ def test_quickgelu_activation(): d_model = 32 expansion_ratio = 1 no_bias = True - ffn_config={ + ffn_config = { 'ffn_act_fn': { 'name': 'quick_gelu', }, 'ffn_type': 'mptmlp', } rank: int = dist.get_rank() - device: torch.device = torch.device(f'cuda:{rank}') + device_str = f'cuda:{rank}' + device: torch.device = torch.device(device_str) ffn1 = build_ffn( name=ffn_config['ffn_type'], d_model=d_model, expansion_ratio=expansion_ratio, - device=device, + device=device_str, bias=not no_bias, ffn_kwargs=ffn_config, ) assert ( ffn1.act == quickgelu_activation - ), f"Expected quick_gelu activation function, got {ffn1.act}" + ), f'Expected quick_gelu activation function, got {ffn1.act}' - ffn_config={ + ffn_config = { 'ffn_act_fn': { 'name': 'gelu', }, @@ -46,7 +47,7 @@ def test_quickgelu_activation(): name=ffn_config['ffn_type'], d_model=d_model, expansion_ratio=expansion_ratio, - device=device, + device=device_str, bias=not no_bias, ffn_kwargs=ffn_config, ) @@ -59,14 +60,14 @@ def num_params(model: nn.Module) -> int: ffn2_numparams = num_params(ffn2) assert ( ffn1_numparams == ffn2_numparams - ), "Only activation paths should have changed, re-check modeling!" + ), 'Only activation paths should have changed, re-check modeling!' input_ = torch.rand(1, d_model, device=device) output1 = ffn1(input_) output2 = ffn2(input_) assert ( output1.numel() == output2.numel() - ), "Only activation paths should have changed, re-check modeling!" + ), 'Only activation paths should have changed, re-check modeling!' assert ( not torch.allclose(output1, output2) - ), "Functions are different, outputs should not match!" + ), 'Functions are different, outputs should not match!'