tensorflow小練手案例 RNN lstm 擬合 周期函數 sin

附上一個簡單lstm結構的網絡,麻雀雖小但是五臟俱全,其中包含了很多優化方法:
擬合sin函數

import tensorflow as tf
import numpy as np

steps = 32
batch_size = 32
train_size = 7000
LSTM_KEEP_PROB = 0.9
NUM_LAYERS = 2
HIDDEN_SIZE = 64
train_step = 10000

X = np.linspace(0, 1000, 10000)
y = np.sin(X)

def data(y):
    data = []
    target = []
    y = y.tolist()
    for i in range(len(y) - steps - 1):
        data.append(y[i: i + steps])
        target.append(y[i + steps])
    return ((np.array(data)[:train_size])[:, :, np.newaxis], np.array(target)[:train_size][:, np.newaxis])\
    , ((np.array(data)[train_size:])[:, :, np.newaxis], np.array(target)[train_size:][:, np.newaxis])
    
# model
tf.reset_default_graph()  
inputs = tf.placeholder(tf.float32, [None, steps, 1])
outputs = tf.placeholder(tf.float32, [None, 1])
    
cell = tf.nn.rnn_cell.BasicLSTMCell
with tf.variable_scope('cell'):
    rnn_cell = []
    for _ in range(NUM_LAYERS): 
        rnn_cell.append(tf.nn.rnn_cell.DropoutWrapper(
                cell(HIDDEN_SIZE), output_keep_prob = LSTM_KEEP_PROB))
    cell = tf.nn.rnn_cell.MultiRNNCell(rnn_cell, state_is_tuple = True)
    ac_y, _ = tf.nn.dynamic_rnn(cell, inputs, dtype = tf.float32)
    ac_y = tf.layers.dense(ac_y[:, -1, :], 1)
    
loss = tf.losses.mean_squared_error(ac_y, outputs)
    
tvars = tf.trainable_variables()
grads, _ = tf.clip_by_global_norm(tf.gradients(loss, tvars), 5)

globel_step = tf.Variable(0)
l_n = tf.train.exponential_decay(0.02, globel_step, 1, 0.96, staircase = True)

train_op = tf.train.AdamOptimizer(l_n)
train_op = train_op.apply_gradients(zip(grads, tvars))

(train_x, train_y), (test_x, test_y) = data(y)

config = tf.ConfigProto()

config.gpu_options.allow_growth = True

config.gpu_options.per_process_gpu_memory_fraction = 0.666

with tf.Session(config = config) as sess:
    tf.global_variables_initializer().run(session = sess)
    for i in range(10000):
        
        if i % 2 == 0:
            import matplotlib.pyplot as plt
            mean_loss = []
            output = []
            for j in range(test_x.shape[0] // batch_size - 1):
                out, nloss = sess.run([ac_y, loss], 
                         feed_dict = {inputs: test_x[j * batch_size: 
                             (j + 1) * batch_size], outputs: test_y[j * batch_size: 
                             (j + 1) * batch_size]})
                mean_loss.append(nloss)
                output.append(out)
            mean_loss, output = np.array(mean_loss), np.array(output)
            print('loss:', mean_loss.mean())
            plt.plot(output.reshape(-1))
            plt.show()
        for j in range(train_x.shape[0] // batch_size - 1):
            nloss, _ = sess.run([loss, train_op], 
                     feed_dict = {inputs: train_x[j * batch_size: 
                         (j + 1) * batch_size], outputs: train_y[j * batch_size: 
                         (j + 1) * batch_size]})
        print(nloss)

擬合sin函數

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章