diff --git a/config.py b/config.py index 05820db..25ede73 100644 --- a/config.py +++ b/config.py @@ -15,14 +15,12 @@ training_hyperparams = { "epochs": int(105000 / (_multi_30k_train_samples / batch_size)), - # "optimizer_params": { - # "warmup_steps": 4000, - # "beta_1": 0.9, - # "beta_2": 0.98, - # "eps": 1e-9 - # }, "optimizer_params": { - "lr": 1e-3 + "lr": 1, + "betas": (0.9, 0.98), + "eps": 1e-9, + "warmup_steps": 4000, + "d_model": model_hyperparams["model_dim"], }, "checkpoint_path": checkpoint_path } diff --git a/trainer.py b/trainer.py index 376516c..a5f584c 100644 --- a/trainer.py +++ b/trainer.py @@ -11,6 +11,7 @@ import config from data import create_masks +from transformer.utils import get_optimizer_and_scheduler from translate import parse_tokens @@ -36,8 +37,11 @@ def __init__(self, model_dir, model, set_loader, device=None): ) def train(self, epochs, optimizer_params, checkpoint_path=None): - # todo: add proper optimizer criterion = nn.CrossEntropyLoss(ignore_index=self.set_loader.pad_idx) + optimizer, scheduler = get_optimizer_and_scheduler({ + "model": self.model, + **config.training_hyperparams["optimizer_params"] + }) optimizer = optim.Adam(self.model.parameters(), lr=optimizer_params["lr"], betas=(0.9, 0.98), eps=1e-9) print("Start training") @@ -47,6 +51,7 @@ def train(self, epochs, optimizer_params, checkpoint_path=None): checkpoint = torch.load(checkpoint_path) self.model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + scheduler.load_state_dict(checkpoint['scheduler_state_dict']) start_epoch = checkpoint['epoch'] self.model_dir = Path(checkpoint_path).parent.parent @@ -75,6 +80,7 @@ def train(self, epochs, optimizer_params, checkpoint_path=None): loss = criterion(outputs.reshape(-1, outputs.shape[-1]), tgt.reshape(-1)) loss.backward() optimizer.step() + scheduler.step() running_loss += loss.item() _iters += 1 @@ -82,12 +88,13 @@ def train(self, epochs, optimizer_params, checkpoint_path=None): avg_loss = running_loss / _iters train_status.set_description_str(f"Epoch {epoch}, train loss: {avg_loss}") + self.log_lr(epoch, "Train", optimizer) self.log_metrics(epoch, "Train", avg_loss) # Save checkpoint save_path = Path(self.model_dir, "checkpoints", f"step_{epoch + 1}") save_path.parent.mkdir(parents=True, exist_ok=True) - self.save_checkpoint(epoch, save_path, optimizer) + self.save_checkpoint(epoch, save_path, optimizer, scheduler) # Evaluate on validation set if available if self.val_loader is not None: @@ -133,12 +140,13 @@ def calculate_metrics(self, data_loader, loss_fn, temperature=1.): return running_loss, bleu_res, meteor_average - def save_checkpoint(self, epoch, path, optimizer): + def save_checkpoint(self, epoch, path, optimizer, scheduler): if path: torch.save({ 'epoch': epoch, 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), + 'scheduler_state_dict': scheduler.state_dict(), }, f'{path}_epoch_{epoch}.pt') @staticmethod @@ -157,6 +165,9 @@ def log_metrics(self, epoch, phase, loss, bleu=None, meteor=None): if meteor is not None: self.writer.add_scalar(f'{phase}/METEOR', meteor, epoch) + def log_lr(self, epoch, phase, optimizer): + self.writer.add_scalar(f"{phase}/LR", optimizer.param_groups[0]["lr"], epoch) + def save_model(self, path): torch.save(self.model.state_dict(), path) diff --git a/transformer/utils.py b/transformer/utils.py new file mode 100644 index 0000000..76ab66d --- /dev/null +++ b/transformer/utils.py @@ -0,0 +1,25 @@ +from functools import partial + +import torch + + +def transformer_scheduler(step, d_model, warmup_steps): + step += 1 + min_val = min( + step ** (-1 / 2), + step * warmup_steps ** (-3 / 2) + ) + return d_model ** (-1 / 2) * min_val + + +def get_optimizer_and_scheduler(model, d_model, warmup_steps, lr, betas, eps): + optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=betas, eps=eps) + scheduler = torch.optim.lr_scheduler.LambdaLR( + optimizer, + lr_lambda=partial( + transformer_scheduler, + d_model=d_model, + warmup_steps=warmup_steps + ) + ) + return optimizer, scheduler