Skip to content

Commit

Permalink
add warmup for soundstream as well as all discriminators
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 11, 2023
1 parent 31b0b7e commit 01f0008
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 6 deletions.
89 changes: 84 additions & 5 deletions audiolm_pytorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from collections import Counter
from contextlib import contextmanager, nullcontext

from beartype.typing import Union, List, Optional, Tuple
from beartype.typing import Union, List, Optional, Tuple, Type
from typing_extensions import Annotated

from beartype import beartype
Expand All @@ -19,8 +19,12 @@
import torch
import torchaudio
from torch import nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, _LRScheduler
from torch.utils.data import Dataset, DataLoader, random_split

import pytorch_warmup as warmup

from einops import rearrange

from audiolm_pytorch.optimizer import get_optimizer
Expand Down Expand Up @@ -55,6 +59,8 @@

DEFAULT_SAMPLE_RATE = 16000

ConstantLRScheduler = partial(LambdaLR, lr_lambda = lambda step: 1.)

# make sure only one trainer is instantiated

ONE_TRAINER_INSTANTIATED = False
Expand Down Expand Up @@ -152,6 +158,53 @@ def checkpoint_num_steps(checkpoint_path):

return int(results[-1])

# optimizer with scheduler + warmup

class OptimizerWithWarmupSchedule(nn.Module):
@beartype
def __init__(
self,
accelerator: Accelerator,
optimizer: Optimizer,
scheduler: Optional[Type[_LRScheduler]] = None,
scheduler_kwargs: dict = dict(),
warmup_steps: int = 0
):
super().__init__()
self.warmup = warmup.LinearWarmup(optimizer, warmup_period = warmup_steps)

if exists(scheduler):
self.scheduler = scheduler(optimizer, **scheduler_kwargs)
else:
self.scheduler = ConstantLRScheduler(optimizer)

self.optimizer = optimizer

self.optimizer, self.scheduler = accelerator.prepare(self.optimizer, self.scheduler)
self.accelerator = accelerator

def state_dict(self):
return dict(
optimizer = self.optimizer.state_dict(),
scheduler = self.scheduler.state_dict(),
warmup = self.warmup.state_dict()
)

def load_state_dict(self, pkg):
self.optimizer.load_state_dict(pkg['optimizer'])
self.scheduler.load_state_dict(pkg['scheduler'])
self.warmup.load_state_dict(pkg['warmup'])

def zero_grad(self):
self.optimizer.zero_grad()

def step(self):
self.optimizer.step()

if not self.accelerator.optimizer_step_was_skipped:
with self.warmup.dampening():
self.scheduler.step()

# main trainer class

class SoundStreamTrainer(nn.Module):
Expand All @@ -172,6 +225,12 @@ def __init__(
lr: float = 2e-4,
grad_accum_every: int = 4,
wd: float = 0.,
warmup_steps: int = 1000,
scheduler: Optional[Type[_LRScheduler]] = None,
scheduler_kwargs: dict = dict(),
discr_warmup_steps: Optional[int] = None,
discr_scheduler: Optional[Type[_LRScheduler]] = None,
discr_scheduler_kwargs: dict = dict(),
max_grad_norm: float = 0.5,
discr_max_grad_norm: float = None,
save_results_every: int = 100,
Expand Down Expand Up @@ -240,13 +299,33 @@ def __init__(

# optimizers

self.optim = get_optimizer(soundstream.non_discr_parameters(), lr = lr, wd = wd)
self.optim = OptimizerWithWarmupSchedule(
self.accelerator,
get_optimizer(soundstream.non_discr_parameters(), lr = lr, wd = wd),
scheduler = scheduler,
scheduler_kwargs = scheduler_kwargs,
warmup_steps = warmup_steps
)

discr_warmup_steps = default(discr_warmup_steps, warmup_steps)

for discr_optimizer_key, discr in self.multiscale_discriminator_iter():
one_multiscale_discr_optimizer = get_optimizer(discr.parameters(), lr = lr, wd = wd)
one_multiscale_discr_optimizer = OptimizerWithWarmupSchedule(
self.accelerator,
get_optimizer(discr.parameters(), lr = lr, wd = wd),
scheduler = discr_scheduler,
scheduler_kwargs = discr_scheduler_kwargs,
warmup_steps = discr_warmup_steps
)
setattr(self, discr_optimizer_key, one_multiscale_discr_optimizer)

self.discr_optim = get_optimizer(soundstream.stft_discriminator.parameters(), lr = lr, wd = wd)
self.discr_optim = OptimizerWithWarmupSchedule(
self.accelerator,
get_optimizer(soundstream.stft_discriminator.parameters(), lr = lr, wd = wd),
scheduler = discr_scheduler,
scheduler_kwargs = discr_scheduler_kwargs,
warmup_steps = discr_warmup_steps
)

# max grad norm

Expand Down Expand Up @@ -596,6 +675,7 @@ def train_step(self):

for model, label in models:
model.eval()
model = model.to(device)

with torch.inference_mode():
recons = model(wave, return_recons_only = True)
Expand Down Expand Up @@ -1064,7 +1144,6 @@ def load(self, path):
# + 1 to start from the next step and avoid overwriting the last checkpoint
self.steps = torch.tensor([checkpoint_num_steps(path) + 1], device=self.device)


def print(self, msg):
self.accelerator.print(msg)

Expand Down
2 changes: 1 addition & 1 deletion audiolm_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.8.7'
__version__ = '1.9.0'
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
'gateloop-transformer>=0.0.24',
'joblib',
'local-attention>=1.9.0',
'pytorch-warmup',
'scikit-learn',
'sentencepiece',
'torch>=1.12',
Expand Down

0 comments on commit 01f0008

Please sign in to comment.