forked from rajatvd/NeuralODE
-
Notifications
You must be signed in to change notification settings - Fork 0
/
training_functions.py
140 lines (108 loc) · 3.91 KB
/
training_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""
Train on batch and other functions for training a ODEnet on MNIST
"""
import torch
from torch import nn
from torch import optim
from tqdm import tqdm
import logging
import pytorch_utils.sacred_trainer as st
def train_on_batch(model, batch, optimizer):
"""One train step on batch of MNIST data. Uses CrossEntropyLoss.
Parameters
----------
model : nn.Module
Model for MNIST classification.
batch : tuple
Tuple of images and labels
optimizer : torch Optimizer
Description of parameter `optimizer`.
Returns
-------
tuple: loss, accuracy
Both are numpy
"""
if isinstance(model, nn.DataParallel):
ode_model = model.module
else:
ode_model = model
criterion = nn.CrossEntropyLoss()
images, labels = batch
ode_model.odefunc.nfe.fill_(0)
outputs = model(images)
nfe_forward = ode_model.odefunc.nfe.item()
loss = criterion(outputs, labels)
# backward and optimize
ode_model.odefunc.nfe.fill_(0)
optimizer.zero_grad()
loss.backward()
optimizer.step()
nfe_backward = ode_model.odefunc.nfe.item()
loss = loss.cpu().detach().numpy()
acc = st.accuracy(outputs.cpu(), labels.cpu())
return loss, acc, nfe_forward, nfe_backward
def validate(model, val_loader, _log=logging.getLogger('validate')):
"""Find loss and accuracy on the given dataloader using the model.
Parameters
----------
model : nn.Module
Model for MNIST classification.
val_loader : DataLoader
Data over which to validate.
Returns
-------
tuple: val_loss, accuracy
Both are numpy
"""
model = model.eval()
val_loss = 0
accuracy = 0
total = 0
_log.info(f"Running validate with {len(val_loader)} steps")
for images, labels in tqdm(val_loader):
with torch.no_grad():
criterion = nn.CrossEntropyLoss()
batch_size = images.shape[0]
# forward pass
outputs = model(images)
loss = criterion(outputs, labels)
val_loss += loss*batch_size
accuracy += st.accuracy(outputs, labels)*batch_size
total += batch_size
val_loss /= total
accuracy /= total
model = model.train()
return val_loss.cpu().numpy(), accuracy
def scheduler_generator(optimizer, milestones, gamma):
"""A generator which performs lr scheduling on the given optimizer using
a MultiStepLR scheduler with given milestones and gamma."""
scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
milestones,
gamma)
while True:
scheduler.step()
yield (optimizer.param_groups[0]['lr'],) # yield to return the lr
def create_scheduler_callback(optimizer, milestones, gamma):
"""Returns a function which can be used as callback for lr scheduling
on the given optimizer using a MultiStepLR scheduler with given
milestones and gamma."""
g = scheduler_generator(optimizer, milestones, gamma)
def scheduler_callback(model, val_loader, batch_metrics_dict):
"""LR scheduler callback using the next function of a
scheduler_generator"""
return next(g)
return scheduler_callback
def create_val_scheduler_callback(optimizer, milestones, gamma):
"""Returns a function which can be used as callback for lr scheduling
on the given optimizer using a MultiStepLR scheduler with given
milestones and gamma.
It also computes loss on the validation data loader.
"""
g = scheduler_generator(optimizer, milestones, gamma)
def scheduler_callback(model, val_loader, batch_metrics_dict):
"""LR scheduler callback using the next function of a
scheduler_generator"""
val_loss, val_accuracy = validate(model, val_loader)
lr = next(g)
return val_loss, val_accuracy, lr[0]
return scheduler_callback