18.6使用官方的slim訓練模型並finetune微調

本文接着博客18.1,請先參考https://blog.csdn.net/u010397980/article/details/84930880

如果嫌自己寫網絡結構有些麻煩,可以直接從tensorflow的slim包中直接調用模型結構。從tensoflow的github下載slim包,該包中有很多模型也包含在imagenet預訓練的權重。詳見https://github.com/tensorflow/models/tree/master/research/slim#Pretrained

在這裏下載slim,別忘了解壓:鏈接: https://pan.baidu.com/s/1rkMcl4bYimFJAQoyZ3z1WA 提取碼: ebbi

或者:git clone https://github.com/tensorflow/models/

別看他們的說明進行安裝,無需安裝,無需安裝,無需安裝。。。

我們這裏只是使用官方slim包中的網絡定義代碼(mobilenetv2,當然也可以用其他網絡),不用他們的slim直接進行訓練。

新建train2.py文件內容如下,注意修改sys.path.append("xxxxxx/models/research/slim")爲自己下好slim包的路徑。

mobilenetv2預訓練的權重mobilenet_v2_1.0_224.ckpt可以在這裏下載,別忘了解壓http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz

#coding:utf-8
import os, sys
import numpy as np
import tensorflow as tf
import glob
import tensorflow.contrib.slim as slim

# import tensorflow.contrib.slim.python.slim.nets.inception_v3 as inception_v3
# from tensorflow.contrib.slim.python.slim.nets import resnet_v2

sys.path.append("/home/ming/models/research/slim")
from nets.mobilenet import mobilenet_v2


def get_files(file_dir):
    image_list, label_list = [], []
    for label in os.listdir(file_dir):
        if os.path.isfile(os.path.join(file_dir, label)):
            continue
        for img in glob.glob(os.path.join(file_dir, label, "*.jpg")):
            image_list.append(img)
            label_list.append(int(label_dict[label]))
    print('There are %d data' %(len(image_list)))
    temp = np.array([image_list, label_list])
    temp = temp.transpose()
    np.random.shuffle(temp)
    image_list = list(temp[:, 0])
    label_list = list(temp[:, 1])
    label_list = [int(i) for i in label_list]
    return image_list, label_list

label_dict, label_dict_res = {}, {}
# 手動指定一個從類別到label的映射關係
with open("label.txt", 'r') as f:
    for line in f.readlines():
        folder, label = line.strip().split(':')[0], line.strip().split(':')[1]
        label_dict[folder] = label
        label_dict_res[label] = folder
print(label_dict)


finetune_model = 'mobilenet_v2_1.0_224.ckpt'
checkpoint_not_load_scope = 'MobilenetV2/Logits'  # not load fc layer
trainable_scope = 'MobilenetV2/Logits'  # train fc layer when finetune

train_dir = "/home/ming/data/yourdatapath"
logs_train_dir = './model_save'
init_lr = 0.01
weight_decay = 0.0001
BATCH_SIZE = 128
freeze_basemodel = False #True
train, train_label = get_files(train_dir)
one_epoch_step = len(train) / BATCH_SIZE
decay_steps = int(30*one_epoch_step)
MAX_STEP = 100*one_epoch_step
N_CLASSES = len(label_dict)
IMG_W = 224
IMG_H = 224
CAPACITY = 1000 + 3 * BATCH_SIZE
display_step = 100
batch_norm_params = {
        # Decay for the moving averages.
        'decay': 0.997,
        # epsilon to prevent 0s in variance.
        'epsilon': 0.001,
        # force in-place updates of mean and variance estimates
        'updates_collections': None,
        # Moving averages ends up in the trainable variables collection
        'variables_collections': [ tf.GraphKeys.TRAINABLE_VARIABLES ],
    }
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # gpu編號
os.environ["TF_CPP_MIN_LOG_LEVEL"] = '2'
config = tf.ConfigProto()
config.gpu_options.allow_growth = True # 設置最小gpu使用量


def get_batch(image, label, image_W, image_H, batch_size, capacity):
    image = tf.cast(image, tf.string)
    label = tf.cast(label, tf.int32)
    # make an input queue
    input_queue = tf.train.slice_input_producer([image, label], shuffle=False)
    label = input_queue[1]
    image_contents = tf.read_file(input_queue[0])
    image = tf.image.decode_jpeg(image_contents, channels=3)
    # 數據增強
    #image = tf.image.resize_image_with_pad(image, target_height=image_W, target_width=image_H)
    image = tf.image.resize_images(image, (image_W, image_H))
    # random rotate 90
    if np.random.randn() > 0:
        image = tf.image.transpose_image(image)
    # 隨機左右翻轉
    image = tf.image.random_flip_left_right(image)
    # 隨機上下翻轉
    image = tf.image.random_flip_up_down(image)
    # 隨機設置圖片的亮度
    image = tf.image.random_brightness(image, max_delta=32/255.0)
    # 隨機設置圖片的對比度
    image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
    # 隨機設置圖片的色度
    #image = tf.image.random_hue(image, max_delta=0.05)
    # 隨機設置圖片的飽和度
    #image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
    # 標準化,使圖片的均值爲0,方差爲1
    image = tf.image.per_image_standardization(image)
    image_batch, label_batch = tf.train.batch([image, label],
                                                batch_size=batch_size,
                                                num_threads=64,
                                                capacity=capacity)
    tf.summary.image("input_img", image_batch, max_outputs=5)
    label_batch = tf.reshape(label_batch, [batch_size])
    image_batch = tf.cast(image_batch, tf.float32)
    return image_batch, label_batch


