diff --git a/datagenerator.py b/datagenerator.py index 05be261..ef61349 100644 --- a/datagenerator.py +++ b/datagenerator.py @@ -63,16 +63,14 @@ def __init__(self, txt_file, mode, batch_size, num_classes, shuffle=True, self.labels = convert_to_tensor(self.labels, dtype=dtypes.int32) # create dataset - data = Dataset.from_tensor_slices((self.img_paths, self.labels)) + data = tf.data.Dataset.from_tensor_slices((self.img_paths, self.labels)) # distinguish between train/infer. when calling the parsing functions if mode == 'training': - data = data.map(self._parse_function_train, num_threads=8, - output_buffer_size=100*batch_size) + data = data.map(self._parse_function_train, num_parallel_calls=8).prefetch(100 * batch_size) elif mode == 'inference': - data = data.map(self._parse_function_inference, num_threads=8, - output_buffer_size=100*batch_size) + data = data.map(self._parse_function_inference, num_parallel_calls=8).prefetch(100 * batch_size) else: raise ValueError("Invalid mode '%s'." % (mode)) diff --git a/finetune.py b/finetune.py index fc7d29f..362269d 100644 --- a/finetune.py +++ b/finetune.py @@ -27,8 +27,8 @@ class on any given dataset. Specify the configuration settings at the """ # Path to the textfiles for the trainings and validation set -train_file = '/path/to/train.txt' -val_file = '/path/to/val.txt' +train_file = 'train.txt' +val_file = 'val.txt' # Learning params learning_rate = 0.01 @@ -44,8 +44,8 @@ class on any given dataset. Specify the configuration settings at the display_step = 20 # Path for tf.summary.FileWriter and to store model checkpoints -filewriter_path = "/tmp/finetune_alexnet/tensorboard" -checkpoint_path = "/tmp/finetune_alexnet/checkpoints" +filewriter_path = "tmp/tensorboard" +checkpoint_path = "tmp/checkpoints" """ Main Part of the finetuning Script. @@ -93,8 +93,8 @@ class on any given dataset. Specify the configuration settings at the # Op for calculating the loss with tf.name_scope("cross_ent"): - loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=score, - labels=y)) + loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=score, + labels=y)) # Train op with tf.name_scope("train"): @@ -117,7 +117,6 @@ class on any given dataset. Specify the configuration settings at the # Add the loss to summary tf.summary.scalar('cross_entropy', loss) - # Evaluation op: Accuracy of the model with tf.name_scope("accuracy"): correct_pred = tf.equal(tf.argmax(score, 1), tf.argmax(y, 1)) @@ -136,12 +135,11 @@ class on any given dataset. Specify the configuration settings at the saver = tf.train.Saver() # Get the number of training/validation steps per epoch -train_batches_per_epoch = int(np.floor(tr_data.data_size/batch_size)) +train_batches_per_epoch = int(np.floor(tr_data.data_size / batch_size)) val_batches_per_epoch = int(np.floor(val_data.data_size / batch_size)) # Start Tensorflow session with tf.Session() as sess: - # Initialize all variables sess.run(tf.global_variables_initializer()) @@ -158,7 +156,7 @@ class on any given dataset. Specify the configuration settings at the # Loop over number of epochs for epoch in range(num_epochs): - print("{} Epoch number: {}".format(datetime.now(), epoch+1)) + print("{} Epoch number: {}".format(datetime.now(), epoch + 1)) # Initialize iterator with the training dataset sess.run(training_init_op) @@ -179,7 +177,7 @@ class on any given dataset. Specify the configuration settings at the y: label_batch, keep_prob: 1.}) - writer.add_summary(s, epoch*train_batches_per_epoch + step) + writer.add_summary(s, epoch * train_batches_per_epoch + step) # Validate the model on the entire validation set print("{} Start validation".format(datetime.now())) @@ -187,7 +185,6 @@ class on any given dataset. Specify the configuration settings at the test_acc = 0. test_count = 0 for _ in range(val_batches_per_epoch): - img_batch, label_batch = sess.run(next_batch) acc = sess.run(accuracy, feed_dict={x: img_batch, y: label_batch, @@ -201,7 +198,7 @@ class on any given dataset. Specify the configuration settings at the # save checkpoint of the model checkpoint_name = os.path.join(checkpoint_path, - 'model_epoch'+str(epoch+1)+'.ckpt') + 'model_epoch' + str(epoch + 1) + '.ckpt') save_path = saver.save(sess, checkpoint_name) print("{} Model checkpoint saved at {}".format(datetime.now(),