-
Notifications
You must be signed in to change notification settings - Fork 0
/
Main.py
114 lines (95 loc) · 3.48 KB
/
Main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as dsets
torch.manual_seed(0)
from matplotlib.pyplot import imshow
import matplotlib.pylab as plt
from PIL import Image
def show_data(data_sample):
plt.imshow(data_sample[0].numpy().reshape(IMAGE_SIZE, IMAGE_SIZE), cmap='gray')
plt.title('y = '+ str(data_sample[1].item()))
IMAGE_SIZE = 16
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
transforms.ToTensor()#
composed = transforms.Compose([transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), transforms.ToTensor()])
dataset_train=dsets.FashionMNIST(root= '.fashion/data', train=True, transform=composed, download=True)
dataset_val=dsets.FashionMNIST(root= '.fashion/data', train=False, transform=composed, download=True)
for n,data_sample in enumerate(dataset_val):
show_data(data_sample)
plt.show()
if n==2:
break
class CNN_batch(nn.Module):
# Contructor
def __init__(self, out_1=16, out_2=32,number_of_classes=10):
super(CNN_batch, self).__init__()
self.cnn1 = nn.Conv2d(in_channels=1, out_channels=out_1, kernel_size=5, padding=2)
self.conv1_bn = nn.BatchNorm2d(out_1)
self.maxpool1=nn.MaxPool2d(kernel_size=2)
self.cnn2 = nn.Conv2d(in_channels=out_1, out_channels=out_2, kernel_size=5, stride=1, padding=2)
self.conv2_bn = nn.BatchNorm2d(out_2)
self.maxpool2=nn.MaxPool2d(kernel_size=2)
self.fc1 = nn.Linear(out_2 * 4 * 4, number_of_classes)
self.bn_fc1 = nn.BatchNorm1d(10)
# Prediction
def forward(self, x):
x = self.cnn1(x)
x=self.conv1_bn(x)
x = torch.relu(x)
x = self.maxpool1(x)
x = self.cnn2(x)
x=self.conv2_bn(x)
x = torch.relu(x)
x = self.maxpool2(x)
x = x.view(x.size(0), -1)
x = self.fc1(x)
x=self.bn_fc1(x)
return x
train_loader = torch.utils.data.DataLoader(dataset=dataset_train, batch_size=100 )
test_loader = torch.utils.data.DataLoader(dataset=dataset_val, batch_size=100 )
model = CNN(out_1=16, out_2=32,number_of_classes=10)
import time
start_time = time.time()
cost_list=[]
accuracy_list=[]
N_test=len(dataset_val)
learning_rate =0.1
optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)
criterion = nn.CrossEntropyLoss()
n_epochs=5
for epoch in range(n_epochs):
cost=0
model.train()
for x, y in train_loader:
optimizer.zero_grad()
z = model(x)
loss = criterion(z, y)
loss.backward()
optimizer.step()
cost+=loss.item()
correct=0
#perform a prediction on the validation data
model.eval()
for x_test, y_test in test_loader:
z = model(x_test)
_, yhat = torch.max(z.data, 1)
correct += (yhat == y_test).sum().item()
accuracy = correct / N_test
accuracy_list.append(accuracy)
cost_list.append(cost)
fig, ax1 = plt.subplots()
color = 'tab:red'
ax1.plot(cost_list, color=color)
ax1.set_xlabel('epoch', color=color)
ax1.set_ylabel('Cost', color=color)
ax1.tick_params(axis='y', color=color)
ax2 = ax1.twinx()
color = 'tab:blue'
ax2.set_ylabel('accuracy', color=color)
ax2.set_xlabel('epoch', color=color)
ax2.plot( accuracy_list, color=color)
ax2.tick_params(axis='y', color=color)
fig.tight_layout()