forked from marian42/butterflies
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_classifier.py
67 lines (54 loc) · 1.9 KB
/
train_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from itertools import count
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
import sys
from torch.utils.data import DataLoader
from torchvision import utils
from tqdm import tqdm
from classifier import Classifier
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CLASSIFIER_FILENAME = 'trained_models/classifier.to'
SAVE_EXAMPLES = False
classifier = Classifier()
try:
classifier.load_state_dict(torch.load(CLASSIFIER_FILENAME))
except:
print("Found no model, training a new one.")
classifier.cuda()
from mask_loader import MaskDataset
dataset = MaskDataset()
data_loader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=8)
optimizer = optim.Adam(classifier.parameters(), lr=0.0002)
criterion = nn.BCELoss()
def save_example(epoch, hash, image, mask):
mask_binary = torch.zeros(mask.shape)
mask_binary[mask > 0.5] = 1
w, h = mask.shape[0] // 2, mask.shape[1] // 2
mask[:w, :h] = mask_binary[:w, :h]
mask[w:, h:] = mask_binary[w:, h:]
result = image.clone().squeeze(0)
result *= mask
utils.save_image(result, 'data/test/{:s}.png'.format(hash))
def train():
for epoch in count():
loss_history = []
for batch in tqdm(data_loader):
image, mask, hash = batch
image = image.to(device)
mask = mask.to(device)
classifier.zero_grad()
output = classifier(image).squeeze(1)
loss = criterion(output, mask)
loss.backward()
optimizer.step()
error = loss.item()
loss_history.append(error)
if epoch % 100 == 0 and SAVE_EXAMPLES:
for i in range(image.shape[0]):
save_example(epoch, hash[i], image[i, :, :], output[i, :, :])
print(epoch, np.mean(loss_history))
torch.save(classifier.state_dict(), CLASSIFIER_FILENAME)
train()