TensorFlow——Fashion mnist示例

在这里插入图片描述
在这里插入图片描述
配置文件

{
  "project_name": "fashion-mnist",
  "save_root": "../experiments/",
  "num_workers": 8,
  "train_batch_size": 64,
  "test_batch_size": 128,
  "prefecth_times": 1,
  "max_epochs": 200,
  "test_by_epochs": 10,
  "steps_per_epoch": 100,
  "max_to_keep": 5,
  "lr": 0.1
}

模型训练与保存

#! -*- coding: utf-8 -*-
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
from tensorflow.contrib import layers
from tensorflow import keras
from bunch import Bunch
import tensorflow as tf
import numpy as np
import time
import json
import os
import shutil


def logger(info, level=0):
    """
    打印日志
    :param info: 日志信息
    :param level: 日志等级,默认为INFO
    :return: None
    """
    levels = ["INFO", "WARNING", "ERROR"]
    print("[{} {}] {}".format(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), levels[level], info))


def read_config(json_file):
    """
    将config的json文件转换为config类
    :param json_file: json文件
    :return: config类
    """
    with open(json_file) as config_file:
        config_json = json.load(config_file)
        config = Bunch(config_json)
        return config


def mkdir(dir_name, delete):
    """
    创建目录
    :param dir_name: 目录名称
    :param delete: 是否删除已存在的目录
    :return: None
    """
    if os.path.exists(dir_name):
        if delete:
            logger("%s is existed. Deleting ..." % dir_name)
            shutil.rmtree(dir_name)
            logger("Delete succeed. Recreating %s ..." % dir_name)
            os.makedirs(dir_name)
            logger("Create succeed.")
        else:
            logger("%s is existed.")
    else:
        logger("Creating %s ..." % dir_name)
        os.makedirs(dir_name)
        logger("Create succeed.")


def get_config(json_file, delete):
    """
    获取config类
    :param json_file: json文件
    :param delete: 是否删除已存在的目录
    :return: config类
    """
    config = read_config(json_file)
    # 保存日志文件目录
    config.log_dir = os.path.join(config.save_root, config.project_name, "logs/")
    # 保存检查点文件目录
    config.ck_dir = os.path.join(config.save_root, config.project_name, "checkpoints/")

    # 创建目录
    mkdir(config.log_dir, delete=delete)
    mkdir(config.ck_dir, delete=delete)
    return config


def _parse_fn(x, label):
    """
    数据转换函数
    :param x: 图像数据
    :param label: 图片标签
    :return: x, label
    """
    x = tf.expand_dims(x, axis=2)
    return x, label


def make_variable(name, shape, initializer, trainable=True):
    """
    创建变量
    :param name:
    :param shape:
    :param initializer:
    :param trainable:
    :return:
    """
    with tf.variable_scope(name):
        return tf.get_variable(name, shape, tf.float32, initializer=initializer, trainable=trainable)


def batch_normal(x, name, training, activation=tf.nn.relu):
    """
    Batch Normalization层
    :param x:
    :param name:
    :param training:
    :param activation:
    :return:
    """
    with tf.name_scope(name), tf.variable_scope(name):
        return layers.batch_norm(x, decay=0.9, activation_fn=activation, is_training=training, scope="batch_normal")


def max_pool(x, k, s, name, padding="SAME"):
    """
    Max pooling层
    :param x:
    :param k:
    :param s:
    :param name:
    :param padding:
    :return:
    """
    with tf.name_scope(name):
        return tf.nn.max_pool(x, [1, k, k, 1], [1, s, s, 1], padding, name="max_pool")


def fc(x, c_out, name, use_bias=True, activation=True):
    """
    全连接层
    :param x:
    :param c_out:
    :param name:
    :param use_bias:
    :param activation:
    :return:
    """
    c_in = x.get_shape().as_list()[-1]
    with tf.name_scope(name), tf.variable_scope(name):
        weights = make_variable("weights", [c_in, c_out], initializer=tf.random_normal_initializer())
        outputs = tf.matmul(x, weights, name="matmul")
        if use_bias:
            biases = make_variable("biases", [c_out], initializer=tf.constant_initializer(0.001))
            outputs = tf.nn.bias_add(outputs, biases)
        if activation:
            outputs = tf.nn.relu(outputs, "relu")
        return outputs


