Skip to content

Commit

Permalink
fix GC use_reentrant
Browse files Browse the repository at this point in the history
  • Loading branch information
younesbelkada committed Oct 18, 2023
1 parent 5a73316 commit 078f1f3
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 3 deletions.
18 changes: 16 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1819,7 +1819,7 @@ def prune_heads(self, heads_to_prune: Dict[int, List[int]]):

self.base_model._prune_heads(heads_to_prune)

def gradient_checkpointing_enable(self):
def gradient_checkpointing_enable(self, use_reentrant: bool = True) -> None:
"""
Activates gradient checkpointing for the current model.
Expand All @@ -1828,7 +1828,21 @@ def gradient_checkpointing_enable(self):
"""
if not self.supports_gradient_checkpointing:
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
self.apply(partial(self._set_gradient_checkpointing, value=True))

_supports_use_reentrant = "use_reentrant" in list(
inspect.signature(self._set_gradient_checkpointing).parameters
)
gc_kwargs = {}

if not _supports_use_reentrant and not use_reentrant:
logger.warn(
f"{self.__class__.__name__} does not support the use_reentrant argument. The argument will be ignored."
" Please raise an issue on GitHub to support this argument if needed."
)
elif _supports_use_reentrant and not use_reentrant:
gc_kwargs["use_reentrant"] = use_reentrant

self.apply(partial(self._set_gradient_checkpointing, value=True, **gc_kwargs))

if getattr(self, "_hf_peft_config_loaded", False):
# When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
Expand Down
12 changes: 11 additions & 1 deletion src/transformers/models/opt/modeling_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch OPT model."""
import inspect
from typing import List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -411,9 +412,10 @@ def _init_weights(self, module):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()

def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, value=False, use_reentrant=True):
if isinstance(module, (OPTDecoder)):
module.gradient_checkpointing = value
module.gradient_checkpointing_use_reentrant = use_reentrant


OPT_INPUTS_DOCSTRING = r"""
Expand Down Expand Up @@ -520,6 +522,8 @@ def __init__(self, config: OPTConfig):
self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])

self.gradient_checkpointing = False
# Use the default value
self.gradient_checkpointing_use_reentrant = True
# Initialize weights and apply final processing
self.post_init()

Expand Down Expand Up @@ -699,12 +703,18 @@ def custom_forward(*inputs):

return custom_forward

kwargs = {}

if "use_reentrant" in list(inspect.signature(torch.utils.checkpoint.checkpoint).parameters):
kwargs["use_reentrant"] = self.gradient_checkpointing_use_reentrant

layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
causal_attention_mask,
head_mask[idx] if head_mask is not None else None,
None,
**kwargs,
)
else:
layer_outputs = decoder_layer(
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,9 @@ class TrainingArguments:
Unless this is `True`, the `Trainer` will skip pushing a checkpoint when the previous push is not finished.
gradient_checkpointing (`bool`, *optional*, defaults to `False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
gradient_checkpointing_use_reentrant (`bool`, *optional*, defaults to `True`):
If `False` use `use_reentrant=False` when calling gradient checkpointing as recommended per PyTorch
documentation (can fix some bugs and unexpected behaviours for distributed training).
include_inputs_for_metrics (`bool`, *optional*, defaults to `False`):
Whether or not the inputs will be passed to the `compute_metrics` function. This is intended for metrics
that need inputs, predictions and references for scoring calculation in Metric class.
Expand Down Expand Up @@ -1119,6 +1122,10 @@ class TrainingArguments:
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
},
)
gradient_checkpointing_use_reentrant: bool = field(
default=True,
metadata={"help": "If False use `use_reentrant=False` as recommended per PyTorch documentation."},
)
include_inputs_for_metrics: bool = field(
default=False, metadata={"help": "Whether or not the inputs will be passed to the `compute_metrics` function."}
)
Expand Down Expand Up @@ -2102,6 +2109,7 @@ def set_training(
gradient_accumulation_steps: int = 1,
seed: int = 42,
gradient_checkpointing: bool = False,
gradient_checkpointing_use_reentrant: bool = True,
):
"""
A method that regroups all basic arguments linked to the training.
Expand Down Expand Up @@ -2165,6 +2173,7 @@ def set_training(
self.gradient_accumulation_steps = gradient_accumulation_steps
self.seed = seed
self.gradient_checkpointing = gradient_checkpointing
self.gradient_checkpointing_use_reentrant = gradient_checkpointing_use_reentrant
return self

def set_evaluate(
Expand Down

0 comments on commit 078f1f3

Please sign in to comment.