tensorflow tf.train.batch

原文鏈接:https://blog.csdn.net/liweibin1994/article/details/78306417

代碼要改一下:get_batch_data()裏把images和label寫反了,將label, images = generate_data()改成這樣:images,label = generate_data()。同時把隨機數改成了固定數。epochs也改大了。


import numpy as np
import tensorflow as tf


def generate_data():
    num = 25
    label = np.asarray(range(0, num))
    arr = []
    for x in range(num) :
        arrsub = [x*10+y for y in range(5)]
        arr.append( arrsub  )
    images = np.array(arr)#= np.random.random([num, 5])
    print('label:' , label )
    print('images:' , images )
    print('label size :{}, image size {}'.format(label.shape, images.shape))
    return images,label

def get_batch_data():
    images,label = generate_data()
    input_queue = tf.train.slice_input_producer([images, label], shuffle=False,num_epochs=20)
    image_batch, label_batch = tf.train.batch(input_queue, batch_size=5, num_threads=1, capacity=64,allow_smaller_final_batch=False)
    return image_batch,label_batch


images,label = get_batch_data()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())#就是這一行
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess,coord)
epochs = 1
try:
    while not coord.should_stop():
        i,l = sess.run([images,label])
        #print('i:',i)
        print( 'epochs:',epochs, 'l:',l)
        epochs = epochs  + 1
except tf.errors.OutOfRangeError:
    print('Done training')
finally:
    coord.request_stop()
coord.join(threads)
sess.close()

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章