tensorflow的Data API

 

例1

import tensorflow as tf
import numpy as np

BATCH_SIZE = 2

train_x = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
train_y = np.sin(train_x)

ds = tf.data.Dataset.from_tensor_slices((train_x, train_y))
ds = ds.repeat().shuffle(buffer_size=1000).batch(BATCH_SIZE)
x, y = ds.make_one_shot_iterator().get_next()

with tf.Session() as sess:
    a,b = sess.run([x,y])

例2

dataset = tf.data.TextLineDataset(SRC_PATH)
dataset = dataset.map(lambda string: tf.string_split([string]).values)
dataset = dataset.map(lambda string: tf.string_to_number(string, tf.int32))
iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(5):
        a = sess.run(one_element)
        print(a)
        print(type(a))

例3

iterator = data.make_initializable_iterator()
a = iterator.get_next()

with tf.Session() as sess:
    sess.run(iterator.initializer)
    a = sess.run(a)

 

 

 

 

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章