Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WSD scheduler #30231

Merged
merged 5 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/main_classes/optimizer_schedules.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ The `.optimization` module provides:

[[autodoc]] get_inverse_sqrt_schedule

[[autodoc]] get_wsd_schedule

### Warmup (TensorFlow)

[[autodoc]] WarmUp
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3911,6 +3911,7 @@
"get_linear_schedule_with_warmup",
"get_polynomial_decay_schedule_with_warmup",
"get_scheduler",
"get_wsd_schedule",
]
_import_structure["pytorch_utils"] = [
"Conv1D",
Expand Down Expand Up @@ -8414,6 +8415,7 @@
get_linear_schedule_with_warmup,
get_polynomial_decay_schedule_with_warmup,
get_scheduler,
get_wsd_schedule,
)
from .pytorch_utils import Conv1D, apply_chunking_to_forward, prune_layer

Expand Down
68 changes: 68 additions & 0 deletions src/transformers/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,73 @@ def get_cosine_with_min_lr_schedule_with_warmup(
return LambdaLR(optimizer, lr_lambda, last_epoch)


def _get_wsd_scheduler_lambda(
current_step: int,
*,
num_warmup_steps: int,
num_stable_steps: int,
num_decay_steps: int,
num_cycles: float,
min_lr_ratio: float,
):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
if current_step < num_warmup_steps + num_stable_steps:
return 1.0
if current_step < num_warmup_steps + num_stable_steps + num_decay_steps:
progress = float(current_step - num_warmup_steps - num_stable_steps) / float(max(1, num_decay_steps))
value = max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
return (1.0 - min_lr_ratio) * value + min_lr_ratio
return min_lr_ratio


def get_wsd_schedule(
optimizer: Optimizer,
num_warmup_steps: int,
num_stable_steps: int,
num_decay_steps: int,
min_lr_ratio: float = 0,
num_cycles: float = 0.5,
last_epoch: int = -1,
):
"""
Create a schedule with a learning rate that has three stages:
1. linear increase from 0 to initial lr.
2. constant lr (equal to initial lr).
3. decrease following the values of the cosine function between the initial lr set in the optimizer to
a fraction of initial lr.

Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
num_stable_steps (`int`):
The number of steps for the stable phase.
num_decay_steps (`int`):
The number of steps for the cosine annealing phase.
min_lr_ratio (`float`, *optional*, defaults to 0):
The minimum learning rate as a ratio of the initial learning rate.
num_cycles (`float`, *optional*, defaults to 0.5):
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
following a half-cosine).
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.

Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
lr_lambda = partial(
_get_wsd_scheduler_lambda,
num_warmup_steps=num_warmup_steps,
num_stable_steps=num_stable_steps,
num_decay_steps=num_decay_steps,
min_lr_ratio=min_lr_ratio,
num_cycles=num_cycles,
)
return LambdaLR(optimizer, lr_lambda, last_epoch)


TYPE_TO_SCHEDULER_FUNCTION = {
SchedulerType.LINEAR: get_linear_schedule_with_warmup,
SchedulerType.COSINE: get_cosine_schedule_with_warmup,
Expand All @@ -397,6 +464,7 @@ def get_cosine_with_min_lr_schedule_with_warmup(
SchedulerType.INVERSE_SQRT: get_inverse_sqrt_schedule,
SchedulerType.REDUCE_ON_PLATEAU: get_reduce_on_plateau_schedule,
SchedulerType.COSINE_WITH_MIN_LR: get_cosine_with_min_lr_schedule_with_warmup,
SchedulerType.WARMUP_STABLE_DECAY: get_wsd_schedule,
}


Expand Down
1 change: 1 addition & 0 deletions src/transformers/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,7 @@ class SchedulerType(ExplicitEnum):
INVERSE_SQRT = "inverse_sqrt"
REDUCE_ON_PLATEAU = "reduce_lr_on_plateau"
COSINE_WITH_MIN_LR = "cosine_with_min_lr"
WARMUP_STABLE_DECAY = "warmup_stable_decay"


class TrainerMemoryTracker:
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -10023,6 +10023,10 @@ def get_scheduler(*args, **kwargs):
requires_backends(get_scheduler, ["torch"])


def get_wsd_schedule(*args, **kwargs):
requires_backends(get_wsd_schedule, ["torch"])


class Conv1D(metaclass=DummyObject):
_backends = ["torch"]

Expand Down
5 changes: 5 additions & 0 deletions tests/optimization/test_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
get_inverse_sqrt_schedule,
get_linear_schedule_with_warmup,
get_polynomial_decay_schedule_with_warmup,
get_wsd_schedule,
)


Expand Down Expand Up @@ -150,6 +151,10 @@ def test_schedulers(self):
{"num_warmup_steps": 2},
[0.0, 5.0, 10.0, 8.165, 7.071, 6.325, 5.774, 5.345, 5.0, 4.714],
),
get_wsd_schedule: (
{"num_warmup_steps": 2, "num_stable_steps": 2, "num_decay_steps": 3, "min_lr_ratio": 0.1},
[0.0, 5.0, 10.0, 10.0, 10.0, 7.75, 3.25, 1.0, 1.0, 1.0],
),
}

for scheduler_func, data in scheds.items():
Expand Down