Skip to content

Commit

Permalink
[NEFTune] Make use of forward hooks instead (#889)
Browse files Browse the repository at this point in the history
* make use of forward hooks

* correctly delete attributes

* address suggestions
  • Loading branch information
younesbelkada authored Oct 24, 2023
1 parent 1f3314f commit 5b2aeca
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 64 deletions.
60 changes: 30 additions & 30 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import inspect
import os
import tempfile
import unittest
Expand Down Expand Up @@ -581,25 +580,27 @@ def test_sft_trainer_with_model_neftune(self):
packing=True,
)

# inspect input embeddings forward code source
input_embeddings_forward_code_source = inspect.getsource(trainer.model.get_input_embeddings().forward)
trainer.model = trainer._activate_neftune(trainer.model)

self.assertTrue(
"mag_norm = self.neftune_noise_alpha / torch.sqrt(dims)" in input_embeddings_forward_code_source
)
device = trainer.model.get_input_embeddings().weight.device
trainer.model.train()

# training should work fine
trainer.train()
torch.random.manual_seed(42)
embeds_neftune = trainer.model.get_input_embeddings()(torch.LongTensor([[1, 0, 1]]).to(device))

# inspect input embeddings forward code source - this time it should not contain any code from NEFTune.
input_embeddings_forward_code_source = inspect.getsource(trainer.model.get_input_embeddings().forward)
torch.random.manual_seed(24)
embeds_neftune_2 = trainer.model.get_input_embeddings()(torch.LongTensor([[1, 0, 1]]).to(device))

self.assertFalse(
"mag_norm = self.neftune_noise_alpha / torch.sqrt(dims)" in input_embeddings_forward_code_source
)
self.assertFalse(torch.allclose(embeds_neftune, embeds_neftune_2))
self.assertTrue(len(trainer.model.get_input_embeddings()._forward_hooks) > 0)

trainer.neftune_hook_handle.remove()

trainer.train()

# Make sure forward pass works fine
_ = trainer.model(torch.LongTensor([[1, 0, 1]]))
_ = trainer.model(torch.LongTensor([[1, 0, 1]]).to(device))
self.assertTrue(len(trainer.model.get_input_embeddings()._forward_hooks) == 0)

@require_peft
def test_peft_sft_trainer(self):
Expand Down Expand Up @@ -673,27 +674,25 @@ def test_peft_sft_trainer_neftune(self):
packing=True,
)

trainer.model = trainer._activate_neftune(trainer.model)

self.assertTrue(isinstance(trainer.model, PeftModel))

# inspect input embeddings forward code source
input_embeddings_forward_code_source = inspect.getsource(
trainer.model.base_model.get_input_embeddings().forward
)
device = trainer.model.get_input_embeddings().weight.device
trainer.model.train()

self.assertTrue(
"mag_norm = self.neftune_noise_alpha / torch.sqrt(dims)" in input_embeddings_forward_code_source
)
torch.random.manual_seed(42)
embeds_neftune = trainer.model.get_input_embeddings()(torch.LongTensor([[1, 0, 1]]).to(device))

trainer.train()
torch.random.manual_seed(24)
embeds_neftune_2 = trainer.model.get_input_embeddings()(torch.LongTensor([[1, 0, 1]]).to(device))

# inspect input embeddings forward code source - this time it should not contain any code from NEFTune.
input_embeddings_forward_code_source = inspect.getsource(
trainer.model.base_model.get_input_embeddings().forward
)
self.assertFalse(torch.allclose(embeds_neftune, embeds_neftune_2))
self.assertTrue(len(trainer.model.get_input_embeddings()._forward_hooks) > 0)

self.assertFalse(
"mag_norm = self.neftune_noise_alpha / torch.sqrt(dims)" in input_embeddings_forward_code_source
)
trainer.neftune_hook_handle.remove()

trainer.train()

self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"])
Expand All @@ -703,4 +702,5 @@ def test_peft_sft_trainer_neftune(self):
self.assertTrue("pytorch_model.bin" not in os.listdir(tmp_dir + "/checkpoint-2"))

# Make sure forward pass works fine to check if embeddings forward is not broken.
_ = trainer.model(torch.LongTensor([[1, 0, 1]]))
_ = trainer.model(torch.LongTensor([[1, 0, 1]]).to(device))
self.assertTrue(len(trainer.model.get_input_embeddings()._forward_hooks) == 0)
35 changes: 15 additions & 20 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@
from transformers.trainer_utils import EvalPrediction

