-
Notifications
You must be signed in to change notification settings - Fork 47
/
setup_mnist.py
100 lines (82 loc) · 3.34 KB
/
setup_mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
## setup_mnist.py -- mnist data and model loading code
##
## Copyright (C) IBM Corp, 2017-2018
## Copyright (C) 2016, Nicholas Carlini <[email protected]>.
##
## This program is licenced under the BSD 2-Clause licence,
## contained in the LICENCE file in this directory.
import tensorflow as tf
import numpy as np
import os
import pickle
import gzip
import urllib.request
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras.utils import np_utils
from keras.models import load_model
def extract_data(filename, num_images):
with gzip.open(filename) as bytestream:
bytestream.read(16)
buf = bytestream.read(num_images*28*28)
data = np.frombuffer(buf, dtype=np.uint8).astype(np.float32)
data = (data / 255) - 0.5
data = data.reshape(num_images, 28, 28, 1)
return data
def extract_labels(filename, num_images):
with gzip.open(filename) as bytestream:
bytestream.read(8)
buf = bytestream.read(1 * num_images)
labels = np.frombuffer(buf, dtype=np.uint8)
return (np.arange(10) == labels[:, None]).astype(np.float32)
class MNIST:
def __init__(self):
if not os.path.exists("data"):
os.mkdir("data")
files = ["train-images-idx3-ubyte.gz",
"t10k-images-idx3-ubyte.gz",
"train-labels-idx1-ubyte.gz",
"t10k-labels-idx1-ubyte.gz"]
for name in files:
urllib.request.urlretrieve('http://yann.lecun.com/exdb/mnist/' + name, "data/"+name)
train_data = extract_data("data/train-images-idx3-ubyte.gz", 60000)
train_labels = extract_labels("data/train-labels-idx1-ubyte.gz", 60000)
self.test_data = extract_data("data/t10k-images-idx3-ubyte.gz", 10000)
self.test_labels = extract_labels("data/t10k-labels-idx1-ubyte.gz", 10000)
VALIDATION_SIZE = 5000
self.validation_data = train_data[:VALIDATION_SIZE, :, :, :]
self.validation_labels = train_labels[:VALIDATION_SIZE]
self.train_data = train_data[VALIDATION_SIZE:, :, :, :]
self.train_labels = train_labels[VALIDATION_SIZE:]
class MNISTModel:
def __init__(self, restore = None, session=None, use_log=False):
self.num_channels = 1
self.image_size = 28
self.num_labels = 10
model = Sequential()
model.add(Conv2D(32, (3, 3),
input_shape=(28, 28, 1)))
model.add(Activation('relu'))
model.add(Conv2D(32, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(64, (3, 3)))
model.add(Activation('relu'))
model.add(Conv2D(64, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(200))
model.add(Activation('relu'))
model.add(Dense(200))
model.add(Activation('relu'))
model.add(Dense(10))
# output log probability, used for black-box attack
if use_log:
model.add(Activation('softmax'))
if restore:
model.load_weights(restore)
self.model = model
def predict(self, data):
return self.model(data)