From f89604a5f3622a97bba21b1fb2624d11d33fdd4c Mon Sep 17 00:00:00 2001 From: haoyangz Date: Tue, 9 May 2017 11:00:13 -0400 Subject: [PATCH] update to work with keras 2 --- example/model.py | 4 ++-- main.py | 43 ++++++++++++++++++++++++++++++++++--------- 2 files changed, 36 insertions(+), 11 deletions(-) diff --git a/example/model.py b/example/model.py index 908cc2f..bbcb6fe 100644 --- a/example/model.py +++ b/example/model.py @@ -1,7 +1,7 @@ import h5py from os.path import join,exists from keras.models import Sequential -from keras.layers.core import Dense, Dropout, Activation,Flatten,Merge +from keras.layers.core import Dense, Dropout, Activation,Flatten from keras.layers.convolutional import Convolution2D,MaxPooling2D from keras.optimizers import Adadelta,RMSprop from hyperas.distributions import choice, uniform, conditional @@ -41,7 +41,7 @@ def model(X_train, Y_train, X_test, Y_test): model.add(Activation('softmax')) myoptimizer = RMSprop(lr={{choice([0.01,0.001,0.0001])}}, rho=0.9, epsilon=1e-06) - mylossfunc = 'binary_crossentropy' + mylossfunc = 'categorical_crossentropy' model.compile(loss=mylossfunc, optimizer=myoptimizer,metrics=['accuracy']) model.fit(X_train, Y_train, batch_size=100, nb_epoch=5,validation_split=0.1) diff --git a/main.py b/main.py index 22052d5..e64f1aa 100644 --- a/main.py +++ b/main.py @@ -30,6 +30,8 @@ def parse_args(): parser.add_argument("-l", "--lweightfile",default=None,help="Weight file after training") parser.add_argument("-r", "--retrain",default=None,help="codename for the retrain run") parser.add_argument("-rw", "--rweightfile",default='',help="Weight file to load for retraining") + parser.add_argument("-dm", "--datamode",default='generator',help="Weight file to load for retraining") + return parser.parse_args() def probedata(dataprefix): @@ -43,6 +45,22 @@ def probedata(dataprefix): samplecnt += len(data['label']) return (cnt,samplecnt) +def readdata(dataprefix): + allfiles = subprocess.check_output('ls '+dataprefix+'*', shell=True).split('\n')[:-1] + cnt = 0 + samplecnt = 0 + for x in allfiles: + if x.split(dataprefix)[1].isdigit(): + cnt += 1 + dataall = h5py.File(x,'r') + if cnt == 1: + label = np.asarray(dataall['label']) + data = np.asarray(dataall['data']) + else: + label = np.vstack((label,dataall['label'])) + data = np.vstack((data,dataall['data'])) + return (label,data) + if __name__ == "__main__": args = parse_args() @@ -90,11 +108,18 @@ def probedata(dataprefix): model.compile(loss=best_lossfunc, optimizer=best_optim,metrics=['accuracy']) checkpointer = ModelCheckpoint(filepath=weight_file, verbose=1, save_best_only=True) - trainbatch_num,train_size = probedata(data1prefix+'.train.h5.batch') - validbatch_num,valid_size = probedata(data1prefix+'.valid.h5.batch') - history_callback = model.fit_generator(mymodel.BatchGenerator2(args.batchsize,trainbatch_num,'train',topdir,data_code)\ - ,train_size,args.trainepoch,validation_data=mymodel.BatchGenerator2(args.batchsize,validbatch_num,'valid',topdir,data_code)\ - ,nb_val_samples=valid_size,callbacks = [checkpointer]) + + if args.datamode == 'generator': + trainbatch_num,train_size = probedata(data1prefix+'.train.h5.batch') + validbatch_num,valid_size = probedata(data1prefix+'.valid.h5.batch') + history_callback = model.fit_generator(mymodel.BatchGenerator2(args.batchsize,trainbatch_num,'train',topdir,data_code)\ + ,np.ceil(float(train_size)/args.batchsize),args.trainepoch,validation_data=mymodel.BatchGenerator2(args.batchsize,validbatch_num,'valid',topdir,data_code)\ + ,validation_steps=np.ceil(float(valid_size)/args.batchsize),callbacks = [checkpointer]) + else: + Y_train, traindata = readdata(data1prefix+'.train.h5.batch') + Y_valid, validdata = readdata(data1prefix+'.valid.h5.batch') + history_callback = model.fit(traindata, Y_train, batch_size=args.batchsize, epochs=args.trainepoch,validation_data=(validdata,Y_valid),callbacks = [checkpointer]) + model.save_weights(last_weight_file, overwrite=True) system('touch '+join(outdir,model_arch+'.traindone')) @@ -116,8 +141,8 @@ def probedata(dataprefix): trainbatch_num,train_size = probedata(data1prefix+'.train.h5.batch') validbatch_num,valid_size = probedata(data1prefix+'.valid.h5.batch') history_callback = model.fit_generator(mymodel.BatchGenerator2(args.batchsize,trainbatch_num,'train',topdir,data_code)\ - ,train_size,args.trainepoch,validation_data=mymodel.BatchGenerator2(args.batchsize,validbatch_num,'valid',topdir,data_code)\ - ,nb_val_samples=valid_size,callbacks = [checkpointer]) + ,np.ceil(float(train_size)/args.batchsize),args.trainepoch,validation_data=mymodel.BatchGenerator2(args.batchsize,validbatch_num,'valid',topdir,data_code)\ + ,validation_steps=np.ceil(float(valid_size)/args.batchsize),callbacks = [checkpointer]) model.save_weights(new_last_weight_file, overwrite=True) system('touch '+join(outdir,model_arch+'.traindone')) @@ -149,8 +174,8 @@ def probedata(dataprefix): ## Predict on new data model = model_from_json(open(architecture_file).read()) model.load_weights(weight_file) - best_optim = cPickle.load(open(optimizer_file,'rb')) - model.compile(loss='binary_crossentropy', optimizer=best_optim,metrics=['accuracy']) + best_optim, best_lossfunc = cPickle.load(open(optimizer_file,'rb')) + model.compile(loss=best_lossfunc, optimizer=best_optim,metrics=['accuracy']) predict_batch_num = len([ 1 for x in subprocess.check_output('ls '+args.infile+'*', shell=True).split('\n')[:-1] if args.infile in x if x.split(args.infile)[1].isdigit()]) print('Total number of batch to predict:',predict_batch_num)