You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I user flower dataset(5 classes and 2500 for train and 500 for val) to create tfrecords file and as input to train,but loss can not reduce and validation acc keep 20% , it is my code have some bug when read tfrecords?
import tensorflow as tf
from tqdm import tqdm
import numpy as np
from utils import load_obj
import matplotlib.pyplot as plt
class Train:
"""Trainer class for the CNN.
It's also responsible for loading/saving the model checkpoints from/to experiments/experiment_name/checkpoint_dir"""
def __init__(self, sess, model, data, summarizer):
self.sess = sess
self.model = model
self.args = self.model.args
self.saver = tf.train.Saver(max_to_keep=self.args.max_to_keep,
keep_checkpoint_every_n_hours=10,
save_relative_paths=True)
# Summarizer references
self.data = data
self.summarizer = summarizer
# Initializing the model
self.init = None
self.__init_model()
# Loading the model checkpoint if exists
self.__load_imagenet_weights()
self.__load_model()
IMAGE_SIZE = 224
NUM_CLASSES = 5
############################################################################################################
# Model related methods
def __init_model(self):
print("Initializing the model...")
self.init = tf.group(tf.global_variables_initializer())
self.sess.run(self.init)
print("Model initialized\n\n")
def save_model(self):
"""
Save Model Checkpoint
:return:
"""
print("Saving a checkpoint")
self.saver.save(self.sess, self.args.checkpoint_dir, self.model.global_step_tensor)
print("Checkpoint Saved\n\n")
def __load_model(self):
latest_checkpoint = tf.train.latest_checkpoint(self.args.checkpoint_dir)
if latest_checkpoint:
print("Loading model checkpoint {} ...\n".format(latest_checkpoint))
self.saver.restore(self.sess, latest_checkpoint)
print("Checkpoint loaded\n\n")
else:
print("First time to train!\n\n")
def __load_imagenet_weights(self):
variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
print("No pretrained ImageNet weights exist. Skipping...\n\n")
############################################################################################################
# Train and Test methods
def read_and_decode(self,filename_queue):
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
# Defaults are not specified since both keys are required.
features={
'image': tf.FixedLenFeature([], tf.string),
'target': tf.FixedLenFeature([], tf.int64),
})
# Convert from a scalar string tensor (whose single string has
image = tf.image.decode_jpeg(features['image'], channels=3)
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize_image_with_crop_or_pad(image, 224, 224)
image = tf.clip_by_value(image, 0.0, 1.0)
# Convert label from a scalar uint8 tensor to an int32 scalar.
label = tf.cast(features['target'], tf.int32)
return image, label
def train(self):
filename_queue = tf.train.string_input_producer(["/home/coolpad/juzhitao/shufflenet/mg2033/ShuffleNet/train1.tfrecords"])
#train data
image, label = self.read_and_decode(filename_queue)
images, labels = tf.train.shuffle_batch([image, label], batch_size=50, num_threads=2,capacity=2500,min_after_dequeue=250)
init_op = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
self.sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=self.sess, coord=coord)
for cur_epoch in range(self.model.global_epoch_tensor.eval(self.sess) + 1, self.args.num_epochs + 1, 1):
# Initialize tqdm
num_iterations = self.args.train_data_size // self.args.batch_size
print("num_iterations:::::::::::",num_iterations,' ','train_data_size=',self.args.train_data_size ,'batch_size:', self.args.batch_size)
#tqdm_batch = tqdm([self.data.X_train,self.data.y_train], total=num_iterations,
# desc="Epoch-" + str(cur_epoch) + "-")
# Initialize the current iterations
cur_iteration = 0
# Initialize classification accuracy and loss lists
loss_list = []
acc_list = []
# Loop by the number of iterations
print("#####################################cur_epoch==",cur_epoch)
#for self.data.X_train, self.data.y_train in tqdm_batch:
for step in tqdm(range(0, num_iterations),initial=1, total=num_iterations):
# Get the current iteration for summarizing it
cur_step = self.model.global_step_tensor.eval(self.sess)
image_train, lable_train = self.sess.run([images,labels])
#print(image_train)
# Feed this variables to the network
feed_dict = {self.model.X: images,
self.model.y: labels,
self.model.is_training: True
}
# Run the feed_forward
_, loss, acc = self.sess.run(
[self.model.train_op, self.model.loss, self.model.accuracy],
feed_dict=feed_dict)
# Append loss and accuracy
loss_list += [loss]
acc_list += [acc]
# Update the Global step
self.model.global_step_assign_op.eval(session=self.sess,
feed_dict={self.model.global_step_input: cur_step + 1})
#self.summarizer.add_summary(cur_step, summaries_merged=summaries_merged)
if step >= num_iterations - 1:
avg_loss = np.mean(loss_list)
avg_acc = np.mean(acc_list)
# summarize
#summaries_dict = dict()
#summaries_dict['loss'] = avg_loss
#summaries_dict['acc'] = avg_acc
# summarize
#self.summarizer.add_summary(cur_step, summaries_dict=summaries_dict)
# Update the Current Epoch tensor
self.model.global_epoch_assign_op.eval(session=self.sess,
feed_dict={self.model.global_epoch_input: cur_epoch + 1})
# Print in console
#tqdm_batch.close()
print("Epoch-" + str(cur_epoch) + " | " + "loss: " + str(avg_loss) + " -" + " acc: " + str(
avg_acc)[
:7])
# Break the loop to finalize this epoch
#break
# Update the current iteration
cur_iteration += 1
# Save the current checkpoint
if cur_epoch % self.args.save_model_every == 0 and cur_epoch != 0:
self.save_model()
# Test the model on validation or test data
if cur_epoch % self.args.test_every == 0:
self.test('val')
coord.request_stop()
coord.join(threads)
def test(self, test_type='val'):
filename_queue = tf.train.string_input_producer(["/home/coolpad/juzhitao/shufflenet/mg2033/ShuffleNet/val1.tfrecords"])
#val data
image, label = self.read_and_decode(filename_queue)
images, labels = tf.train.shuffle_batch([image, label], batch_size=50, num_threads=2,capacity=200,min_after_dequeue=50)
init_op = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
self.sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=self.sess, coord=coord)
num_iterations = self.args.test_data_size // self.args.batch_size
#tqdm_batch = tqdm(self.data.generate_batch(type=test_type), total=num_iterations,
# desc='Testing')
# Initialize classification accuracy and loss lists
loss_list = []
acc_list = []
cur_iteration = 0
#for X_batch, y_batch in tqdm_batch:
for step in tqdm(range(0, num_iterations),initial=1, total=num_iterations):
image_val, label_val = self.sess.run([images,labels])
# Feed this variables to the network
feed_dict = {self.model.X: image_val,
self.model.y: label_val,
self.model.is_training: False
}
# Run the feed_forward
loss, acc = self.sess.run(
[self.model.loss, self.model.accuracy],
feed_dict=feed_dict)
# Append loss and accuracy
loss_list += [loss]
acc_list += [acc]
if step >= num_iterations - 1:
avg_loss = np.mean(loss_list)
avg_acc = np.mean(acc_list)
print('Test results | test_loss: ' + str(avg_loss) + ' - test_acc: ' + str(avg_acc)[:7])
#break
cur_iteration += 1
The text was updated successfully, but these errors were encountered:
I user flower dataset(5 classes and 2500 for train and 500 for val) to create tfrecords file and as input to train,but loss can not reduce and validation acc keep 20% , it is my code have some bug when read tfrecords?
import tensorflow as tf
from tqdm import tqdm
import numpy as np
from utils import load_obj
import matplotlib.pyplot as plt
class Train:
"""Trainer class for the CNN.
It's also responsible for loading/saving the model checkpoints from/to experiments/experiment_name/checkpoint_dir"""
The text was updated successfully, but these errors were encountered: