Skip to content

Commit

Permalink
update to work with keras 2
Browse files Browse the repository at this point in the history
  • Loading branch information
haoyangz committed May 9, 2017
1 parent c1f32ce commit f89604a
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 11 deletions.
4 changes: 2 additions & 2 deletions example/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import h5py
from os.path import join,exists
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation,Flatten,Merge
from keras.layers.core import Dense, Dropout, Activation,Flatten
from keras.layers.convolutional import Convolution2D,MaxPooling2D
from keras.optimizers import Adadelta,RMSprop
from hyperas.distributions import choice, uniform, conditional
Expand Down Expand Up @@ -41,7 +41,7 @@ def model(X_train, Y_train, X_test, Y_test):
model.add(Activation('softmax'))

myoptimizer = RMSprop(lr={{choice([0.01,0.001,0.0001])}}, rho=0.9, epsilon=1e-06)
mylossfunc = 'binary_crossentropy'
mylossfunc = 'categorical_crossentropy'
model.compile(loss=mylossfunc, optimizer=myoptimizer,metrics=['accuracy'])
model.fit(X_train, Y_train, batch_size=100, nb_epoch=5,validation_split=0.1)

Expand Down
43 changes: 34 additions & 9 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def parse_args():
parser.add_argument("-l", "--lweightfile",default=None,help="Weight file after training")
parser.add_argument("-r", "--retrain",default=None,help="codename for the retrain run")
parser.add_argument("-rw", "--rweightfile",default='',help="Weight file to load for retraining")
parser.add_argument("-dm", "--datamode",default='generator',help="Weight file to load for retraining")

return parser.parse_args()

def probedata(dataprefix):
Expand All @@ -43,6 +45,22 @@ def probedata(dataprefix):
samplecnt += len(data['label'])
return (cnt,samplecnt)

def readdata(dataprefix):
allfiles = subprocess.check_output('ls '+dataprefix+'*', shell=True).split('\n')[:-1]
cnt = 0
samplecnt = 0
for x in allfiles:
if x.split(dataprefix)[1].isdigit():
cnt += 1
dataall = h5py.File(x,'r')
if cnt == 1:
label = np.asarray(dataall['label'])
data = np.asarray(dataall['data'])
else:
label = np.vstack((label,dataall['label']))
data = np.vstack((data,dataall['data']))
return (label,data)

if __name__ == "__main__":

args = parse_args()
Expand Down Expand Up @@ -90,11 +108,18 @@ def probedata(dataprefix):
model.compile(loss=best_lossfunc, optimizer=best_optim,metrics=['accuracy'])

checkpointer = ModelCheckpoint(filepath=weight_file, verbose=1, save_best_only=True)
trainbatch_num,train_size = probedata(data1prefix+'.train.h5.batch')
validbatch_num,valid_size = probedata(data1prefix+'.valid.h5.batch')
history_callback = model.fit_generator(mymodel.BatchGenerator2(args.batchsize,trainbatch_num,'train',topdir,data_code)\
,train_size,args.trainepoch,validation_data=mymodel.BatchGenerator2(args.batchsize,validbatch_num,'valid',topdir,data_code)\
,nb_val_samples=valid_size,callbacks = [checkpointer])

if args.datamode == 'generator':
trainbatch_num,train_size = probedata(data1prefix+'.train.h5.batch')
validbatch_num,valid_size = probedata(data1prefix+'.valid.h5.batch')
history_callback = model.fit_generator(mymodel.BatchGenerator2(args.batchsize,trainbatch_num,'train',topdir,data_code)\
,np.ceil(float(train_size)/args.batchsize),args.trainepoch,validation_data=mymodel.BatchGenerator2(args.batchsize,validbatch_num,'valid',topdir,data_code)\
,validation_steps=np.ceil(float(valid_size)/args.batchsize),callbacks = [checkpointer])
else:
Y_train, traindata = readdata(data1prefix+'.train.h5.batch')
Y_valid, validdata = readdata(data1prefix+'.valid.h5.batch')
history_callback = model.fit(traindata, Y_train, batch_size=args.batchsize, epochs=args.trainepoch,validation_data=(validdata,Y_valid),callbacks = [checkpointer])


model.save_weights(last_weight_file, overwrite=True)
system('touch '+join(outdir,model_arch+'.traindone'))
Expand All @@ -116,8 +141,8 @@ def probedata(dataprefix):
trainbatch_num,train_size = probedata(data1prefix+'.train.h5.batch')
validbatch_num,valid_size = probedata(data1prefix+'.valid.h5.batch')
history_callback = model.fit_generator(mymodel.BatchGenerator2(args.batchsize,trainbatch_num,'train',topdir,data_code)\
,train_size,args.trainepoch,validation_data=mymodel.BatchGenerator2(args.batchsize,validbatch_num,'valid',topdir,data_code)\
,nb_val_samples=valid_size,callbacks = [checkpointer])
,np.ceil(float(train_size)/args.batchsize),args.trainepoch,validation_data=mymodel.BatchGenerator2(args.batchsize,validbatch_num,'valid',topdir,data_code)\
,validation_steps=np.ceil(float(valid_size)/args.batchsize),callbacks = [checkpointer])

model.save_weights(new_last_weight_file, overwrite=True)
system('touch '+join(outdir,model_arch+'.traindone'))
Expand Down Expand Up @@ -149,8 +174,8 @@ def probedata(dataprefix):
## Predict on new data
model = model_from_json(open(architecture_file).read())
model.load_weights(weight_file)
best_optim = cPickle.load(open(optimizer_file,'rb'))
model.compile(loss='binary_crossentropy', optimizer=best_optim,metrics=['accuracy'])
best_optim, best_lossfunc = cPickle.load(open(optimizer_file,'rb'))
model.compile(loss=best_lossfunc, optimizer=best_optim,metrics=['accuracy'])

predict_batch_num = len([ 1 for x in subprocess.check_output('ls '+args.infile+'*', shell=True).split('\n')[:-1] if args.infile in x if x.split(args.infile)[1].isdigit()])
print('Total number of batch to predict:',predict_batch_num)
Expand Down

0 comments on commit f89604a

Please sign in to comment.