-
Notifications
You must be signed in to change notification settings - Fork 5
/
caltech_101.py
110 lines (92 loc) · 3.86 KB
/
caltech_101.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""Manages the loading, labeling, and caching of the Caltech 101 dataset.
"""
from keras.preprocessing import image
from config import args
import h5py
import json
import numpy as np
import os
import pickle
""" Several convienience functions to assist in the loading / configuration """
def caltech_101_cache():
return os.path.join(args.data_dir, args.caltech_101_cache)
def caltech_101_dir():
return os.path.join(args.data_dir, args.caltech_101_dir)
def labels_cache():
return os.path.join(args.data_dir, args.labels_file)
def get_directories():
return [ x[0] for x in os.walk(caltech_101_dir()) ][1:]
def class_labels():
""" Return a mapping of label indices -> class name in caltech 101 dataset
Mapping is saved to cache so that it is consistent across calls.
"""
if not os.path.isfile(labels_cache()):
all_directories = get_directories()
labels = {idx: os.path.basename(d) for idx, d in enumerate(all_directories)}
with open(labels_cache(), "w") as f:
json.dump(labels, f)
return json.load(open(labels_cache(), "r"))
def num_class_labels():
return len(class_labels())
def load_images():
""" Load all the images from JPG format into numpy arrays
Returns ((training data, training labels), (test data, test labels)),
data shape: [N, 224, 224, 3]
labels shape: [N]
where N is the number of examples
"""
all_directories = get_directories()
labels = class_labels()
num_labels = len(labels)
# These will assemble the training and test data, to be concatenated later
train_x_arr = []
train_y_arr = []
test_x_arr = []
test_y_arr = []
for class_idx, name in labels.iteritems():
print "Loading {} images...".format(name)
class_path = os.path.join(caltech_101_dir(), name)
files = [os.path.join(class_path, f) for f in os.listdir(class_path)
if os.path.isfile(os.path.join(class_path, f))]
imgs = []
for img_file in files:
img = image.load_img(img_file, target_size=(224, 224))
imgs.append(image.img_to_array(img))
# Decide on the index to split training and test sets at
split = int(np.floor(len(imgs) * args.train_test_ratio))
train_x_arr.append(np.stack(imgs[:split], axis=0))
test_x_arr.append(np.stack(imgs[split:], axis=0))
train_batch_len = len(imgs[:split])
train_y = np.multiply(int(class_idx),
np.ones((train_batch_len,), dtype=int))
train_y_arr.append(train_y)
test_batch_len = len(imgs[split:])
test_y = np.multiply(int(class_idx),
np.ones((test_batch_len,), dtype=int))
test_y_arr.append(test_y)
return ((np.concatenate(train_x_arr), np.concatenate(train_y_arr)),
(np.concatenate(test_x_arr), np.concatenate(test_y_arr)))
def load_data(recache=False):
""" Returns the caltech_101 data, in the format described above
If the data doesn't exist or recache == True, then load the data into a
numpy array and save it to the cached file. Otherwise, load the data
directly from the cached file.
"""
if recache or not os.path.isfile(caltech_101_cache()):
(train_x, train_y), (test_x, test_y) = load_images()
print "Saving file..."
f = h5py.File(caltech_101_cache(), mode="w")
f.create_dataset("train_x", data=train_x)
f.create_dataset("train_y", data=train_y)
f.create_dataset("test_x", data=test_x)
f.create_dataset("test_y", data=test_y)
f.close()
f = h5py.File(caltech_101_cache(), mode="r")
train_x = f["train_x"][:]
test_x = f["test_x"][:]
train_y = f["train_y"][:]
test_y = f["test_y"][:]
f.close()
return ((train_x, train_y), (test_x, test_y))
if __name__ == "__main__":
load_data(recache=True)