Skip to content

Commit

Permalink
Fix formatting.
Browse files Browse the repository at this point in the history
  • Loading branch information
visheratin committed Apr 13, 2024
1 parent f1c5a1e commit fbc7877
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions src/transformers/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,8 +386,15 @@ def get_cosine_with_min_lr_schedule_with_warmup(
)
return LambdaLR(optimizer, lr_lambda, last_epoch)


def _get_wsd_scheduler_lambda(
current_step: int, *, num_warmup_steps: int, num_stable_steps: int, num_decay_steps: int, num_cycles: float, min_lr_ratio: float
current_step: int,
*,
num_warmup_steps: int,
num_stable_steps: int,
num_decay_steps: int,
num_cycles: float,
min_lr_ratio: float,
):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
Expand All @@ -399,14 +406,15 @@ def _get_wsd_scheduler_lambda(
return (1.0 - min_lr_ratio) * value + min_lr_ratio
return min_lr_ratio


def get_wsd_schedule(
optimizer: Optimizer,
num_warmup_steps: int,
num_stable_steps: int,
num_decay_steps: int,
min_lr_ratio: float = 0,
num_cycles: float = 0.5,
last_epoch: int = -1
last_epoch: int = -1,
):
"""
Create a schedule with a learning rate that has three stages:
Expand Down

0 comments on commit fbc7877

Please sign in to comment.