TensorFlow程序读取数据一共有3种方法:
- 供给数据(Feeding): 在TensorFlow程序运行的每一步, 让Python代码来供给数据。
- 从文件读取数据: 在TensorFlow图的起始, 让一个输入管线从文件中读取数据。
- 预加载数据: 在TensorFlow图中定义常量或变量来保存所有数据(仅适用于数据量比较小的情况)
第一种大家很熟悉不用多说, 第三种采用比如np.loadtxt 等函数一次性载入所有数据,如果数据很大会很慢。
重点谈一谈对第二种方法的体会:
假如我们的数据为一个.csv文件:
那么用以下的代码可以读取:
# coding=utf-8
import tensorflow as tf
import os
import csv
#要保存后csv格式的文件名
filenames = ['./new.csv']
#file_name_string="'D:/dataTest.csv'"
#filename_queue = tf.train.string_input_producer([file_name_string])
filename_queue = tf.train.string_input_producer(filenames, shuffle=False, num_epochs=1)
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
print(value)
# Default values, in case of empty columns. Also specifies the type of the
# decoded result.
record_defaults = [[1.0],[1.0]]
data = tf.decode_csv(value, record_defaults=record_defaults)
data_batch = tf.train.batch([data], batch_size=4, capacity=200, num_threads=2)
#features = tf.concat(0, [col1, col2, col3])
init_local_op = tf.local_variables_initializer()
with tf.Session() as sess:
# Start populating the filename queue.
sess.run(init_local_op)
tf.global_variables_initializer().run()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(2):
# Retrieve a single instance:
example= sess.run([data_batch])
print(example)
coord.request_stop()
coord.join(threads)
执行的结果为:
如果有数据有很多列, 列举不过来呢?
只需把 record_defaults = [[1.0],[1.0]] 修改为
record_defaults = list([1.0] for i in range(2))
就可以了, 数据有多少列就填多大!
Done!