對於tensordlow 如何製作tfrecord數據集,並且 將數據集導入參考以下博客:http://blog.csdn.net/miaomiaoyuan/article/details/56865361
上篇博客並未使用batch分批導入
參考代碼:在自己的當前目錄下製作兩個文件夾,一個名稱爲ku, 另一個爲 biao 。分別在這兩個文件夾下存放數據圖片
import os
import tensorflow as tf
from PIL import Image #注意Image,後面會用到
cwd= os.getcwd()
classes={'ku','biao'} #人爲 設定 2 類
writer= tf.python_io.TFRecordWriter("example.tfrecords") #要生成的文件
for index,name in enumerate(classes):
class_path=cwd+'/'+name
for img_name in os.listdir(class_path):
img_path=class_path+'/'+img_name #每一個圖片的地址
img=Image.open(img_path)
img= img.resize((128,128))
img_raw=img.tobytes()#將圖片轉化爲二進制格式
example = tf.train.Example(features=tf.train.Features(feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
})) #example對象對label和image數據進行封裝
writer.write(example.SerializeToString()) #序列化爲字符串
writer.close()
其中enumerate 是對列表進行解析, 對於一個循環, index = 0, name = ku ; 然後是 index = 1, name = biao然後是依據上面所產生的example.tfrecords 文件 進行讀取,讓那後從新對 *.tfrecords 文件提取圖片 ,並且保存。
import os
import tensorflow as tf
from PIL import Image
cwd = os.getcwd()
filename_tfrecords = cwd +'/' + 'example.tfrecords'
def read_and_decode(filename): # 讀入example.tfrecords
filename_queue = tf.train.string_input_producer([filename])#生成一個queue隊列
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)#返回文件名和文件
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw' : tf.FixedLenFeature([], tf.string),
})#將image數據和label取出來
img = tf.decode_raw(features['img_raw'], tf.uint8)
img = tf.reshape(img, [128, 128, 3]) #reshape爲128*128的3通道圖片
# img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 #在流中拋出img張量
label = tf.cast(features['label'], tf.int32) #在流中拋出label張量
return img, label
image , label = read_and_decode(filename_tfrecords)
init_op = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer())
with tf.Session() as sess: #開始一個會話
sess.run(init_op)
coord=tf.train.Coordinator()
threads= tf.train.start_queue_runners(coord=coord)
for i in range(16):
example, l = sess.run([image,label])#在會話中取出image和label
img=Image.fromarray(example, 'RGB')#這裏Image是之前提到的
img.save(cwd+'/'+str(i)+'_Label_'+str(l)+'.jpg')#存下圖片;注意cwd後邊加上‘/’
print(example, l)
coord.request_stop()
coord.join(threads)
得到的結果如下: