Skip to content

Commit

Permalink
Simplify FSDP Gradient Clipping (#2586)
Browse files Browse the repository at this point in the history
* simplify gradclip

* remove use 2

* Update composer/algorithms/gradient_clipping/gradient_clipping.py

Co-authored-by: Vitaliy Chiley <[email protected]>

---------

Co-authored-by: Vitaliy Chiley <[email protected]>
  • Loading branch information
mvpatel2000 and vchiley authored Oct 2, 2023
1 parent 6375740 commit 67c7819
Showing 1 changed file with 16 additions and 31 deletions.
47 changes: 16 additions & 31 deletions composer/algorithms/gradient_clipping/gradient_clipping.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from composer.core import Algorithm, Event, State
from composer.loggers import Logger
from composer.models import ComposerModel
from composer.utils import using_torch_2

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -44,38 +43,24 @@ def apply_gradient_clipping(model: Union[ComposerModel, torch.nn.Module], clippi
raise RuntimeError('To use FSDP with Composer, you must use torch>=1.13.0.')
from torch.distributed.fsdp import FullyShardedDataParallel

is_torch_2_0 = using_torch_2()

for module in model.modules():
if isinstance(module, FullyShardedDataParallel):
# We can only call grad clip on the parent instance, so we iterate through all
# modules and try grad clipping and FSDP will throw an exception if we
# clip any gradients that aren't a parent module
try:
if clipping_type == 'norm':
module.clip_grad_norm_(max_norm=clipping_threshold)
elif clipping_type == 'value':
module.clip_grad_norm_(max_norm=clipping_threshold, norm_type=float('inf'))
else:
raise ValueError(f"clipping type must be 'norm' or 'value' with FSDP not {clipping_type}")
except (AssertionError, RuntimeError) as e:
if (('clip_grad_norm should only be called on the root (parent) instance' == str(e) and
not is_torch_2_0) or
('`clip_grad_norm_()` should only be called on the root FSDP instance' == str(e) and
is_torch_2_0)):
continue
else:
raise
return
parameters = model.parameters()
if clipping_type == 'adaptive':
_apply_agc(parameters, clipping_threshold=clipping_threshold)
elif clipping_type == 'norm':
torch.nn.utils.clip_grad_norm_(parameters, max_norm=clipping_threshold)
elif clipping_type == 'value':
torch.nn.utils.clip_grad_value_(parameters, clip_value=clipping_threshold)
if isinstance(module, FullyShardedDataParallel) and module.check_is_root():
if clipping_type == 'norm':
module.clip_grad_norm_(max_norm=clipping_threshold)
elif clipping_type == 'value':
module.clip_grad_norm_(max_norm=clipping_threshold, norm_type=float('inf'))
else:
raise ValueError(f"clipping type must be 'norm' or 'value' with FSDP not {clipping_type}")
else:
raise ValueError(f"clipping_type must be 'adaptive', 'norm', or 'value' not {clipping_type} ")
parameters = model.parameters()
if clipping_type == 'adaptive':
_apply_agc(parameters, clipping_threshold=clipping_threshold)
elif clipping_type == 'norm':
torch.nn.utils.clip_grad_norm_(parameters, max_norm=clipping_threshold)
elif clipping_type == 'value':
torch.nn.utils.clip_grad_value_(parameters, clip_value=clipping_threshold)
else:
raise ValueError(f"clipping_type must be 'adaptive', 'norm', or 'value' not {clipping_type} ")


def _apply_agc(
Expand Down

0 comments on commit 67c7819

Please sign in to comment.