TensorFlow读取数据

TensorFlow读取数据

TensorFlow程序读取数据一共有3种方法:

  • 供给数据(feed)
  • 从文件读取数据
  • 预加载数据:适用于数据量较小的情况

1、供给数据
TensorFlow的数据供给机制允许在TensorFlow运算图中将数据输入到任一张量中。如下所示:

import tensorflow as tf
x1 = tf.placeholder(tf.float32,shape=[2,3]) #占位符,实际数据有feed给定
x2 = tf.placeholder(tf.float32,shape=[2,3])
y = tf.add(x1,x2)
with tf.Session() as sess:
    print sess.run(y,feed_dict={x1:[[1,2,3],[1,2,3]],x2:[[4,5,6],[4,5,6]]})

2、从文件读取数据

  • 构建文件名列表,将文件名列表传给tf.train.string_input_producer() 函数,从而生成一个先入先出的队列,对应的文件阅读器会根据队列来获取数据。

    tf.train.string_input_producer(string_tensor, num_epochs=None, shuffle=True, seed=None, capacity=32, name=None)
    该函数返回的是一个队列,各参数的意义如下所示:
    string_tensor: 存储的是要构造的队列的一组文件名
    num_epochs: 如果没有指定,则在这些文件名中一直循环不停。若指定,则在每一个 string 都被生成指定次数后产生 out_of_range 错误
    shuffle: 是否开启乱序,默认开启
    seed:随机种子数,控制随机数的
    
  • 文件格式,对于不同的文件格式选择相应的文件阅读器
    从CSV文件中读取数据,如下所示:
import tensorflow as tf
filenames = ['a.csv']
filename_queue = tf.train.string_input_producer(filenames, shuffle=False)
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
example, label = tf.decode_csv(value, record_defaults=[['null'], ['null']])
example_batch, label_batch = tf.train.shuffle_batch([example,label], batch_size=1, capacity=10, min_after_dequeue=5, num_threads=2)
with tf.Session() as sess:
    coord = tf.train.Coordinator() 
    threads = tf.train.start_queue_runners(coord=coord) 
    for i in range(10):
        x,y = sess.run([example, label])
        x_,y_ = sess.run([example_batch, label_batch])
        print x,y#不乱序
        print x_,y_#乱序
    coord.request_stop()
    coord.join(threads)
注:在调用run或者eval去执行read之前, 你必须调用tf.train.start_queue_runners来将文件名填充到队列。否则read操作会被阻塞到文件名队列中有值为止。

3、预加载数据
将数据全部加载到内存中,适用于小数据集,一般将数据集存储在常量中或存储在变量中,若存储在变量中,则一般初始化之后,就不会再改变数据。

import tensorflow as tf
x1 = tf.constant([1,2,3])
x2 = tf.constant([4,5,6])
y = tf.add(x1,x2)
with tf.Session() as sess:
    print sess.run(y)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章