Inception-Resnet-V2 Pre-train 總結

工作之後有點小忙碌,一直都沒來得及更新博客。這是工作之後的第一篇博客。Mark一下自己,快要一個月了,快要發工資了,R神很高興啊。

今天在工作培訓中,需要運用InceptionV4-Resnet-V2進行圖片的分類。

由於InceptionV4的網絡很深,所以直接訓練是很不理智的,於是下載了Pre-train的模型。

網絡文章地址:http://arxiv.org/abs/1602.07261

源代碼地址:https://download.csdn.net/download/weixin_41153216/10591023

如果需要Pretrain模型可以去Github上下載。

代碼的主程序如下

def main():
    ImageInform = SaveFile()
    #print(ImageInform[0])
    Train_Set, Test_Set = TrTsSet(ImageInform)
    Train_Num = len(Train_Set)
    #Create the log directory here. Must be done here otherwise import will activate this unneededly.
    if not os.path.exists(log_dir):
        os.mkdir(log_dir)
    # ------- Training Process --------
    with tf.Graph().as_default() as graph:
        tf.logging.set_verbosity(tf.logging.INFO) #Set the verbosity to INFO level
        x, y_ = inputs()
        #Create the model inference
        with slim.arg_scope(inception_resnet_v2_arg_scope()):
            logits, end_points = inception_resnet_v2(x, num_classes = num_classes, is_training = True)

        #Define the scopes that you want to exclude for restoration
        exclude = ['InceptionResnetV2/Logits', 'InceptionResnetV2/AuxLogits']
        variables_to_restore = slim.get_variables_to_restore(exclude = exclude)

        #Performs the equivalent to tf.nn.sparse_softmax_cross_entropy_with_logits but enhanced with checks
        loss = tf.losses.softmax_cross_entropy(onehot_labels = y_, logits = logits)
        total_loss = tf.losses.get_total_loss() #obtain the regularization losses as well

        #Create the global step for monitoring the learning_rate and training.
        global_step = get_or_create_global_step()

        lr = tf.train.exponential_decay(learning_rate = initial_learning_rate,
            global_step = global_step, decay_steps = decay_steps,
            decay_rate = learning_rate_decay_factor, staircase = True)
        optimizer = tf.train.AdamOptimizer(learning_rate=lr)
        #Create the train_op.
        train_op = slim.learning.create_train_op(total_loss, optimizer)

        #State the metrics that you want to predict. We get a predictions that is not one_hot_encoded.
        predictions = tf.argmax(end_points['Predictions'], 1)
        probabilities = end_points['Predictions']
        real_label = tf.argmax(y_, 1)
        accuracy, accuracy_update = tf.contrib.metrics.streaming_accuracy(predictions, real_label)
        metrics_op = tf.group(accuracy_update, probabilities)

        #Now finally create all the summaries you need to monitor and group them into one summary op.
        tf.summary.scalar('losses/Total_Loss', total_loss)
        tf.summary.scalar('accuracy', accuracy)
        tf.summary.scalar('learning_rate', lr)
        my_summary_op = tf.summary.merge_all()

        #Now we need to create a training step function that runs both the train_op, metrics_op and updates the global_step concurrently.
        def train_step(sess, train_op, global_step,batch_x,batch_y):
            '''
            Simply runs a session for the three arguments provided and gives a logging on the time elapsed for each global step
            '''
            #Check the time for each sess run
            start_time = time.time()
            total_loss, global_step_count, _ = sess.run([train_op, global_step, metrics_op],feed_dict={x:batch_x,y_:batch_y})
            time_elapsed = time.time() - start_time

            #Run the logging to print some results
            logging.info('global step %s: loss: %.4f (%.2f sec/step)', global_step_count, total_loss, time_elapsed)

            return total_loss, global_step_count

        #Now we create a saver function that actually restores the variables from a checkpoint file in a sess
        saver = tf.train.Saver(variables_to_restore)
        def restore_fn(sess):
            return saver.restore(sess, checkpoint_file)

        #Define your supervisor for running a managed session. Do not run the summary_op automatically or else it will consume too much memory
        sv = tf.train.Supervisor(logdir = log_dir, summary_op = None, init_fn = restore_fn)
        #print('I have done')
        #Run the managed session
        with sv.managed_session() as sess:
            for step in range(num_steps):
                if ((step*batch_size)%Train_Num == 0):
                    permutation = np.zeros((Train_Num,1))
                    permutation = np.random.permutation(Train_Num)
                Batch = permutation[(step%15)*batch_size:((step%15)+1)*batch_size]
                #import pdb; pdb.set_trace()
                batch_x, batch_y = ANB(Batch,Train_Set)
                #At the start of every epoch, show the vital information:
                if step % display_step == 0:
                    logging.info('Steps: %s', step)
                    learning_rate_value, accuracy_value = sess.run([lr, accuracy],feed_dict={x:batch_x,y_:batch_y})
                    #logging.info('Current Learning Rate: %s', learning_rate_value)
                    logging.info('Current Streaming Accuracy: %s', accuracy_value)

                    # optionally, print your logits and predictions for a sanity check that things are going fine.
                    logits_value, probabilities_value, predictions_value, labels_value = \
                    sess.run([logits, probabilities, predictions, real_label],feed_dict={x:batch_x,y_:batch_y})
                    #print('logits:', logits_value)
                    #print('Probabilities:', probabilities_value)
                    #print('predictions:', predictions_value)
                    #print('Labels:', labels_value)

                    loss, _ = train_step(sess, train_op, sv.global_step,batch_x,batch_y)
                    summaries = sess.run(my_summary_op,feed_dict={x:batch_x,y_:batch_y})
                    sv.summary_computed(sess, summaries)

                #If not, simply run the training step
                else:
                    loss, _ = train_step(sess, train_op, sv.global_step,batch_x,batch_y)

                # 測試,每隔幾部存入一個數據
                '''
                if (step%1) == 0:
                    sv.saver.save(sess, sv.save_path, global_step = sv.global_step)
                '''
            #We log the final training loss and accuracy
            logging.info('Final Loss: %s', loss)
            logging.info('Final Accuracy: %s', sess.run(accuracy,feed_dict={x:batch_x,y_:batch_y}))

            #Once all the training has been done, save the log files and checkpoint model
            logging.info('Finished training! Saving model to disk now.')
            saver.save(sess, "./sc15_model.ckpt")
            #sv.saver.save(sess, sv.save_path, global_step = sv.global_step)

實驗結果如下:

圖中的正確率是對於訓練集的。

在調試過程中的總結:

從圖中可以看出,經過115個batch_size的訓練,正確率可以達到90.86%。但是Loss已經不變了,可以考慮更改初始學習率。

        每一步的訓練大概需要30 s,這樣的話,訓練過程就會比較久。

        改進方法,從服務器端運行。

        如何從checkpoint的地方讀入數據?

        網絡中的模型是自動存儲的,每隔10分鐘,存儲4個文件。Checkpoint文件會進行覆蓋,其它三個文件就是模型,然後下次每次都從上次的checkpoint開始訓練。

        模型的載入和讀取,參考:https://blog.csdn.net/lwplwf/article/details/62419087 網絡模型的保存和讀取

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章