From 7b1170b0faa49a33322415a2f6f8b398fd8fbcc7 Mon Sep 17 00:00:00 2001 From: Alexander Visheratin Date: Thu, 25 Apr 2024 07:07:21 -0400 Subject: [PATCH] Add WSD scheduler (#30231) * Added WSD scheduler. * Added tests. * Fixed errors. * Fix formatting. * CI fixes. --- .../en/main_classes/optimizer_schedules.md | 2 + src/transformers/__init__.py | 2 + src/transformers/optimization.py | 68 +++++++++++++++++++ src/transformers/trainer_utils.py | 1 + src/transformers/utils/dummy_pt_objects.py | 4 ++ tests/optimization/test_optimization.py | 5 ++ 6 files changed, 82 insertions(+) diff --git a/docs/source/en/main_classes/optimizer_schedules.md b/docs/source/en/main_classes/optimizer_schedules.md index dfcab9e91465a3..e75306408f8665 100644 --- a/docs/source/en/main_classes/optimizer_schedules.md +++ b/docs/source/en/main_classes/optimizer_schedules.md @@ -66,6 +66,8 @@ The `.optimization` module provides: [[autodoc]] get_inverse_sqrt_schedule +[[autodoc]] get_wsd_schedule + ### Warmup (TensorFlow) [[autodoc]] WarmUp diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index a65ed489d9506b..083c7f031ac6cc 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -3911,6 +3911,7 @@ "get_linear_schedule_with_warmup", "get_polynomial_decay_schedule_with_warmup", "get_scheduler", + "get_wsd_schedule", ] _import_structure["pytorch_utils"] = [ "Conv1D", @@ -8414,6 +8415,7 @@ get_linear_schedule_with_warmup, get_polynomial_decay_schedule_with_warmup, get_scheduler, + get_wsd_schedule, ) from .pytorch_utils import Conv1D, apply_chunking_to_forward, prune_layer diff --git a/src/transformers/optimization.py b/src/transformers/optimization.py index 1ee2f41d2f91d1..79a2c71c384f30 100644 --- a/src/transformers/optimization.py +++ b/src/transformers/optimization.py @@ -387,6 +387,73 @@ def get_cosine_with_min_lr_schedule_with_warmup( return LambdaLR(optimizer, lr_lambda, last_epoch) +def _get_wsd_scheduler_lambda( + current_step: int, + *, + num_warmup_steps: int, + num_stable_steps: int, + num_decay_steps: int, + num_cycles: float, + min_lr_ratio: float, +): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + if current_step < num_warmup_steps + num_stable_steps: + return 1.0 + if current_step < num_warmup_steps + num_stable_steps + num_decay_steps: + progress = float(current_step - num_warmup_steps - num_stable_steps) / float(max(1, num_decay_steps)) + value = max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) + return (1.0 - min_lr_ratio) * value + min_lr_ratio + return min_lr_ratio + + +def get_wsd_schedule( + optimizer: Optimizer, + num_warmup_steps: int, + num_stable_steps: int, + num_decay_steps: int, + min_lr_ratio: float = 0, + num_cycles: float = 0.5, + last_epoch: int = -1, +): + """ + Create a schedule with a learning rate that has three stages: + 1. linear increase from 0 to initial lr. + 2. constant lr (equal to initial lr). + 3. decrease following the values of the cosine function between the initial lr set in the optimizer to + a fraction of initial lr. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_stable_steps (`int`): + The number of steps for the stable phase. + num_decay_steps (`int`): + The number of steps for the cosine annealing phase. + min_lr_ratio (`float`, *optional*, defaults to 0): + The minimum learning rate as a ratio of the initial learning rate. + num_cycles (`float`, *optional*, defaults to 0.5): + The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 + following a half-cosine). + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + lr_lambda = partial( + _get_wsd_scheduler_lambda, + num_warmup_steps=num_warmup_steps, + num_stable_steps=num_stable_steps, + num_decay_steps=num_decay_steps, + min_lr_ratio=min_lr_ratio, + num_cycles=num_cycles, + ) + return LambdaLR(optimizer, lr_lambda, last_epoch) + + TYPE_TO_SCHEDULER_FUNCTION = { SchedulerType.LINEAR: get_linear_schedule_with_warmup, SchedulerType.COSINE: get_cosine_schedule_with_warmup, @@ -397,6 +464,7 @@ def get_cosine_with_min_lr_schedule_with_warmup( SchedulerType.INVERSE_SQRT: get_inverse_sqrt_schedule, SchedulerType.REDUCE_ON_PLATEAU: get_reduce_on_plateau_schedule, SchedulerType.COSINE_WITH_MIN_LR: get_cosine_with_min_lr_schedule_with_warmup, + SchedulerType.WARMUP_STABLE_DECAY: get_wsd_schedule, } diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index 5c57ce0696f634..d17113091777be 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -412,6 +412,7 @@ class SchedulerType(ExplicitEnum): INVERSE_SQRT = "inverse_sqrt" REDUCE_ON_PLATEAU = "reduce_lr_on_plateau" COSINE_WITH_MIN_LR = "cosine_with_min_lr" + WARMUP_STABLE_DECAY = "warmup_stable_decay" class TrainerMemoryTracker: diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 8166c9d24297aa..f91bbbe4fc44c0 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -10023,6 +10023,10 @@ def get_scheduler(*args, **kwargs): requires_backends(get_scheduler, ["torch"]) +def get_wsd_schedule(*args, **kwargs): + requires_backends(get_wsd_schedule, ["torch"]) + + class Conv1D(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/optimization/test_optimization.py b/tests/optimization/test_optimization.py index 0ee8513dacde6a..6d6707db5a4b67 100644 --- a/tests/optimization/test_optimization.py +++ b/tests/optimization/test_optimization.py @@ -36,6 +36,7 @@ get_inverse_sqrt_schedule, get_linear_schedule_with_warmup, get_polynomial_decay_schedule_with_warmup, + get_wsd_schedule, ) @@ -150,6 +151,10 @@ def test_schedulers(self): {"num_warmup_steps": 2}, [0.0, 5.0, 10.0, 8.165, 7.071, 6.325, 5.774, 5.345, 5.0, 4.714], ), + get_wsd_schedule: ( + {"num_warmup_steps": 2, "num_stable_steps": 2, "num_decay_steps": 3, "min_lr_ratio": 0.1}, + [0.0, 5.0, 10.0, 10.0, 10.0, 7.75, 3.25, 1.0, 1.0, 1.0], + ), } for scheduler_func, data in scheds.items():