tensorflow中的数据对象Dataset

注:学习中遇到了tensorflow中的Dataset,在此记录相关内容(仅自己遇到)

问题描述: 使用pandas读取了excel表格中的数据,提取其中的几列后将其转换为numpy数组。在此数组基础上使用了tf.data.Dataset.from_tensor_slices()shuffle()batch()make_one_shot_iterator()方法。以下为代码实例。

代码片段:

import pandas as pd
import numpy as np

# 读取excel表格中数据
df = pd.read_excel('./test.xlsx')
x = np.array(df[['band1', 'band2', 'band3', 'band4', 'band5', 'band6', 'band7']])
# 输出x
print(x)

输出结果为:

[[ 423  332  643  460 2909 1973  895]
 [ 395  309  617  452 2863 1997  908]
 [ 374  291  599  448 2823 2013  919]
 [ 394  304  612  465 2820 2042  943]
 [ 393  304  613  466 2814 2048  951]
 [ 399  311  621  469 2826 2049  955]
 [ 395  311  622  467 2816 2029  953]
 [ 398  316  629  473 2798 2002  956]
 [ 351  293  617  477 2712 1965  976]
 [ 268  250  595  489 2561 1924 1021]]

切片:

# 注意此处被切片的数据的第一个维度值必须相同,例中都为10
db_train = tf.data.Dataset.from_tensor_slices((x, [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]))
print(db_train)

输出结果为:

'''
输出结果中显示db_train是一个Dataset实例。
(7,)表示变量x按照第一个维度值——10切分后,单个切片是秩为7的数组
()表示原本秩为10的数组按照第一个维度值——10切分后,单个切片变是标量
'''
<DatasetV1Adapter shapes: ((7,), ()), types: (tf.int64, tf.int32)>

查看每一个切片:

# 创建迭代器,枚举此数据集元素
iterator = db_train.make_one_shot_iterator()
# 获取下一次迭代值
ele = iterator.get_next()
# 创建会话
with tf.Session() as sess:
	for i in range(10):
		(x, y) = sess.run(ele)
		print(x)
		print(y)

输出为:

[ 423  332  643  460 2909 1973  895]
1
-----
[ 395  309  617  452 2863 1997  908]
2
-----
[ 374  291  599  448 2823 2013  919]
3
-----
[ 394  304  612  465 2820 2042  943]
4
-----
[ 393  304  613  466 2814 2048  951]
5
-----
[ 399  311  621  469 2826 2049  955]
6
-----
[ 395  311  622  467 2816 2029  953]
7
-----
[ 398  316  629  473 2798 2002  956]
8
-----
[ 351  293  617  477 2712 1965  976]
9
-----
[ 268  250  595  489 2561 1924 1021]
0
-----

查看每一个切片前,可以对db_train做shuffle()或batch()操作,如下:

# 打乱元素顺序,按照新顺序分批
# 打乱切片顺序,参数buffer_size表示从现有dataset中采样到固定数量的元素到维护的buffer中,并从buffer中随机选出一个元素
db_train = db_train.shuffle(buffer_size=10) 
db_train = db_train.batch(2) # 分批,参数为2表示在元素集中将2个连续元素分为1个批次
# 逐个输出
iterator = db_train.make_one_shot_iterator()
ele = iterator.get_next()
with tf.Session() as sess:
    for i in range(5):
        (x, y) = sess.run(ele)
        print(x)
        print(y)
        print('-----')

输出为:

[[ 423  332  643  460 2909 1973  895]
 [ 394  304  612  465 2820 2042  943]]
[1 4]
-----
[[ 395  309  617  452 2863 1997  908]
 [ 374  291  599  448 2823 2013  919]]
[2 3]
-----
[[ 395  311  622  467 2816 2029  953]
 [ 398  316  629  473 2798 2002  956]]
[7 8]
-----
[[ 393  304  613  466 2814 2048  951]
 [ 399  311  621  469 2826 2049  955]]
[5 6]
-----
[[ 351  293  617  477 2712 1965  976]
 [ 268  250  595  489 2561 1924 1021]]
[9 0]
-----

可以看到上面输出,总体顺序是打乱的,且按照打乱后的顺序,每两个元素分为1个批。

完整代码:

import pandas as pd
import numpy as np

# 读取excel表格中数据
df = pd.read_excel('./test.xlsx')
x = np.array(df[['band1', 'band2', 'band3', 'band4', 'band5', 'band6', 'band7']])
# 输出x
print(x)

# 注意此处被切片的数据的第一个维度值必须相同,例中都为10
db_train = tf.data.Dataset.from_tensor_slices((x, [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]))
print(db_train)

# 创建迭代器,枚举此数据集元素
iterator = db_train.make_one_shot_iterator()
# 获取下一次迭代值
ele = iterator.get_next()
# 创建会话
with tf.Session() as sess:
	for i in range(10):
		(x, y) = sess.run(ele)
		print(x)
		print(y)

# 打乱元素顺序,按照新顺序分批
# 打乱切片顺序,参数buffer_size表示从现有dataset中采样到固定数量的元素到维护的buffer中,并从buffer中随机选出一个元素
db_train = db_train.shuffle(buffer_size=10) 
db_train = db_train.batch(2) # 分批,参数为2表示在元素集中将2个连续元素分为1个批次
# 逐个输出
iterator = db_train.make_one_shot_iterator()
ele = iterator.get_next()
with tf.Session() as sess:
    for i in range(5):
        (x, y) = sess.run(ele)
        print(x)
        print(y)
        print('-----')

遇到再补充…

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章