-
Notifications
You must be signed in to change notification settings - Fork 62
/
data_loader.py
executable file
·72 lines (67 loc) · 3.91 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from parameters import DATASET, HYPERPARAMS
import numpy as np
def load_data(validation=False, test=False):
data_dict = dict()
validation_dict = dict()
test_dict = dict()
if DATASET.name == "Fer2013":
# load train set
if HYPERPARAMS.features == "landmarks_and_hog":
data_dict['X'] = np.load(DATASET.train_folder + '/landmarks.npy')
data_dict['X'] = np.array([x.flatten() for x in data_dict['X']])
data_dict['X'] = np.concatenate((data_dict['X'], np.load(DATASET.train_folder + '/hog_features.npy')), axis=1)
elif HYPERPARAMS.features == "landmarks":
data_dict['X'] = np.load(DATASET.train_folder + '/landmarks.npy')
data_dict['X'] = np.array([x.flatten() for x in data_dict['X']])
elif HYPERPARAMS.features == "hog":
data_dict['X'] = np.load(DATASET.train_folder + '/hog_features.npy')
else:
print( "Error '{}' features not recognized".format(HYPERPARAMS.features))
data_dict['Y'] = np.load(DATASET.train_folder + '/labels.npy')
if DATASET.trunc_trainset_to > 0:
data_dict['X'] = data_dict['X'][0:DATASET.trunc_trainset_to, :]
data_dict['Y'] = data_dict['Y'][0:DATASET.trunc_trainset_to]
if validation:
# load validation set
if HYPERPARAMS.features == "landmarks_and_hog":
validation_dict['X'] = np.load(DATASET.validation_folder + '/landmarks.npy')
validation_dict['X'] = np.array([x.flatten() for x in validation_dict['X']])
validation_dict['X'] = np.concatenate((validation_dict['X'], np.load(DATASET.validation_folder + '/hog_features.npy')), axis=1)
elif HYPERPARAMS.features == "landmarks":
validation_dict['X'] = np.load(DATASET.validation_folder + '/landmarks.npy')
validation_dict['X'] = np.array([x.flatten() for x in validation_dict['X']])
elif HYPERPARAMS.features == "hog":
validation_dict['X'] = np.load(DATASET.validation_folder + '/hog_features.npy')
else:
print( "Error '{}' features not recognized".format(HYPERPARAMS.features))
validation_dict['Y'] = np.load(DATASET.validation_folder + '/labels.npy')
if DATASET.trunc_validationset_to > 0:
validation_dict['X'] = validation_dict['X'][0:DATASET.trunc_validationset_to, :]
validation_dict['Y'] = validation_dict['Y'][0:DATASET.trunc_validationset_to]
if test:
# load train set
if HYPERPARAMS.features == "landmarks_and_hog":
test_dict['X'] = np.load(DATASET.test_folder + '/landmarks.npy')
test_dict['X'] = np.array([x.flatten() for x in test_dict['X']])
test_dict['X'] = np.concatenate((test_dict['X'], np.load(DATASET.test_folder + '/hog_features.npy')), axis=1)
elif HYPERPARAMS.features == "landmarks":
test_dict['X'] = np.load(DATASET.test_folder + '/landmarks.npy')
test_dict['X'] = np.array([x.flatten() for x in test_dict['X']])
elif HYPERPARAMS.features == "hog":
test_dict['X'] = np.load(DATASET.test_folder + '/hog_features.npy')
else:
print( "Error '{}' features not recognized".format(HYPERPARAMS.features))
test_dict['Y'] = np.load(DATASET.test_folder + '/labels.npy')
np.save(DATASET.test_folder + "/lab.npy", test_dict['Y'])
if DATASET.trunc_testset_to > 0:
test_dict['X'] = test_dict['X'][0:DATASET.trunc_testset_to, :]
test_dict['Y'] = test_dict['Y'][0:DATASET.trunc_testset_to]
if not validation and not test:
return data_dict
elif not test:
return data_dict, validation_dict
else:
return data_dict, validation_dict, test_dict
else:
print( "Unknown dataset")
exit()