Skip to content

Commit

Permalink
update lightning
Browse files Browse the repository at this point in the history
  • Loading branch information
dberenbaum committed Jan 17, 2024
1 parent efb1579 commit 1f80c6a
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
import numpy as np
import torch
import pytorch_lightning as pl
import lightning as L
import torchmetrics
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint, Timer
from ruamel.yaml import YAML
from dvclive import Live
from dvclive.lightning import DVCLiveLogger
Expand Down Expand Up @@ -76,7 +78,7 @@


# Define the model
class LSTMSeqToSeq(pl.LightningModule):
class LSTMSeqToSeq(L.LightningModule):
def __init__(self, latent_dim, optim_params):
super().__init__()
# Log parameters (saves them to self.hparams)
Expand Down Expand Up @@ -163,14 +165,14 @@ def __getitem__(self, idx):

exp = Live("results", save_dvc_exp=True)
live = DVCLiveLogger(report=None, experiment=exp, log_model=True)
checkpoint = pl.callbacks.ModelCheckpoint(
checkpoint = ModelCheckpoint(
dirpath="model",
monitor="val_acc",
mode="max",
save_weights_only=True, every_n_epochs=1)
timer = pl.callbacks.Timer(duration=params["model"]["duration"])
timer = Timer(duration=params["model"]["duration"])

trainer = pl.Trainer(max_epochs=params["model"]["max_epochs"], logger=[live],
trainer = Trainer(max_epochs=params["model"]["max_epochs"], logger=[live],
callbacks=[timer, checkpoint])
trainer.fit(model=arch, train_dataloaders=train_loader,
val_dataloaders=val_loader)

0 comments on commit 1f80c6a

Please sign in to comment.