Skip to content

Commit

Permalink
only change lLlama
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker committed Dec 11, 2024
1 parent f14637a commit f446bd4
Show file tree
Hide file tree
Showing 4 changed files with 267 additions and 292 deletions.
46 changes: 45 additions & 1 deletion src/transformers/modeling_flash_attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def _flash_attention_forward(
if not use_top_left_mask:
causal = is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__.
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.mistral.modeling_mistral.MistralFlashAttention2.__init__.
causal = is_causal and query_length != 1

# Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
Expand Down Expand Up @@ -336,3 +336,47 @@ class FlashAttentionKwargs(TypedDict, total=False):
cu_seq_lens_k: Optional[torch.LongTensor]
max_length_q: Optional[int]
max_length_k: Optional[int]


class TransformersKwargs(TypedDict, total=False):
output_attentions: Optional[bool]
output_hidden_states: Optional[bool]
use_cache: Optional[bool]
return_dict: Optional[bool]



from functools import wraps
from typing import Callable, TypedDict, Optional



def validate_config_kwargs(config):
"""
A decorator to validate and initialize kwargs based on a config object.
"""
def decorator(func: Callable):
@wraps(func)
def wrapper(*args, **kwargs):
# Default values from the config
default_kwargs = {
"output_attentions": config.output_attentions,
"output_hidden_states": config.output_hidden_states,
"use_cache": config.use_cache,
"return_dict": config.use_return_dict,
}

# Merge provided kwargs with defaults
validated_kwargs = {**default_kwargs, **kwargs}

# Validate kwargs against TypedDict
for key in validated_kwargs:
if key not in TransformersKwargs.__annotations__:
raise ValueError(f"Invalid keyword argument: {key}")

# Pass the validated kwargs to the function
return func(*args, **validated_kwargs)

return wrapper

return decorator
206 changes: 120 additions & 86 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2534,92 +2534,6 @@ def prune_heads(self, heads_to_prune: Dict[int, List[int]]):

self.base_model._prune_heads(heads_to_prune)

def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
"""
Activates gradient checkpointing for the current model.
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
activations".
We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks of
the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
Args:
gradient_checkpointing_kwargs (dict, *optional*):
Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function.
"""
if not self.supports_gradient_checkpointing:
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")

if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {"use_reentrant": True}

gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs)

# For old GC format (transformers < 4.35.0) for models that live on the Hub
# we will fall back to the overwritten `_set_gradient_checkpointing` method
_is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters

if not _is_using_old_format:
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
else:
self.apply(partial(self._set_gradient_checkpointing, value=True))
logger.warning(
"You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)."
"Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model."
)

if getattr(self, "_hf_peft_config_loaded", False):
# When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
# we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
# When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate
# the gradients to make sure the gradient flows.
self.enable_input_require_grads()

def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func: Callable = checkpoint):
is_gradient_checkpointing_set = False

# Apply it on the top-level module in case the top-level modules supports it
# for example, LongT5Stack inherits from `PreTrainedModel`.
if hasattr(self, "gradient_checkpointing"):
self._gradient_checkpointing_func = gradient_checkpointing_func
self.gradient_checkpointing = enable
is_gradient_checkpointing_set = True

for module in self.modules():
if hasattr(module, "gradient_checkpointing"):
module._gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = enable
is_gradient_checkpointing_set = True

if not is_gradient_checkpointing_set:
raise ValueError(
f"{self.__class__.__name__} is not compatible with gradient checkpointing. Make sure all the architecture support it by setting a boolean attribute"
" `gradient_checkpointing` to modules of the model that uses checkpointing."
)

def gradient_checkpointing_disable(self):
"""
Deactivates gradient checkpointing for the current model.
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
activations".
"""
if self.supports_gradient_checkpointing:
# For old GC format (transformers < 4.35.0) for models that live on the Hub
# we will fall back to the overwritten `_set_gradient_checkpointing` methid
_is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters
if not _is_using_old_format:
self._set_gradient_checkpointing(enable=False)
else:
logger.warning(
"You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)."
"Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model."
)
self.apply(partial(self._set_gradient_checkpointing, value=False))

