-
Notifications
You must be signed in to change notification settings - Fork 118
/
train_mnist.py
102 lines (81 loc) · 2.81 KB
/
train_mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""
This example shows training of a simple Conv model with MNIST dataset using Auto Training mode of 👟.
"""
import os
from dataclasses import dataclass
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from trainer import TrainerConfig, TrainerModel, Trainer, TrainerArgs
@dataclass
class MnistModelConfig(TrainerConfig):
optimizer: str = "Adam"
lr: float = 0.001
epochs: int = 1
print_step: int = 1
save_step: int = 5
plot_step: int = 5
dashboard_logger: str = "tensorboard"
class MnistModel(TrainerModel):
def __init__(self):
super().__init__()
# mnist images are (1, 28, 28) (channels, height, width)
self.layer_1 = nn.Linear(28 * 28, 128)
self.layer_2 = nn.Linear(128, 256)
self.layer_3 = nn.Linear(256, 10)
def forward(self, x):
batch_size, _, _, _ = x.size()
# (b, 1, 28, 28) -> (b, 1*28*28)
x = x.view(batch_size, -1)
x = self.layer_1(x)
x = F.relu(x)
x = self.layer_2(x)
x = F.relu(x)
x = self.layer_3(x)
x = F.log_softmax(x, dim=1)
return x
def train_step(self, batch, criterion):
x, y = batch
logits = self(x)
loss = criterion(logits, y)
return {"model_outputs": logits}, {"loss": loss}
def eval_step(self, batch, criterion):
x, y = batch
logits = self(x)
loss = criterion(logits, y)
return {"model_outputs": logits}, {"loss": loss}
@staticmethod
def get_criterion():
return torch.nn.NLLLoss()
def get_data_loader(
self, config, assets, is_eval, samples, verbose, num_gpus, rank=0
): # pylint: disable=unused-argument
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
dataset = MNIST(os.getcwd(), train=not is_eval, download=True, transform=transform)
dataset.data = dataset.data[:256]
dataset.targets = dataset.targets[:256]
dataloader = DataLoader(dataset, batch_size=config.batch_size)
return dataloader
def main():
"""Run `MNIST` model training from scratch or from previous checkpoint."""
# init args and config
train_args = TrainerArgs()
config = MnistModelConfig()
# init the model from config
model = MnistModel()
# init the trainer and 🚀
trainer = Trainer(
train_args,
config,
config.output_path,
model=model,
train_samples=model.get_data_loader(config, None, False, None, None, None),
eval_samples=model.get_data_loader(config, None, True, None, None, None),
parse_command_line_args=True,
)
trainer.fit()
if __name__ == "__main__":
main()