-
Notifications
You must be signed in to change notification settings - Fork 0
/
read_data.py
146 lines (110 loc) · 4.88 KB
/
read_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""Parses the MNIST binary file to read image and label data."""
from __future__ import absolute_import
from __future__ import print_function
import os
from six.moves import xrange
import tensorflow as tf
# Global constants describing the MNIST data set.
IMAGE_SIZE = 28
IMAGE_CHANNELS = 1
NUM_PIXELS = IMAGE_SIZE*IMAGE_SIZE*IMAGE_CHANNELS
NUM_CLASSES = 10
NUM_EXAMPLES_IN_TRAIN_SET = 60000
NUM_EXAMPLES_IN_TEST_SET = 10000
FOLDER_PATH = r"/root/binLearning/database/MNIST/"
def _read_images(test_data=False, as_image=True, for_show=False):
"""Reads and parses the binary file which contains training/test images.
"""
if not test_data:
filename = os.path.join(FOLDER_PATH, 'train-images.idx3-ubyte')
else:
filename = os.path.join(FOLDER_PATH, 't10k-images.idx3-ubyte')
if not os.path.exists(filename):
raise ValueError('The file dose not exist.')
# Create a queue that produces the filenames to read.
filename_queue = tf.train.string_input_producer([filename])
# The first 16 bytes contain file information:
# [offset] [type] [value] [description]
# 0000 32 bit integer 0x00000803(2051) magic number
# 0004 32 bit integer 60000/10000 number of images
# 0008 32 bit integer 28 number of rows
# 0012 32 bit integer 28 number of columns
# ...(pixel value)
header_bytes = 16
# Every record consists of an image, with a fixed number of bytes for each.
record_bytes = IMAGE_SIZE * IMAGE_SIZE
# Create a FixedLengthRecordReader to read record.
reader = tf.FixedLengthRecordReader(record_bytes=record_bytes,
header_bytes=header_bytes)
_, value = reader.read(filename_queue)
# Convert from a string to a vector of uint8.
image = tf.decode_raw(value, tf.uint8)
if for_show:
reshape_image = tf.reshape(image, [IMAGE_SIZE, IMAGE_SIZE])
return reshape_image
if as_image: # for CNN
# Reshape from [height * width * channels] to [height, width, channels].
reshape_image = tf.reshape(image, [IMAGE_SIZE, IMAGE_SIZE, IMAGE_CHANNELS])
# Subtract off the mean and divide by the variance of the pixels.
# Linearly scales image to have zero mean and unit norm.
preproc_image = tf.image.per_image_whitening(reshape_image)
else: # for linear classifier / ANN
# To avoid ValueError: All shapes must be fully defined:...
image.set_shape([IMAGE_SIZE * IMAGE_SIZE])
# Cast image pixel value from tf.uint8 to tf.float32
float_image = tf.cast(image, tf.float32)
# normalization
preproc_image = tf.div(float_image, 255.0)
return preproc_image
def _read_labels(test_data=False):
"""Reads and parses the binary file which contains training/test labels.
"""
if not test_data:
filename = os.path.join(FOLDER_PATH, 'train-labels.idx1-ubyte')
else:
filename = os.path.join(FOLDER_PATH, 't10k-labels.idx1-ubyte')
if not os.path.exists(filename):
raise ValueError('The file dose not exist.')
# Create a queue that produces the filenames to read.
filename_queue = tf.train.string_input_producer([filename])
# The first 8 bytes contain file information:
# [offset] [type] [value] [description]
# 0000 32 bit integer 0x00000801(2049) magic number
# 0004 32 bit integer 60000/10000 number of items
# ...(label value)
header_bytes = 8
# Every record consists of a label, with a fixed number of bytes for each.
record_bytes = 1
# Create a FixedLengthRecordReader to read record.
reader = tf.FixedLengthRecordReader(record_bytes=record_bytes,
header_bytes=header_bytes)
_, value = reader.read(filename_queue)
# Convert from a string to a vector of uint8, then cast to int32.
record = tf.cast(tf.decode_raw(value, tf.uint8), tf.int32)
# Reshape from [1] to a scalar shape [].
label = tf.reshape(record, [])
return label
def generate_batch(model, batch_size, test_data=False):
"""Construct a queued batch of images and labels.
"""
if model == 'cnn':
as_image = True
else:
as_image = False
image = _read_images(test_data=test_data, as_image=as_image)
label = _read_labels(test_data=test_data)
images_batch, labels_batch = tf.train.batch([image, label],
batch_size = batch_size,
num_threads = 1,
capacity = batch_size * 8)
return images_batch, tf.reshape(labels_batch, [batch_size])
"""
def func():
image,label = generate_batch(...)
with tf.Session() as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
# run training step or whatever
coord.request_stop()
coord.join(threads)
"""