TensorFlow讀取數據

TensorFlow讀取數據

TensorFlow程序讀取數據一共有3種方法:

  • 供給數據(feed)
  • 從文件讀取數據
  • 預加載數據:適用於數據量較小的情況

1、供給數據
TensorFlow的數據供給機制允許在TensorFlow運算圖中將數據輸入到任一張量中。如下所示:

import tensorflow as tf
x1 = tf.placeholder(tf.float32,shape=[2,3]) #佔位符,實際數據有feed給定
x2 = tf.placeholder(tf.float32,shape=[2,3])
y = tf.add(x1,x2)
with tf.Session() as sess:
    print sess.run(y,feed_dict={x1:[[1,2,3],[1,2,3]],x2:[[4,5,6],[4,5,6]]})

2、從文件讀取數據

  • 構建文件名列表,將文件名列表傳給tf.train.string_input_producer() 函數,從而生成一個先入先出的隊列,對應的文件閱讀器會根據隊列來獲取數據。

    tf.train.string_input_producer(string_tensor, num_epochs=None, shuffle=True, seed=None, capacity=32, name=None)
    該函數返回的是一個隊列,各參數的意義如下所示:
    string_tensor: 存儲的是要構造的隊列的一組文件名
    num_epochs: 如果沒有指定,則在這些文件名中一直循環不停。若指定,則在每一個 string 都被生成指定次數後產生 out_of_range 錯誤
    shuffle: 是否開啓亂序,默認開啓
    seed:隨機種子數,控制隨機數的
    
  • 文件格式,對於不同的文件格式選擇相應的文件閱讀器
    從CSV文件中讀取數據,如下所示:
import tensorflow as tf
filenames = ['a.csv']
filename_queue = tf.train.string_input_producer(filenames, shuffle=False)
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
example, label = tf.decode_csv(value, record_defaults=[['null'], ['null']])
example_batch, label_batch = tf.train.shuffle_batch([example,label], batch_size=1, capacity=10, min_after_dequeue=5, num_threads=2)
with tf.Session() as sess:
    coord = tf.train.Coordinator() 
    threads = tf.train.start_queue_runners(coord=coord) 
    for i in range(10):
        x,y = sess.run([example, label])
        x_,y_ = sess.run([example_batch, label_batch])
        print x,y#不亂序
        print x_,y_#亂序
    coord.request_stop()
    coord.join(threads)
注:在調用run或者eval去執行read之前, 你必須調用tf.train.start_queue_runners來將文件名填充到隊列。否則read操作會被阻塞到文件名隊列中有值爲止。

3、預加載數據
將數據全部加載到內存中,適用於小數據集,一般將數據集存儲在常量中或存儲在變量中,若存儲在變量中,則一般初始化之後,就不會再改變數據。

import tensorflow as tf
x1 = tf.constant([1,2,3])
x2 = tf.constant([4,5,6])
y = tf.add(x1,x2)
with tf.Session() as sess:
    print sess.run(y)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章