tensorflow 學習資料收集

對於tensordlow 如何製作tfrecord數據集,並且 將數據集導入參考以下博客:http://blog.csdn.net/miaomiaoyuan/article/details/56865361

上篇博客並未使用batch分批導入


參考代碼:在自己的當前目錄下製作兩個文件夾,一個名稱爲ku, 另一個爲 biao 。分別在這兩個文件夾下存放數據圖片

import os 
import tensorflow as tf 
from PIL import Image  #注意Image,後面會用到

cwd= os.getcwd()
classes={'ku','biao'} #人爲 設定 2 類
writer= tf.python_io.TFRecordWriter("example.tfrecords") #要生成的文件

for index,name in enumerate(classes):
    class_path=cwd+'/'+name
    for img_name in os.listdir(class_path): 
        img_path=class_path+'/'+img_name #每一個圖片的地址

        img=Image.open(img_path)
        img= img.resize((128,128))
        img_raw=img.tobytes()#將圖片轉化爲二進制格式
        example = tf.train.Example(features=tf.train.Features(feature={
            "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
            'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
        })) #example對象對label和image數據進行封裝
        writer.write(example.SerializeToString())  #序列化爲字符串

writer.close()
其中enumerate 是對列表進行解析, 對於一個循環, index = 0, name = ku  ; 然後是 index = 1, name = biao

然後是依據上面所產生的example.tfrecords 文件 進行讀取,讓那後從新對 *.tfrecords 文件提取圖片 ,並且保存。

import os 
import tensorflow as tf 
from PIL import Image  

cwd = os.getcwd()
filename_tfrecords = cwd +'/' + 'example.tfrecords'
def read_and_decode(filename): # 讀入example.tfrecords
    filename_queue = tf.train.string_input_producer([filename])#生成一個queue隊列

    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)#返回文件名和文件
    features = tf.parse_single_example(serialized_example,
                                       features={
                                           'label': tf.FixedLenFeature([], tf.int64),
                                           'img_raw' : tf.FixedLenFeature([], tf.string),
                                       })#將image數據和label取出來

    img = tf.decode_raw(features['img_raw'], tf.uint8)
    img = tf.reshape(img, [128, 128, 3])  #reshape爲128*128的3通道圖片
#    img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 #在流中拋出img張量
    label = tf.cast(features['label'], tf.int32) #在流中拋出label張量
    return img, label


image , label = read_and_decode(filename_tfrecords)

init_op = tf.group(tf.global_variables_initializer(),
                   tf.local_variables_initializer())
with tf.Session() as sess: #開始一個會話

    sess.run(init_op)
    coord=tf.train.Coordinator()
    threads= tf.train.start_queue_runners(coord=coord)
    for i in range(16):
        example, l = sess.run([image,label])#在會話中取出image和label
        img=Image.fromarray(example, 'RGB')#這裏Image是之前提到的
        img.save(cwd+'/'+str(i)+'_Label_'+str(l)+'.jpg')#存下圖片;注意cwd後邊加上‘/’
        print(example, l)
    coord.request_stop()
    coord.join(threads)
    

得到的結果如下:






發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章