from ..import_utils import is_peft_available
from .utils import ConstantLengthDataset, DataCollatorForCompletionOnlyLM, PeftSavingCallback, neftune_forward
from .utils import (
ConstantLengthDataset,
DataCollatorForCompletionOnlyLM,
PeftSavingCallback,
neftune_post_forward_hook,
)


if is_peft_available():
Expand Down Expand Up @@ -223,9 +228,6 @@ def __init__(
"overflow issues when training a model in half-precision. You might consider adding `tokenizer.padding_side = 'right'` to your code."
)

if self.neftune_noise_alpha is not None:
model = self._activate_neftune(model)

super().__init__(
model=model,
args=args,
Expand All @@ -250,21 +252,22 @@ def __init__(

@wraps(Trainer.train)
def train(self, *args, **kwargs):
# Activate neftune right before training.
if self.neftune_noise_alpha is not None:
self.model = self._activate_neftune(self.model)

output = super().train(*args, **kwargs)

# After training we make sure to retrieve back the original forward pass method
# for the embedding layer
# for the embedding layer by removing the forward post hook.
if self.neftune_noise_alpha is not None:

if isinstance(self.model, PreTrainedModel):
embeddings = self.model.get_input_embeddings()
elif isinstance(self.model, PeftModel):
embeddings = self.model.base_model.get_input_embeddings()

if hasattr(embeddings, "_trl_old_forward"):
embeddings.forward = embeddings._trl_old_forward
del embeddings._trl_old_forward
del embeddings.neftune_noise_alpha
self.neftune_hook_handle.remove()
del embeddings.neftune_noise_alpha

return output

Expand Down Expand Up @@ -361,14 +364,6 @@ def _activate_neftune(self, model):
embeddings = model.base_model.get_input_embeddings()

embeddings.neftune_noise_alpha = self.neftune_noise_alpha
old_forward = embeddings.forward

# This hack seems to be needed to properly use a custom forward pass
# all credits to: https://discuss.pytorch.org/t/how-can-i-replace-the-forward-method-of-a-predefined-torchvision-model-with-my-customized-forward-function/54224/11
bound_method = neftune_forward.__get__(embeddings, embeddings.__class__)
setattr(embeddings, "forward", bound_method)

# embeddings.forward = neftune_forward
embeddings._trl_old_forward = old_forward

hook_handle = embeddings.register_forward_hook(neftune_post_forward_hook)
self.neftune_hook_handle = hook_handle
return model
33 changes: 19 additions & 14 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,25 +739,30 @@ def get_stats(self):
return {k: {"mean": np.mean(v), "std": np.std(v), "count": len(v)} for k, v in self.stats.items()}


def neftune_forward(self, input: torch.Tensor):
def neftune_post_forward_hook(module, input, output):
"""
Implements the NEFTune forward pass for the model. Note this works only for
Implements the NEFTune forward pass for the model using forward hooks. Note this works only for
torch.nn.Embedding layers. This method is slightly adapted from the original source code
that can be found here: https://github.com/neelsjain/NEFTune
Simply add it to your model as follows:
```python
model = ...
model.embed_tokens.neftune_noise_alpha = 0.1
model.embed_tokens.register_forward_hook(neftune_post_forward_hook)
```
Args:
module (`torch.nn.Module`):
The embedding module where the hook is attached. Note that you need to set
`module.neftune_noise_alpha` to the desired noise alpha value.
input (`torch.Tensor`):
The input tensor to the model.
noise_alpha (`float`):
The noise alpha value to use for the NEFTune forward pass.
output (`torch.Tensor`):
The output tensor of the model (i.e. the embeddings).
"""
embeddings = torch.nn.functional.embedding(
input, self.weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse
)

if self.training:
dims = torch.tensor(embeddings.size(1) * embeddings.size(2))
mag_norm = self.neftune_noise_alpha / torch.sqrt(dims)
embeddings = embeddings + torch.zeros_like(embeddings).uniform_(-mag_norm, mag_norm)

return embeddings
if module.training:
dims = torch.tensor(output.size(1) * output.size(2))
mag_norm = module.neftune_noise_alpha / torch.sqrt(dims)
output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm)
return output

0 comments on commit 5b2aeca

Please sign in to comment.