def get_finetuned_variables():
    exclusions = [scope.strip() for scope in checkpoint_not_load_scope.split(',')]
    variables_to_restore = []

    # 枚舉inception-v3模型中所有的參數,然後判斷是否需要從加載列表中移除。
    for var in slim.get_model_variables():
        excluded = False
        for exclusion in exclusions:
            if var.op.name.startswith(exclusion):
                excluded = True
                break
        if not excluded:
            variables_to_restore.append(var)
    # print("restore variables {}".format(variables_to_restore))
    return variables_to_restore


# 獲取所有需要訓練的變量列表。
def get_trainable_variables():
    scopes = [scope.strip() for scope in trainable_scope.split(',')]
    variables_to_trian = []

    # 枚舉所有需要訓練的參數前綴,並通過這些前綴找到所有需要訓練的參數。
    for scope in scopes:
        variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)
        variables_to_trian.extend(variables)
    return variables_to_trian


def main():
    global_step = tf.Variable(0, name='global_step', trainable=False)
    # label without one-hot
    batch_train, batch_labels = get_batch(train,
                                          train_label,
                                          IMG_W,
                                          IMG_H,
                                          BATCH_SIZE, 
                                          CAPACITY)
    # network, set is_training=False when predict img
    # with slim.arg_scope([slim.conv2d, slim.fully_connected], normalizer_fn=slim.batch_norm, normalizer_params=batch_norm_params):
    #     # logits, _ = inception_v3.inception_v3(batch_train, num_classes=N_CLASSES, is_training=True)
    #     logits, _ = resnet_v2.resnet_v2_152(batch_train, num_classes=N_CLASSES, is_training=True)
    #     logits = tf.reshape(logits, [-1, N_CLASSES])
    with slim.arg_scope([slim.conv2d, slim.fully_connected], normalizer_fn=slim.batch_norm, normalizer_params=batch_norm_params, weights_regularizer=slim.l2_regularizer(weight_decay)):
    # with slim.arg_scope(mobilenet_v2.training_scope()):
        logits, _ = mobilenet_v2.mobilenet(batch_train, num_classes=N_CLASSES, is_training=True)
    print logits.get_shape()
    # loss
    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=batch_labels)
    loss = tf.reduce_mean(cross_entropy, name='loss')
    regularization_losses_n = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
    loss = tf.add_n([loss] + regularization_losses_n, name='total_loss')
    tf.summary.scalar('train_loss', loss)
    # optimizer
    lr = tf.train.exponential_decay(learning_rate=init_lr, global_step=global_step, decay_steps=decay_steps, decay_rate=0.1)
    tf.summary.scalar('learning_rate', lr)

    # set optimizer, trainable variable
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        if freeze_basemodel:
            trainable_variable = get_trainable_variables()
            for var in trainable_variable:
                print "only train variable:", var
            optimizer = tf.train.AdamOptimizer(learning_rate=lr).minimize(loss, global_step=global_step, var_list=trainable_variable)
        else:
            print("train all variable")
            optimizer = tf.train.AdamOptimizer(learning_rate=lr).minimize(loss, global_step=global_step)  #train all var

    # accuracy
    correct = tf.nn.in_top_k(logits, batch_labels, 1)
    correct = tf.cast(correct, tf.float16)
    accuracy = tf.reduce_mean(correct)
    tf.summary.scalar('train_acc', accuracy)

    summary_op = tf.summary.merge_all()
    sess = tf.Session(config=config)
    train_writer = tf.summary.FileWriter(logs_train_dir, sess.graph)

    # load model
    load_finetune_model = slim.assign_from_checkpoint_fn(finetune_model, get_finetuned_variables(),
                                                         ignore_missing_vars=True)
    saver = tf.train.Saver(max_to_keep=100)
    sess.run(tf.global_variables_initializer())
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    #saver.restore(sess, logs_train_dir+'/model.ckpt-174000')
    print('Loading finetune model from %s' % finetune_model)
    load_finetune_model(sess)

    try:
        for step in range(MAX_STEP):
            if coord.should_stop():
                    break
            _, learning_rate, tra_loss, tra_acc = sess.run([optimizer, lr, loss, accuracy])
            if step % display_step == 0:
                print('Epoch:%3d/%d, Step:%6d/%d, lr:%f, train loss:%.4f, train acc:%.2f%%' %(step/one_epoch_step+1, MAX_STEP/one_epoch_step, step+display_step, MAX_STEP, learning_rate, tra_loss, tra_acc*100.0))
                summary_str = sess.run(summary_op)
                train_writer.add_summary(summary_str, step)
            if step % 500 == 0 or (step + 1) == MAX_STEP:
                checkpoint_path = os.path.join(logs_train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)
                print("save model", checkpoint_path)
    except tf.errors.OutOfRangeError:
        print('Done training -- epoch limit reached')
    finally:
        coord.request_stop()
        
    coord.join(threads)
    sess.close()
    

if __name__ == '__main__':
    main()

python train2.py 即可訓練。

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章