利用Tensorflow構建自己的圖片數據集TFrecords

相信很多初學者和我一樣,雖然用了那麼久的tensorflow,也嘗試了很多的實例,但那些實例基本都是直接利用官方文檔現成的MNIST和cifar_10數據庫,而一旦需要自己構建數據集時,完全不知道該如何製作並輸入自己改的數據。另外,雖然也有一些人提供了相關的操作,但是總是或多或少存在各種各樣的問題。今天給大家分享我的Tensorflow製作數據集的學習歷程。 TensorFlow提供了標準的TFRecord 格式,而關於 tensorflow 讀取數據, 官網也提供了3中方法 :
1 Feeding: 在tensorflow程序運行的每一步, 用python代碼在線提供數據
2 Reader : 在一個計算圖(tf.graph)的開始前,將文件讀入到流(queue)中
3 在聲明tf.variable變量或numpy數組時保存數據。受限於內存大小,適用於數據較小的情況

特此聲明:初次寫博客,如有問題,如有問題多體諒;另外文本參考了下面的博客(提供鏈接如下),因而讀者可結合兩者取齊所需。

點擊打開鏈接http://blog.csdn.net/miaomiaoyuan/article/details/56865361

在本文,主要介紹第二種方法,利用tf.record標準接口來讀入文件

第一步,準備數據

先在網上下載一些不同類的圖片集,例如貓、狗等,也可以是同一種類,不同類型的,例如哈士奇、吉娃娃等都屬於狗類;此處筆者預先下載了哈士奇、吉娃娃兩種狗的照片各20張,並分別將其放置在不同文件夾下。如下:


第二步,製作TFRecord文件

注意:tfrecord會根據你選擇輸入文件的類,自動給每一類打上同樣的標籤 如在本例中,只有0,1 兩類

#-----------------------------------------------------------------------------
#encoding=utf-8
import os
import tensorflow as tf
from PIL import Image

cwd = 'E:/train_data/picture_dog//' 
classes = {'husky','jiwawa'}


#製作TFRecords數據
def create_record():
    writer = tf.python_io.TFRecordWriter("dog_train.tfrecords")
    for index, name in enumerate(classes):
        class_path = cwd +"/"+ name+"/"
        for img_name in os.listdir(class_path):
            img_path = class_path + img_name
            img = Image.open(img_path)
            img = img.resize((64, 64))
            img_raw = img.tobytes() #將圖片轉化爲原生bytes
            print (index,img_raw)
            example = tf.train.Example(
               features=tf.train.Features(feature={
                    "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
                    'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
               }))
            writer.write(example.SerializeToString())
    writer.close()
#-------------------------------------------------------------------------

將上面的代碼編輯完成後,點擊運行,就會生成一個dog_train.TFRecords文件,如下圖所示:


TFRecords文件包含了tf.train.Example 協議內存塊(protocol buffer)(協議內存塊包含了字段 Features)。我們可以寫一段代碼獲取你的數據, 將數據填入到Example協議內存塊(protocol buffer),將協議內存塊序列化爲一個字符串, 並且通過tf.python_io.TFRecordWriter 寫入到TFRecords文件。

第三步,讀取TFRecord文件

#-------------------------------------------------------------------------
cwd = 'E:/train_data/picture_dog//' 
#讀取二進制數據

def read_and_decode(filename):
    # 創建文件隊列,不限讀取的數量
    filename_queue = tf.train.string_input_producer([filename])
    # create a reader from file queue
    reader = tf.TFRecordReader()
    # reader從文件隊列中讀入一個序列化的樣本
    _, serialized_example = reader.read(filename_queue)
    # get feature from serialized example
    # 解析符號化的樣本
    features = tf.parse_single_example(
        serialized_example,
        features={
            'label': tf.FixedLenFeature([], tf.int64),
            'img_raw': tf.FixedLenFeature([], tf.string)
        })
    label = features['label']
    img = features['img_raw']
    img = tf.decode_raw(img, tf.uint8)
    img = tf.reshape(img, [64, 64, 3])
    #img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
    label = tf.cast(label, tf.int32)
    return img, label
#--------------------------------------------------------------------------    
一個Example中包含Features,Features裏包含Feature(這裏沒s)的字典。最後,Feature裏包含有一個 FloatList, 或者ByteList,或者Int64List。另外,需要我們注意的是:feature的屬性“label”和“img_raw”名稱要和製作時統一 ,返回的img數據和label數據一一對應。

