-
Notifications
You must be signed in to change notification settings - Fork 1
/
adversarial_reprogramming_counting_squares.py
126 lines (89 loc) · 3.45 KB
/
adversarial_reprogramming_counting_squares.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import numpy as np
import torch
import torch as T
import torch.nn as nn
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import Dataset
from tqdm import tqdm
from random import shuffle
import matplotlib.pyplot as plt
import scipy
import scipy.misc as misc
from utils import ProgrammingNetwork, get_program, train, reg_l1, reg_l2
def create_patchv1(nb_square, patch_size=36, square_size=4, border=True):
mini_patch = torch.zeros((square_size, square_size))
tmp = [(np.random.choice(square_size), np.random.choice(square_size))]
for _ in range(nb_square - 1):
x = (np.random.choice(square_size), np.random.choice(square_size))
while x in tmp:
x = (np.random.choice(square_size), np.random.choice(square_size))
tmp.append(x)
tmp = list(zip(*tmp))
mini_patch[np.array(tmp[0]), np.array(tmp[1])] = 1
patch = torch.zeros((patch_size, patch_size))
slice = int(patch_size/square_size)
for i, j in zip(tmp[0], tmp[1]):
patch[i*9 : i*9 + slice - border, j*9 : j*9 + slice - border] = 1
return patch
def create_patch(nb_square, patch_size=36, square_size=4, border=True):
"""
:param nb_square: the number of squares to set on
:param patch_size: the size of the patch wich will be in the center of the img
:param square_size: the size of the squares <!> square_size < patch_size
:param border: a boolean, that allows to separate the squares (to have a better display)
:type nb_square: int
:type patch_size: int
:type square_size: int
:type border: bool
:return: the new img
:rtype: torch.tensor
"""
patch = [1] * nb_square + [0] * (16 - nb_square)
shuffle(patch)
patch = np.array(patch).reshape(square_size, square_size)
patch = scipy.misc.imresize(patch, (patch_size, patch_size), interp='nearest') // 255
if border:
for i in range(0, patch_size, patch_size // square_size):
patch[i, :] = 0.
patch[:, i] = 0.
return torch.Tensor([patch]).float()
class SquaresDataset(Dataset):
def __init__(self, patch_size=36, square_size=4, dataset_size=100000):
self.patch_size = patch_size
self.square_size = square_size
self.dataset_size = dataset_size
def __len__(self):
return self.dataset_size
def __getitem__(self, item):
y = np.random.randint(1, 10)
return (
create_patch(y, self.patch_size, self.square_size, border=False),
T.tensor(y).long()
)
def get_counting_squares(batch_size, dataset_size=100000):
train_loader = T.utils.data.DataLoader(
SquaresDataset(dataset_size=dataset_size),
batch_size=batch_size,
shuffle=True
)
return train_loader
DEVICE = 'cuda:0'
PATH = "./models/squeezenet1_0_counting_squares_"
batch_size = 16
train_loader = get_counting_squares(batch_size)
test_loader = get_counting_squares(batch_size, 1000)
pretrained_model = torchvision.models.squeezenet1_0(pretrained=True).eval()
input_size = 224
patch_size = 36
model = ProgrammingNetwork(pretrained_model, input_size, patch_size, blur_sigma=1.5, device=DEVICE)
optimizer = T.optim.Adam([model.p])
nb_epochs = 20
nb_freq = 10
model, loss_history = train(
model, train_loader, nb_epochs, optimizer,
C=.05, reg_fun=reg_l2,
save_freq=nb_freq,
save_path=PATH, test_loader=test_loader, device=DEVICE
)
program = get_program(model, PATH, imshow=True)