# coding: utf-8
from BilstmModel.BilstmModel import BilstmModel
from BilstmModel.cnn_model import TextCNN
from DataProcess.DateProcess import process_file, build_word_to_id, build_lables_to_id, batch_iter
from Config.ConfigParameters import ConfigParameters as config
import tensorflow as tf
import os
base_dir = '../DataProcess/cnews'
train_data_path = os.path.join(base_dir, 'cnews.train.txt')
test_data_path = os.path.join(base_dir, 'cnews.test.txt')
val_data_path = os.path.join(base_dir, 'cnews.val.txt')
vocab_path = os.path.join(base_dir, 'cnews.vocab.txt')
save_dir = 'checkpoints/bilstm'
save_path = os.path.join(save_dir, 'best_validation')
word_to_id = build_word_to_id(vocab_path)
lables_to_id = build_lables_to_id()
model = BilstmModel()
# model = TextCNN()
def feed_data(contents, lables, keep_prob):
feed_dict = {
model.x: contents,
model.y: lables,
model.keep_prob: keep_prob
}
return feed_dict
def evaluate(sess, xs, ys):
batches = batch_iter(xs, ys, config.batch_size)
total_loss = 0
total_acc = 0
for contents, lables in batches:
batch_len = len(contents)
feed_dict = feed_data(contents, lables, config.keep_prob)
loss, acc = sess.run([model.loss, model.accuracy], feed_dict=feed_dict)
total_loss += loss * batch_len
total_acc += acc * batch_len
return total_loss / len(xs), total_acc / len(ys)
def training():
train_contents, train_lables = process_file(train_data_path, word_to_id, lables_to_id,
config.max_seq_length)
val_contents, val_lables = process_file(val_data_path, word_to_id, lables_to_id,
config.max_seq_length)
saver = tf.train.Saver()
if not os.path.exists(save_dir):
os.makedirs(save_dir)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
total_batch = 0
best_acc_val = 0.0
last_improved = 0
flag = False
print('training')
for epoch in range(config.epoches):
print('epoch %s' % (epoch + 1))
train_batches = batch_iter(train_contents, train_lables, config.batch_size)
for xs, ys in train_batches:
feed_dict = feed_data(xs, ys, config.keep_prob)
# 一組一組數據進行訓練
if total_batch % config.print_per_batch == 0:
train_loss, train_accuracy = sess.run([model.loss, model.accuracy], feed_dict=feed_dict)
val_loss, val_acc = evaluate(sess, val_contents, val_lables)
if val_acc > best_acc_val:
# 保存最好結果
best_acc_val = val_acc
last_improved = total_batch
saver.save(sess=sess, save_path=save_path)
msg = 'train_loss:{0:>6.2}, train_accuracy:{1:>7.2%}, val_loss:{2:>6.2}, val_acc:{3:>7.2%}'
print(msg.format(train_loss, train_accuracy, val_loss, val_acc))
sess.run([model.train_op], feed_dict=feed_dict)
total_batch += 1
if total_batch - last_improved > config.require_improvement:
print("No optimization for a long time, auto-stopping...")
flag = True
break # 跳出循環
if flag:
break
def test():
print("Loading test data...")
x_test, y_test = process_file(test_data_path, word_to_id, lables_to_id, config.max_seq_length)
session = tf.Session()
session.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.restore(sess=session, save_path=save_path) # 讀取保存的模型
print('Testing...')
loss_test, acc_test = evaluate(session, x_test, y_test)
msg = 'loss_test: {0:>6.2}, acc_test: {1:>7.2%}'
print(msg.format(loss_test, acc_test))
if __name__ == '__main__':
# training()
test()
訓練,驗證,測試模型
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.