if getattr(self, "_hf_peft_config_loaded", False):
self.disable_input_require_grads()

@property
def is_gradient_checkpointing(self) -> bool:
Expand Down Expand Up @@ -5568,3 +5482,123 @@ def get_disk_only_shard_files(device_map, sharded_metadata, start_prefix):
files_content[filename].append(device_map[weight_name])

return [fname for fname, devices in files_content.items() if set(devices) == {"disk"}]


class GradientCheckpointLayer(torch.nn.Module):
def __call__(self, *args, **kwargs):
"""
Adjust the behavior of the inherited class by overriding `__call__`.
Automatically handles gradient checkpointing based on flags in the provided arguments.
"""
# Extract necessary flags and arguments
gradient_checkpointing = kwargs.pop("gradient_checkpointing", False)
training = self.training

if gradient_checkpointing and training:
# Use gradient checkpointing
return self._apply_gradient_checkpointing(*args, **kwargs)
else:
# Default behavior: call the original `forward` method
return super().__call__(*args, **kwargs)

def _apply_gradient_checkpointing(self, *args, **kwargs):
"""
Apply gradient checkpointing using the appropriate function.
By default, uses `torch.utils.checkpoint.checkpoint`.
"""
# Assume `self.forward` is compatible with checkpointing
return checkpoint(self.__call__, *args, **kwargs)



def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
"""
Activates gradient checkpointing for the current model.
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
activations".
We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks of
the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
Args:
gradient_checkpointing_kwargs (dict, *optional*):
Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function.
"""
if not self.supports_gradient_checkpointing:
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")

if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {"use_reentrant": True}

gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs)

# For old GC format (transformers < 4.35.0) for models that live on the Hub
# we will fall back to the overwritten `_set_gradient_checkpointing` method
_is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters

if not _is_using_old_format:
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
else:
self.apply(partial(self._set_gradient_checkpointing, value=True))
logger.warning(
"You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)."
"Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model."
)

if getattr(self, "_hf_peft_config_loaded", False):
# When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
# we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
# When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate
# the gradients to make sure the gradient flows.
self.enable_input_require_grads()

def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func: Callable = checkpoint):
is_gradient_checkpointing_set = False

# Apply it on the top-level module in case the top-level modules supports it
# for example, LongT5Stack inherits from `PreTrainedModel`.
if hasattr(self, "gradient_checkpointing"):
self._gradient_checkpointing_func = gradient_checkpointing_func
self.gradient_checkpointing = enable
is_gradient_checkpointing_set = True

for module in self.modules():
if hasattr(module, "gradient_checkpointing"):
module._gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = enable
is_gradient_checkpointing_set = True

if not is_gradient_checkpointing_set:
raise ValueError(
f"{self.__class__.__name__} is not compatible with gradient checkpointing. Make sure all the architecture support it by setting a boolean attribute"
" `gradient_checkpointing` to modules of the model that uses checkpointing."
)

def gradient_checkpointing_disable(self):
"""
Deactivates gradient checkpointing for the current model.
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
activations".
"""
if self.supports_gradient_checkpointing:
# For old GC format (transformers < 4.35.0) for models that live on the Hub
# we will fall back to the overwritten `_set_gradient_checkpointing` methid
_is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters
if not _is_using_old_format:
self._set_gradient_checkpointing(enable=False)
else:
logger.warning(
"You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)."
"Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model."
)
self.apply(partial(self._set_gradient_checkpointing, value=False))

if getattr(self, "_hf_peft_config_loaded", False):
self.disable_input_require_grads()


ALL_ATTENTION_FUNCTIONS: Dict[str, Dict[str, function]] = {}
Loading

0 comments on commit f446bd4

Please sign in to comment.