def conv(x, k, s, c_in, c_out, name, use_bias=True, padding="SAME", activation=None):
    """
    卷积层
    :param x:
    :param k:
    :param s:
    :param c_in:
    :param c_out:
    :param name:
    :param use_bias:
    :param padding:
    :param activation:
    :return:
    """
    with tf.name_scope(name), tf.variable_scope(name):
        weights = make_variable("weights", [k, k, c_in, c_out], tf.random_normal_initializer())
        outputs = tf.nn.conv2d(x, weights, [1, s, s, 1], padding, name="conv")
        if use_bias:
            biases = make_variable("biases", [c_out], tf.constant_initializer(0.001))
            outputs = tf.nn.bias_add(outputs, biases)

        if activation:
            outputs = tf.nn.relu(outputs)
        return outputs


def train(config_file="fashion-mnist.json"):
    """
    主函数
    :param config_file: 配置文件
    :return:
    """
    # 解析配置文件,将配置项转换为配置类
    config = get_config(config_file, delete=True)

    # Fashiong mnist数据集
    fashion_mnist = keras.datasets.fashion_mnist
    # 加载数据集
    (x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()

    # 数据集预处理
    train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(buffer_size=1000). \
        map(_parse_fn, config.num_workers). \
        batch(batch_size=config.train_batch_size). \
        repeat(). \
        prefetch(buffer_size=config.train_batch_size * config.prefecth_times)

    test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)).map(_parse_fn,
                                                                            num_parallel_calls=config.num_workers). \
        batch(batch_size=config.test_batch_size). \
        prefetch(config.test_batch_size * config.prefecth_times)

    # 创建数据迭代器
    train_iterator = train_dataset.make_one_shot_iterator()
    test_iterator = test_dataset.make_initializable_iterator()

    # 获取下一批数据
    train_next = train_iterator.get_next()
    test_next = test_iterator.get_next()

    # 创建占位符
    x = tf.placeholder(tf.float32, [None, 28, 28, 1])
    y = tf.placeholder(tf.int64, [None, ])
    training = tf.placeholder(tf.bool)

    # --------------- 构建网络 ----------------------
    with tf.name_scope("conv1"):
        conv1 = conv(x, 5, 2, 1, 32, name="conv1")
        print("conv1: {}".format(conv1.shape))
        bn1 = batch_normal(conv1, "bn1", training)
        pool1 = max_pool(bn1, 2, 2, "pool1")
        print("pool1: {}".format(pool1.shape))

    with tf.name_scope("conv2"):
        conv2 = conv(pool1, 3, 1, 32, 64, "conv2")
        print("conv2: {}".format(conv2.shape))
        bn2 = batch_normal(conv2, "bn2", training)
        pool2 = batch_normal(bn2, "pool_2", training)
        print("pool2: {}".format(pool2.shape))

    with tf.name_scope("conv3"):
        conv3 = conv(pool2, 3, 1, 64, 128, "conv3")
        print("conv3: {}".format(conv3.shape))
        bn3 = batch_normal(conv3, "bn3", training)
        pool3 = max_pool(bn3, 2, 2, "pool3")
        print("pool3: {}".format(pool3.shape))

    with tf.name_scope("flatten"):
        flatten = layers.flatten(pool3, scope="flatten")
        print("flatten: {}".format(flatten.shape))

    with tf.name_scope("fc1"):
        fc1 = fc(flatten, 1024, "fc1")
        print("fc1: {}".format(fc1.shape))

    with tf.name_scope("outputs"):
        outputs = fc(fc1, 10, "fc2", activation=False)
        print("outputs: {}".format(outputs.shape))
    # -------------- 网络构建完毕 -----------------------

    # 全局训练步数
    global_step_tensor = tf.Variable(0, trainable=False, name="global_step")

    # 当前训练的epoch次数
    cur_epoch_tensor = tf.Variable(0, trainable=False, name="cur_epoch")

    # epoch自增操作
    cur_epoch_increment = tf.assign_add(cur_epoch_tensor, 1)

    # 定义loss
    with tf.name_scope("loss"):
        loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=outputs))

    # 定义训练step操作
    with tf.name_scope("train_op"):
        # 学习率指数衰减
        lr = tf.train.exponential_decay(config.lr,
                                        global_step=global_step_tensor,
                                        decay_steps=1000,
                                        decay_rate=0.9,
                                        staircase=True)
        tf.summary.scalar("lr", lr)  # 记录学习率

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = tf.train.AdamOptimizer(lr).minimize(loss, global_step=global_step_tensor)

    # 计算accuracy操作
    with tf.name_scope("acc"):
        pred_corrects = tf.equal(tf.argmax(outputs, 1), y)
        acc = tf.reduce_mean(tf.cast(pred_corrects, tf.float32))

    # 定义模型保存器
    # 注意:需要在模型构建完成后定义(初始化)
    saver = tf.train.Saver(max_to_keep=config.max_to_keep)

    # 定义测试集最小的loss
    best_loss = 10.0

    # 初始化global和local变量操作
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    with tf.Session() as sess:
        sess.run(init_op)  # 初始化

        # 写入日志操作
        train_writer = tf.summary.FileWriter(config.log_dir + "train/", sess.graph)
        test_writer = tf.summary.FileWriter(config.log_dir + "test/")
        merged = tf.summary.merge_all()

        # 训练和验证
        while cur_epoch_tensor.eval(sess) < config.max_epochs:
            # --------------------------------------- 训练一个epoch ---------------------------------------------------
            train_losses = []
            train_accs = []
            for _ in range(config.steps_per_epoch):
                x_train, y_train = sess.run(train_next)
                summary_merged, _, train_loss, train_acc = sess.run([merged, train_op, loss, acc],
                                                                    feed_dict={x: x_train, y: y_train, training: True})
                train_writer.add_summary(summary_merged, global_step=global_step_tensor.eval(sess))
                train_losses.append(train_loss)
                train_accs.append(train_acc)
            train_loss = np.mean(train_losses)
            train_acc = np.mean(train_accs)
            logger("Train {}/{}, loss: {:.3f}, acc: {:.3f}".format(global_step_tensor.eval(sess),
                                                                   cur_epoch_tensor.eval(sess), train_loss, train_acc))
            train_summary = tf.summary.Summary(
                value=[
                    tf.summary.Summary.Value(tag="train/loss", simple_value=train_loss),
                    tf.summary.Summary.Value(tag="train/acc", simple_value=train_acc)
                ]
            )
            train_writer.add_summary(train_summary, global_step=global_step_tensor.eval(sess))
            # --------------------------------------- 一个epoch结束 ---------------------------------------------------

            # --------------------------------------- 测试集全量测试 ---------------------------------------------------
            if cur_epoch_tensor.eval(sess) % config.test_by_epochs == 0:
                sess.run(test_iterator.initializer)
                test_losses = []
                test_accs = []
                while True:
                    try:
                        x_test, y_test = sess.run(test_next)
                    except tf.errors.OutOfRangeError:
                        break
                    test_loss, test_acc = sess.run([loss, acc], feed_dict={x: x_test, y: y_test, training: False})
                    test_losses.append(test_loss)
                    test_accs.append(test_acc)
                test_loss = np.mean(test_losses)
                test_acc = np.mean(test_accs)

                logger("Test {}/{}, loss: {:.3f}, acc: {:.3f}".format(global_step_tensor.eval(sess),
                                                                      cur_epoch_tensor.eval(sess), test_loss, test_acc))

                test_summary = tf.summary.Summary(
                    value=[
                        tf.summary.Summary.Value(tag="test/loss", simple_value=test_loss),
                        tf.summary.Summary.Value(tag="test/acc", simple_value=test_acc)
                    ]
                )
                test_writer.add_summary(test_summary, global_step=global_step_tensor.eval(sess))
                # --------------------------------------- 测试结束 ---------------------------------------------------

                # ---------------------- 保存模型 ----------------------
                if test_acc < best_loss:
                    logger("Saving model ...")
                    saver.save(sess, config.ck_dir + "model", global_step=global_step_tensor)
                    best_loss = test_acc
                    logger("Model saved.")
                # ---------------------- 保存完成 ----------------------

            sess.run(cur_epoch_increment)


def test(config_file="fashion-mnist.json"):
    """
    查看训练好的模型中的参数
    :param config_file:
    :return:
    """
    # 解析配置文件,将配置项转换为配置类
    config = get_config(config_file, delete=False)
    latest_checkpoint = tf.train.latest_checkpoint(config.ck_dir)
    print_tensors_in_checkpoint_file(latest_checkpoint,
                                     tensor_name="",
                                     all_tensors=True,
                                     all_tensor_names=True)


if __name__ == '__main__':
    train()  # 训练模型
    test()  # 查看模型文件

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章