TensorFlow數據加載

TensorFlow程序讀取數據一共有3種方法:

  • 供給數據(Feeding): 在TensorFlow程序運行的每一步, 讓Python代碼來供給數據。
  • 從文件讀取數據: 在TensorFlow圖的起始, 讓一個輸入管線從文件中讀取數據。
  • 預加載數據: 在TensorFlow圖中定義常量或變量來保存所有數據(僅適用於數據量比較小的情況)

第一種大家很熟悉不用多說, 第三種採用比如np.loadtxt 等函數一次性載入所有數據,如果數據很大會很慢。

重點談一談對第二種方法的體會:

假如我們的數據爲一個.csv文件:

那麼用以下的代碼可以讀取:
# coding=utf-8
import tensorflow as tf
import os
import csv
#要保存後csv格式的文件名
filenames = ['./new.csv']
#file_name_string="'D:/dataTest.csv'"
#filename_queue = tf.train.string_input_producer([file_name_string])
filename_queue = tf.train.string_input_producer(filenames, shuffle=False, num_epochs=1)

reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
print(value)
# Default values, in case of empty columns. Also specifies the type of the
# decoded result.
record_defaults = [[1.0],[1.0]]
data = tf.decode_csv(value, record_defaults=record_defaults)
data_batch = tf.train.batch([data], batch_size=4, capacity=200, num_threads=2)
#features = tf.concat(0, [col1, col2, col3])
init_local_op = tf.local_variables_initializer()
with tf.Session() as sess:
  # Start populating the filename queue.
   sess.run(init_local_op)
   tf.global_variables_initializer().run()
   coord = tf.train.Coordinator()
   threads = tf.train.start_queue_runners(coord=coord)


   for i in range(2):
    # Retrieve a single instance:
        example= sess.run([data_batch])
        print(example)

   coord.request_stop()
   coord.join(threads)

執行的結果爲:

如果有數據有很多列, 列舉不過來呢?

只需把 record_defaults = [[1.0],[1.0]] 修改爲 

record_defaults = list([1.0] for i in range(2))

就可以了, 數據有多少列就填多大!

Done!

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章