Skip to content

Commit

Permalink
Add LR scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
gwilczynski95 committed Dec 30, 2023
1 parent 7ef8071 commit 5ef024c
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 10 deletions.
12 changes: 5 additions & 7 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,12 @@

training_hyperparams = {
"epochs": int(105000 / (_multi_30k_train_samples / batch_size)),
# "optimizer_params": {
# "warmup_steps": 4000,
# "beta_1": 0.9,
# "beta_2": 0.98,
# "eps": 1e-9
# },
"optimizer_params": {
"lr": 1e-3
"lr": 1,
"betas": (0.9, 0.98),
"eps": 1e-9,
"warmup_steps": 4000,
"d_model": model_hyperparams["model_dim"],
},
"checkpoint_path": checkpoint_path
}
17 changes: 14 additions & 3 deletions trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import config
from data import create_masks
from transformer.utils import get_optimizer_and_scheduler
from translate import parse_tokens


Expand All @@ -36,8 +37,11 @@ def __init__(self, model_dir, model, set_loader, device=None):
)

def train(self, epochs, optimizer_params, checkpoint_path=None):
# todo: add proper optimizer
criterion = nn.CrossEntropyLoss(ignore_index=self.set_loader.pad_idx)
optimizer, scheduler = get_optimizer_and_scheduler({
"model": self.model,
**config.training_hyperparams["optimizer_params"]
})
optimizer = optim.Adam(self.model.parameters(), lr=optimizer_params["lr"], betas=(0.9, 0.98), eps=1e-9)
print("Start training")

Expand All @@ -47,6 +51,7 @@ def train(self, epochs, optimizer_params, checkpoint_path=None):
checkpoint = torch.load(checkpoint_path)
self.model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
start_epoch = checkpoint['epoch']

self.model_dir = Path(checkpoint_path).parent.parent
Expand Down Expand Up @@ -75,19 +80,21 @@ def train(self, epochs, optimizer_params, checkpoint_path=None):
loss = criterion(outputs.reshape(-1, outputs.shape[-1]), tgt.reshape(-1))
loss.backward()
optimizer.step()
scheduler.step()

running_loss += loss.item()
_iters += 1

avg_loss = running_loss / _iters
train_status.set_description_str(f"Epoch {epoch}, train loss: {avg_loss}")

self.log_lr(epoch, "Train", optimizer)
self.log_metrics(epoch, "Train", avg_loss)

# Save checkpoint
save_path = Path(self.model_dir, "checkpoints", f"step_{epoch + 1}")
save_path.parent.mkdir(parents=True, exist_ok=True)
self.save_checkpoint(epoch, save_path, optimizer)
self.save_checkpoint(epoch, save_path, optimizer, scheduler)

# Evaluate on validation set if available
if self.val_loader is not None:
Expand Down Expand Up @@ -133,12 +140,13 @@ def calculate_metrics(self, data_loader, loss_fn, temperature=1.):

return running_loss, bleu_res, meteor_average

def save_checkpoint(self, epoch, path, optimizer):
def save_checkpoint(self, epoch, path, optimizer, scheduler):
if path:
torch.save({
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
}, f'{path}_epoch_{epoch}.pt')

@staticmethod
Expand All @@ -157,6 +165,9 @@ def log_metrics(self, epoch, phase, loss, bleu=None, meteor=None):
if meteor is not None:
self.writer.add_scalar(f'{phase}/METEOR', meteor, epoch)

def log_lr(self, epoch, phase, optimizer):
self.writer.add_scalar(f"{phase}/LR", optimizer.param_groups[0]["lr"], epoch)

def save_model(self, path):
torch.save(self.model.state_dict(), path)

Expand Down
25 changes: 25 additions & 0 deletions transformer/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from functools import partial

import torch


def transformer_scheduler(step, d_model, warmup_steps):
step += 1
min_val = min(
step ** (-1 / 2),
step * warmup_steps ** (-3 / 2)
)
return d_model ** (-1 / 2) * min_val


def get_optimizer_and_scheduler(model, d_model, warmup_steps, lr, betas, eps):
optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=betas, eps=eps)
scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer,
lr_lambda=partial(
transformer_scheduler,
d_model=d_model,
warmup_steps=warmup_steps
)
)
return optimizer, scheduler

0 comments on commit 5ef024c

Please sign in to comment.