TensorFlow——Fashion mnist示例

在這裏插入圖片描述
在這裏插入圖片描述
配置文件

{
  "project_name": "fashion-mnist",
  "save_root": "../experiments/",
  "num_workers": 8,
  "train_batch_size": 64,
  "test_batch_size": 128,
  "prefecth_times": 1,
  "max_epochs": 200,
  "test_by_epochs": 10,
  "steps_per_epoch": 100,
  "max_to_keep": 5,
  "lr": 0.1
}

模型訓練與保存

#! -*- coding: utf-8 -*-
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
from tensorflow.contrib import layers
from tensorflow import keras
from bunch import Bunch
import tensorflow as tf
import numpy as np
import time
import json
import os
import shutil


def logger(info, level=0):
    """
    打印日誌
    :param info: 日誌信息
    :param level: 日誌等級,默認爲INFO
    :return: None
    """
    levels = ["INFO", "WARNING", "ERROR"]
    print("[{} {}] {}".format(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), levels[level], info))


def read_config(json_file):
    """
    將config的json文件轉換爲config類
    :param json_file: json文件
    :return: config類
    """
    with open(json_file) as config_file:
        config_json = json.load(config_file)
        config = Bunch(config_json)
        return config


def mkdir(dir_name, delete):
    """
    創建目錄
    :param dir_name: 目錄名稱
    :param delete: 是否刪除已存在的目錄
    :return: None
    """
    if os.path.exists(dir_name):
        if delete:
            logger("%s is existed. Deleting ..." % dir_name)
            shutil.rmtree(dir_name)
            logger("Delete succeed. Recreating %s ..." % dir_name)
            os.makedirs(dir_name)
            logger("Create succeed.")
        else:
            logger("%s is existed.")
    else:
        logger("Creating %s ..." % dir_name)
        os.makedirs(dir_name)
        logger("Create succeed.")


def get_config(json_file, delete):
    """
    獲取config類
    :param json_file: json文件
    :param delete: 是否刪除已存在的目錄
    :return: config類
    """
    config = read_config(json_file)
    # 保存日誌文件目錄
    config.log_dir = os.path.join(config.save_root, config.project_name, "logs/")
    # 保存檢查點文件目錄
    config.ck_dir = os.path.join(config.save_root, config.project_name, "checkpoints/")

    # 創建目錄
    mkdir(config.log_dir, delete=delete)
    mkdir(config.ck_dir, delete=delete)
    return config


def _parse_fn(x, label):
    """
    數據轉換函數
    :param x: 圖像數據
    :param label: 圖片標籤
    :return: x, label
    """
    x = tf.expand_dims(x, axis=2)
    return x, label


def make_variable(name, shape, initializer, trainable=True):
    """
    創建變量
    :param name:
    :param shape:
    :param initializer:
    :param trainable:
    :return:
    """
    with tf.variable_scope(name):
        return tf.get_variable(name, shape, tf.float32, initializer=initializer, trainable=trainable)


def batch_normal(x, name, training, activation=tf.nn.relu):
    """
    Batch Normalization層
    :param x:
    :param name:
    :param training:
    :param activation:
    :return:
    """
    with tf.name_scope(name), tf.variable_scope(name):
        return layers.batch_norm(x, decay=0.9, activation_fn=activation, is_training=training, scope="batch_normal")


def max_pool(x, k, s, name, padding="SAME"):
    """
    Max pooling層
    :param x:
    :param k:
    :param s:
    :param name:
    :param padding:
    :return:
    """
    with tf.name_scope(name):
        return tf.nn.max_pool(x, [1, k, k, 1], [1, s, s, 1], padding, name="max_pool")


def fc(x, c_out, name, use_bias=True, activation=True):
    """
    全連接層
    :param x:
    :param c_out:
    :param name:
    :param use_bias:
    :param activation:
    :return:
    """
    c_in = x.get_shape().as_list()[-1]
    with tf.name_scope(name), tf.variable_scope(name):
        weights = make_variable("weights", [c_in, c_out], initializer=tf.random_normal_initializer())
        outputs = tf.matmul(x, weights, name="matmul")
        if use_bias:
            biases = make_variable("biases", [c_out], initializer=tf.constant_initializer(0.001))
            outputs = tf.nn.bias_add(outputs, biases)
        if activation:
            outputs = tf.nn.relu(outputs, "relu")
        return outputs


