-
Notifications
You must be signed in to change notification settings - Fork 5
/
main.py
156 lines (117 loc) · 4.79 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# ================================================================ #
# Optimization #
# ================================================================ #
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
import torch.onnx as onnx
import torchvision.models as models
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
train_loader = DataLoader(training_data, batch_size=64)
test_loader = DataLoader(test_data, batch_size=64)
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28 * 28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
nn.ReLU()
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = NeuralNetwork()
# Hyperparameters
learning_rate = 1e-3
batch_size = 64
epochs = 5
# Initialize the loss function
# nn.CrossEntropyLoss combines nn.LogSoftmax and nn.NLLLoss.
loss_fn = nn.CrossEntropyLoss()
# ================================================================ #
# Optimizer #
# ================================================================ #
# Optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
# ================================================================ #
# Full Implementation #
# ================================================================ #
def train_loop(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
for batch, (images, labels) in enumerate(dataloader):
# Compute prediction and loss
pred = model(images)
loss = loss_fn(pred, labels)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(images)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
def test_loop(dataloader, model, loss_fn):
size = len(dataloader.dataset)
test_loss, correct = 0, 0
with torch.no_grad():
for images, labels in dataloader:
pred = model(images)
test_loss += loss_fn(pred, labels).item()
correct += (pred.argmax(1) == labels).type(torch.float).sum().item()
test_loss /= size
correct /= size
print(f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
# ================================================================ #
# Training #
# ================================================================ #
for t in range(epochs):
print(f"Epoch {t + 1}\n-------------------------------")
train_loop(train_loader, model, loss_fn, optimizer)
test_loop(test_loader, model, loss_fn)
print("Done!")
# ================================================================ #
# Save and Load the Model #
# ================================================================ #
model = models.vgg16(pretrained=True)
torch.save(model.state_dict(), 'model_weights.pth') # save the model
''' To load model weights, you need to create an instance
of the same model first, and then load the parameters using
load_state_dict() method. '''
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = models.vgg16() # load the model structure not weights
# To laod the model to cpu or gpu: map_location=device
model.load_state_dict(torch.load('model_weights.pth', map_location=device))
model.eval()
''' Be sure to call model.eval() method before inferencing to set the dropout
and batch normalization layers to evaluation mode. Failing to do this will
yield inconsistent inference results.
'''
# ================================================================ #
# Deep Save and Load the Models #
# ================================================================ #
# save
torch.save(model, 'model.pth')
# load
model = torch.load('model.pth')
# ================================================================ #
# Export Model to ONNX #
# ================================================================ #
input_image = torch.zeros((1, 3, 224, 224))
onnx.export(model, input_image, 'model.onnx')