工作之後有點小忙碌,一直都沒來得及更新博客。這是工作之後的第一篇博客。Mark一下自己,快要一個月了,快要發工資了,R神很高興啊。
今天在工作培訓中,需要運用InceptionV4-Resnet-V2進行圖片的分類。
由於InceptionV4的網絡很深,所以直接訓練是很不理智的,於是下載了Pre-train的模型。
網絡文章地址:http://arxiv.org/abs/1602.07261
源代碼地址:https://download.csdn.net/download/weixin_41153216/10591023
如果需要Pretrain模型可以去Github上下載。
代碼的主程序如下
def main():
ImageInform = SaveFile()
#print(ImageInform[0])
Train_Set, Test_Set = TrTsSet(ImageInform)
Train_Num = len(Train_Set)
#Create the log directory here. Must be done here otherwise import will activate this unneededly.
if not os.path.exists(log_dir):
os.mkdir(log_dir)
# ------- Training Process --------
with tf.Graph().as_default() as graph:
tf.logging.set_verbosity(tf.logging.INFO) #Set the verbosity to INFO level
x, y_ = inputs()
#Create the model inference
with slim.arg_scope(inception_resnet_v2_arg_scope()):
logits, end_points = inception_resnet_v2(x, num_classes = num_classes, is_training = True)
#Define the scopes that you want to exclude for restoration
exclude = ['InceptionResnetV2/Logits', 'InceptionResnetV2/AuxLogits']
variables_to_restore = slim.get_variables_to_restore(exclude = exclude)
#Performs the equivalent to tf.nn.sparse_softmax_cross_entropy_with_logits but enhanced with checks
loss = tf.losses.softmax_cross_entropy(onehot_labels = y_, logits = logits)
total_loss = tf.losses.get_total_loss() #obtain the regularization losses as well
#Create the global step for monitoring the learning_rate and training.
global_step = get_or_create_global_step()
lr = tf.train.exponential_decay(learning_rate = initial_learning_rate,
global_step = global_step, decay_steps = decay_steps,
decay_rate = learning_rate_decay_factor, staircase = True)
optimizer = tf.train.AdamOptimizer(learning_rate=lr)
#Create the train_op.
train_op = slim.learning.create_train_op(total_loss, optimizer)
#State the metrics that you want to predict. We get a predictions that is not one_hot_encoded.
predictions = tf.argmax(end_points['Predictions'], 1)
probabilities = end_points['Predictions']
real_label = tf.argmax(y_, 1)
accuracy, accuracy_update = tf.contrib.metrics.streaming_accuracy(predictions, real_label)
metrics_op = tf.group(accuracy_update, probabilities)
#Now finally create all the summaries you need to monitor and group them into one summary op.
tf.summary.scalar('losses/Total_Loss', total_loss)
tf.summary.scalar('accuracy', accuracy)
tf.summary.scalar('learning_rate', lr)
my_summary_op = tf.summary.merge_all()
#Now we need to create a training step function that runs both the train_op, metrics_op and updates the global_step concurrently.
def train_step(sess, train_op, global_step,batch_x,batch_y):
'''
Simply runs a session for the three arguments provided and gives a logging on the time elapsed for each global step
'''
#Check the time for each sess run
start_time = time.time()
total_loss, global_step_count, _ = sess.run([train_op, global_step, metrics_op],feed_dict={x:batch_x,y_:batch_y})
time_elapsed = time.time() - start_time
#Run the logging to print some results
logging.info('global step %s: loss: %.4f (%.2f sec/step)', global_step_count, total_loss, time_elapsed)
return total_loss, global_step_count
#Now we create a saver function that actually restores the variables from a checkpoint file in a sess
saver = tf.train.Saver(variables_to_restore)
def restore_fn(sess):
return saver.restore(sess, checkpoint_file)
#Define your supervisor for running a managed session. Do not run the summary_op automatically or else it will consume too much memory
sv = tf.train.Supervisor(logdir = log_dir, summary_op = None, init_fn = restore_fn)
#print('I have done')
#Run the managed session
with sv.managed_session() as sess:
for step in range(num_steps):
if ((step*batch_size)%Train_Num == 0):
permutation = np.zeros((Train_Num,1))
permutation = np.random.permutation(Train_Num)
Batch = permutation[(step%15)*batch_size:((step%15)+1)*batch_size]
#import pdb; pdb.set_trace()
batch_x, batch_y = ANB(Batch,Train_Set)
#At the start of every epoch, show the vital information:
if step % display_step == 0:
logging.info('Steps: %s', step)
learning_rate_value, accuracy_value = sess.run([lr, accuracy],feed_dict={x:batch_x,y_:batch_y})
#logging.info('Current Learning Rate: %s', learning_rate_value)
logging.info('Current Streaming Accuracy: %s', accuracy_value)
# optionally, print your logits and predictions for a sanity check that things are going fine.
logits_value, probabilities_value, predictions_value, labels_value = \
sess.run([logits, probabilities, predictions, real_label],feed_dict={x:batch_x,y_:batch_y})
#print('logits:', logits_value)
#print('Probabilities:', probabilities_value)
#print('predictions:', predictions_value)
#print('Labels:', labels_value)
loss, _ = train_step(sess, train_op, sv.global_step,batch_x,batch_y)
summaries = sess.run(my_summary_op,feed_dict={x:batch_x,y_:batch_y})
sv.summary_computed(sess, summaries)
#If not, simply run the training step
else:
loss, _ = train_step(sess, train_op, sv.global_step,batch_x,batch_y)
# 測試,每隔幾部存入一個數據
'''
if (step%1) == 0:
sv.saver.save(sess, sv.save_path, global_step = sv.global_step)
'''
#We log the final training loss and accuracy
logging.info('Final Loss: %s', loss)
logging.info('Final Accuracy: %s', sess.run(accuracy,feed_dict={x:batch_x,y_:batch_y}))
#Once all the training has been done, save the log files and checkpoint model
logging.info('Finished training! Saving model to disk now.')
saver.save(sess, "./sc15_model.ckpt")
#sv.saver.save(sess, sv.save_path, global_step = sv.global_step)
實驗結果如下:
圖中的正確率是對於訓練集的。
在調試過程中的總結:
從圖中可以看出,經過115個batch_size的訓練,正確率可以達到90.86%。但是Loss已經不變了,可以考慮更改初始學習率。
每一步的訓練大概需要30 s,這樣的話,訓練過程就會比較久。
改進方法,從服務器端運行。
如何從checkpoint的地方讀入數據?
網絡中的模型是自動存儲的,每隔10分鐘,存儲4個文件。Checkpoint文件會進行覆蓋,其它三個文件就是模型,然後下次每次都從上次的checkpoint開始訓練。
模型的載入和讀取,參考:https://blog.csdn.net/lwplwf/article/details/62419087 網絡模型的保存和讀取