def conv(x, k, s, c_in, c_out, name, use_bias=True, padding="SAME", activation=None):
    """
    卷積層
    :param x:
    :param k:
    :param s:
    :param c_in:
    :param c_out:
    :param name:
    :param use_bias:
    :param padding:
    :param activation:
    :return:
    """
    with tf.name_scope(name), tf.variable_scope(name):
        weights = make_variable("weights", [k, k, c_in, c_out], tf.random_normal_initializer())
        outputs = tf.nn.conv2d(x, weights, [1, s, s, 1], padding, name="conv")
        if use_bias:
            biases = make_variable("biases", [c_out], tf.constant_initializer(0.001))
            outputs = tf.nn.bias_add(outputs, biases)

        if activation:
            outputs = tf.nn.relu(outputs)
        return outputs


def train(config_file="fashion-mnist.json"):
    """
    主函數
    :param config_file: 配置文件
    :return:
    """
    # 解析配置文件,將配置項轉換爲配置類
    config = get_config(config_file, delete=True)

    # Fashiong mnist數據集
    fashion_mnist = keras.datasets.fashion_mnist
    # 加載數據集
    (x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()

    # 數據集預處理
    train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(buffer_size=1000). \
        map(_parse_fn, config.num_workers). \
        batch(batch_size=config.train_batch_size). \
        repeat(). \
        prefetch(buffer_size=config.train_batch_size * config.prefecth_times)

    test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)).map(_parse_fn,
                                                                            num_parallel_calls=config.num_workers). \
        batch(batch_size=config.test_batch_size). \
        prefetch(config.test_batch_size * config.prefecth_times)

    # 創建數據迭代器
    train_iterator = train_dataset.make_one_shot_iterator()
    test_iterator = test_dataset.make_initializable_iterator()

    # 獲取下一批數據
    train_next = train_iterator.get_next()
    test_next = test_iterator.get_next()

    # 創建佔位符
    x = tf.placeholder(tf.float32, [None, 28, 28, 1])
    y = tf.placeholder(tf.int64, [None, ])
    training = tf.placeholder(tf.bool)

    # --------------- 構建網絡 ----------------------
    with tf.name_scope("conv1"):
        conv1 = conv(x, 5, 2, 1, 32, name="conv1")
        print("conv1: {}".format(conv1.shape))
        bn1 = batch_normal(conv1, "bn1", training)
        pool1 = max_pool(bn1, 2, 2, "pool1")
        print("pool1: {}".format(pool1.shape))

    with tf.name_scope("conv2"):
        conv2 = conv(pool1, 3, 1, 32, 64, "conv2")
        print("conv2: {}".format(conv2.shape))
        bn2 = batch_normal(conv2, "bn2", training)
        pool2 = batch_normal(bn2, "pool_2", training)
        print("pool2: {}".format(pool2.shape))

    with tf.name_scope("conv3"):
        conv3 = conv(pool2, 3, 1, 64, 128, "conv3")
        print("conv3: {}".format(conv3.shape))
        bn3 = batch_normal(conv3, "bn3", training)
        pool3 = max_pool(bn3, 2, 2, "pool3")
        print("pool3: {}".format(pool3.shape))

    with tf.name_scope("flatten"):
        flatten = layers.flatten(pool3, scope="flatten")
        print("flatten: {}".format(flatten.shape))

    with tf.name_scope("fc1"):
        fc1 = fc(flatten, 1024, "fc1")
        print("fc1: {}".format(fc1.shape))

    with tf.name_scope("outputs"):
        outputs = fc(fc1, 10, "fc2", activation=False)
        print("outputs: {}".format(outputs.shape))
    # -------------- 網絡構建完畢 -----------------------

    # 全局訓練步數
    global_step_tensor = tf.Variable(0, trainable=False, name="global_step")

    # 當前訓練的epoch次數
    cur_epoch_tensor = tf.Variable(0, trainable=False, name="cur_epoch")

    # epoch自增操作
    cur_epoch_increment = tf.assign_add(cur_epoch_tensor, 1)

    # 定義loss
    with tf.name_scope("loss"):
        loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=outputs))

    # 定義訓練step操作
    with tf.name_scope("train_op"):
        # 學習率指數衰減
        lr = tf.train.exponential_decay(config.lr,
                                        global_step=global_step_tensor,
                                        decay_steps=1000,
                                        decay_rate=0.9,
                                        staircase=True)
        tf.summary.scalar("lr", lr)  # 記錄學習率

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = tf.train.AdamOptimizer(lr).minimize(loss, global_step=global_step_tensor)

    # 計算accuracy操作
    with tf.name_scope("acc"):
        pred_corrects = tf.equal(tf.argmax(outputs, 1), y)
        acc = tf.reduce_mean(tf.cast(pred_corrects, tf.float32))

    # 定義模型保存器
    # 注意:需要在模型構建完成後定義(初始化)
    saver = tf.train.Saver(max_to_keep=config.max_to_keep)

    # 定義測試集最小的loss
    best_loss = 10.0

    # 初始化global和local變量操作
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    with tf.Session() as sess:
        sess.run(init_op)  # 初始化

        # 寫入日誌操作
        train_writer = tf.summary.FileWriter(config.log_dir + "train/", sess.graph)
        test_writer = tf.summary.FileWriter(config.log_dir + "test/")
        merged = tf.summary.merge_all()

        # 訓練和驗證
        while cur_epoch_tensor.eval(sess) < config.max_epochs:
            # --------------------------------------- 訓練一個epoch ---------------------------------------------------
            train_losses = []
            train_accs = []
            for _ in range(config.steps_per_epoch):
                x_train, y_train = sess.run(train_next)
                summary_merged, _, train_loss, train_acc = sess.run([merged, train_op, loss, acc],
                                                                    feed_dict={x: x_train, y: y_train, training: True})
                train_writer.add_summary(summary_merged, global_step=global_step_tensor.eval(sess))
                train_losses.append(train_loss)
                train_accs.append(train_acc)
            train_loss = np.mean(train_losses)
            train_acc = np.mean(train_accs)
            logger("Train {}/{}, loss: {:.3f}, acc: {:.3f}".format(global_step_tensor.eval(sess),
                                                                   cur_epoch_tensor.eval(sess), train_loss, train_acc))
            train_summary = tf.summary.Summary(
                value=[
                    tf.summary.Summary.Value(tag="train/loss", simple_value=train_loss),
                    tf.summary.Summary.Value(tag="train/acc", simple_value=train_acc)
                ]
            )
            train_writer.add_summary(train_summary, global_step=global_step_tensor.eval(sess))
            # --------------------------------------- 一個epoch結束 ---------------------------------------------------

            # --------------------------------------- 測試集全量測試 ---------------------------------------------------
            if cur_epoch_tensor.eval(sess) % config.test_by_epochs == 0:
                sess.run(test_iterator.initializer)
                test_losses = []
                test_accs = []
                while True:
                    try:
                        x_test, y_test = sess.run(test_next)
                    except tf.errors.OutOfRangeError:
                        break
                    test_loss, test_acc = sess.run([loss, acc], feed_dict={x: x_test, y: y_test, training: False})
                    test_losses.append(test_loss)
                    test_accs.append(test_acc)
                test_loss = np.mean(test_losses)
                test_acc = np.mean(test_accs)

                logger("Test {}/{}, loss: {:.3f}, acc: {:.3f}".format(global_step_tensor.eval(sess),
                                                                      cur_epoch_tensor.eval(sess), test_loss, test_acc))

                test_summary = tf.summary.Summary(
                    value=[
                        tf.summary.Summary.Value(tag="test/loss", simple_value=test_loss),
                        tf.summary.Summary.Value(tag="test/acc", simple_value=test_acc)
                    ]
                )
                test_writer.add_summary(test_summary, global_step=global_step_tensor.eval(sess))
                # --------------------------------------- 測試結束 ---------------------------------------------------

                # ---------------------- 保存模型 ----------------------
                if test_acc < best_loss:
                    logger("Saving model ...")
                    saver.save(sess, config.ck_dir + "model", global_step=global_step_tensor)
                    best_loss = test_acc
                    logger("Model saved.")
                # ---------------------- 保存完成 ----------------------

            sess.run(cur_epoch_increment)


def test(config_file="fashion-mnist.json"):
    """
    查看訓練好的模型中的參數
    :param config_file:
    :return:
    """
    # 解析配置文件,將配置項轉換爲配置類
    config = get_config(config_file, delete=False)
    latest_checkpoint = tf.train.latest_checkpoint(config.ck_dir)
    print_tensors_in_checkpoint_file(latest_checkpoint,
                                     tensor_name="",
                                     all_tensors=True,
                                     all_tensor_names=True)


if __name__ == '__main__':
    train()  # 訓練模型
    test()  # 查看模型文件

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章