訓練,驗證,測試模型

# coding: utf-8
from BilstmModel.BilstmModel import BilstmModel
from BilstmModel.cnn_model import TextCNN
from DataProcess.DateProcess import process_file, build_word_to_id, build_lables_to_id, batch_iter
from Config.ConfigParameters import ConfigParameters as config
import tensorflow as tf
import os

base_dir = '../DataProcess/cnews'
train_data_path = os.path.join(base_dir, 'cnews.train.txt')
test_data_path = os.path.join(base_dir, 'cnews.test.txt')
val_data_path = os.path.join(base_dir, 'cnews.val.txt')
vocab_path = os.path.join(base_dir, 'cnews.vocab.txt')
save_dir = 'checkpoints/bilstm'
save_path = os.path.join(save_dir, 'best_validation')

word_to_id = build_word_to_id(vocab_path)
lables_to_id = build_lables_to_id()
model = BilstmModel()
# model = TextCNN()


def feed_data(contents, lables, keep_prob):
    feed_dict = {
        model.x: contents,
        model.y: lables,
        model.keep_prob: keep_prob
    }
    return feed_dict


def evaluate(sess, xs, ys):
    batches = batch_iter(xs, ys, config.batch_size)
    total_loss = 0
    total_acc = 0
    for contents, lables in batches:
        batch_len = len(contents)
        feed_dict = feed_data(contents, lables, config.keep_prob)
        loss, acc = sess.run([model.loss, model.accuracy], feed_dict=feed_dict)
        total_loss += loss * batch_len
        total_acc += acc * batch_len
    return total_loss / len(xs), total_acc / len(ys)


def training():
    train_contents, train_lables = process_file(train_data_path, word_to_id, lables_to_id,
                                                config.max_seq_length)
    val_contents, val_lables = process_file(val_data_path, word_to_id, lables_to_id,
                                            config.max_seq_length)
    saver = tf.train.Saver()
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    total_batch = 0
    best_acc_val = 0.0
    last_improved = 0
    flag = False
    print('training')
    for epoch in range(config.epoches):
        print('epoch %s' % (epoch + 1))
        train_batches = batch_iter(train_contents, train_lables, config.batch_size)
        for xs, ys in train_batches:
            feed_dict = feed_data(xs, ys, config.keep_prob)
            # 一組一組數據進行訓練
            if total_batch % config.print_per_batch == 0:
                train_loss, train_accuracy = sess.run([model.loss, model.accuracy], feed_dict=feed_dict)
                val_loss, val_acc = evaluate(sess, val_contents, val_lables)
                if val_acc > best_acc_val:
                    # 保存最好結果
                    best_acc_val = val_acc
                    last_improved = total_batch
                    saver.save(sess=sess, save_path=save_path)
                    msg = 'train_loss:{0:>6.2},   train_accuracy:{1:>7.2%},   val_loss:{2:>6.2},   val_acc:{3:>7.2%}'
                    print(msg.format(train_loss, train_accuracy, val_loss, val_acc))
            sess.run([model.train_op], feed_dict=feed_dict)
            total_batch += 1
            if total_batch - last_improved > config.require_improvement:
                print("No optimization for a long time, auto-stopping...")
                flag = True
                break  # 跳出循環
        if flag:
            break




def test():
    print("Loading test data...")
    x_test, y_test = process_file(test_data_path, word_to_id, lables_to_id, config.max_seq_length)
    session = tf.Session()
    session.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    saver.restore(sess=session, save_path=save_path)  # 讀取保存的模型
    print('Testing...')
    loss_test, acc_test = evaluate(session, x_test, y_test)
    msg = 'loss_test: {0:>6.2}, acc_test: {1:>7.2%}'
    print(msg.format(loss_test, acc_test))


if __name__ == '__main__':
    # training()
    test()

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章