第四步,TFRecord的顯示操作

如果想要檢查分類是否有誤,或者在之後的網絡訓練過程中可以監視,輸出圖片,來觀察分類等操作的結果,那麼我們就可以session回話中,將tfrecord的圖片從流中讀取出來,再保存。因而自然少不了主程序的存在。

#---------主程序----------------------------------------------------------
if __name__ == '__main__':
    create_record()
    batch = read_and_decode('dog_train.tfrecords')
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    
    with tf.Session() as sess: #開始一個會話  
        sess.run(init_op)  
        coord=tf.train.Coordinator()  
        threads= tf.train.start_queue_runners(coord=coord)  
        for i in range(40):  
            example, lab = sess.run(batch)#在會話中取出image和label  
            img=Image.fromarray(example, 'RGB')#這裏Image是之前提到的  
            img.save(cwd+'/'+str(i)+'_Label_'+str(lab)+'.jpg')#存下圖片;注意cwd後邊加上‘/’  
            print(example, lab)  
        coord.request_stop()  
        coord.join(threads) 
        sess.close()
#-----------------------------------------------------------------------------
進過上面的一通操作之後,我們便可以得到和tensorflow官方的二進制數據集一樣的數據集了,並且可以按照自己的設計來進行。

下面附上該程序的完整代碼,僅供參考。

#-----------------------------------------------------------------------------
#encoding=utf-8
import os
import tensorflow as tf
from PIL import Image

cwd = 'E:/train_data/picture_dog//' 
classes = {'husky','jiwawa'}


#製作TFRecords數據
def create_record():
    writer = tf.python_io.TFRecordWriter("dog_train.tfrecords")
    for index, name in enumerate(classes):
        class_path = cwd +"/"+ name+"/"
        for img_name in os.listdir(class_path):
            img_path = class_path + img_name
            img = Image.open(img_path)
            img = img.resize((64, 64))
            img_raw = img.tobytes() #將圖片轉化爲原生bytes
            print (index,img_raw)
            example = tf.train.Example(
               features=tf.train.Features(feature={
                    "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
                    'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
               }))
            writer.write(example.SerializeToString())
    writer.close()
#-------------------------------------------------------------------------

#讀取二進制數據

def read_and_decode(filename):
    # 創建文件隊列,不限讀取的數量
    filename_queue = tf.train.string_input_producer([filename])
    # create a reader from file queue
    reader = tf.TFRecordReader()
    # reader從文件隊列中讀入一個序列化的樣本
    _, serialized_example = reader.read(filename_queue)
    # get feature from serialized example
    # 解析符號化的樣本
    features = tf.parse_single_example(
        serialized_example,
        features={
            'label': tf.FixedLenFeature([], tf.int64),
            'img_raw': tf.FixedLenFeature([], tf.string)
        })
    label = features['label']
    img = features['img_raw']
    img = tf.decode_raw(img, tf.uint8)
    img = tf.reshape(img, [64, 64, 3])
    #img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
    label = tf.cast(label, tf.int32)
    return img, label
#--------------------------------------------------------------------------    
#---------主程序----------------------------------------------------------
if __name__ == '__main__':
    create_record()
    batch = read_and_decode('dog_train.tfrecords')
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    
    with tf.Session() as sess: #開始一個會話  
        sess.run(init_op)  
        coord=tf.train.Coordinator()  
        threads= tf.train.start_queue_runners(coord=coord)  
        for i in range(40):  
            example, lab = sess.run(batch)#在會話中取出image和label  
            img=Image.fromarray(example, 'RGB')#這裏Image是之前提到的  
            img.save(cwd+'/'+str(i)+'_Label_'+str(lab)+'.jpg')#存下圖片;注意cwd後邊加上‘/’  
            print(example, lab)  
        coord.request_stop()  
        coord.join(threads) 
        sess.close()
#-----------------------------------------------------------------------------
運行上述的完整代碼,便可以 將從TFRecord中取出的文件保存下來了。如下圖:


每一幅圖片的命名中,第二個數字則是 label,吉娃娃都爲1,哈士奇都爲0;通過對照圖片,可以發現圖片分類正確。

發佈了31 篇原創文章 · 獲贊 111 · 訪問量 25萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章