使用 tfrecord 製作自己的數據集 (附上源代碼)

相信很多剛入手深度學習的人,最早接觸的程序就是Mnist 手寫數字的識別。Mnist 數據集都已被事先整理好,我們只有拿來用即可。但是如何製作自己的數據集,相信很多剛入門的人還是會一團霧水。作爲剛入門不就的小白,我也花了很長時間才完整的製作了自己的數據集。製作自己的數據集,大概可以分爲這麼幾步:

Step1.首先要去收集自己的數據吧,可以是自己拍的圖片,也可以是那種網上爬蟲爬下來的圖片。

Step2.建議最好將趴下來的圖片重新命名,再用去訓練,這樣圖片數據看起來比較整齊。特別是對有強迫症的同學來說,這是很重要的,總感覺名字不統一會覺得怪怪的。命名可以採用 name1,name2,name3.......這種形式。具體如何命名,我在之前的博客中也有詳細介紹過,有需要的同學可以參考看下  點擊打開鏈接    當然不改名字的話,也沒什麼影響,只是讀取圖片時需要採用不同的方法就好。

Step3. 接下來就是讀取圖片,在讀取圖片時也有些需要注意的細節。我在另一片博客中給出了詳細的介紹  點擊打開鏈接 

並製作成tfrecord形式,具體代碼如下

import tensorflow as tf  
import numpy as np  
import os  
import cv2
from skimage import transform 
import skimage.io as io  

#%%
#def rename(file_dir,name):
#    '''將網上爬下來的圖片重命名(更好的觀看)'''
#    i=0
#    for file in os.listdir(file_dir):  #獲取該路徑文件下的所有圖片
#        src = os.path.join(os.path.abspath(file_dir), file) #目標文件夾+圖片的名稱
#        dst = os.path.join(os.path.abspath(file_dir),  name+str(i) + '.jpg')#目標文件夾+新的圖片的名稱
#        os.rename(src, dst)
#        i=i+1        
#rename(file_dir+'/roses','rose')
#rename(file_dir+'/sunflowers','sunflower')

'''要將圖片的路徑完整的保存下來''' 
def get_files(file_dir): 
    roses=[]
    label_roses=[]
    sunflowers=[]
    label_sunflowers=[]

    for file in os.listdir(file_dir+'/roses'):  #獲取該路徑文件下的所有圖片
        roses.append(file_dir+'/roses' +'/'+file)  #將圖片存入一個列表中
        label_roses.append(0) # 將roses的標籤設爲0
     
    for file in os.listdir(file_dir+'/sunflowers'):
       sunflowers.append(file_dir+'/sunflowers' +'/'+file)
       label_sunflowers.append(1)     # 將sunflower的標籤設爲1 
    print('There are %d roses \n There are %d sunflowers' %(len(roses), len(sunflowers)))  
    
#把cat和dog合起來組成一個list(img和lab)
    image_list = np.hstack((roses, sunflowers))
    label_list = np.hstack((label_roses, label_sunflowers))

    #利用shuffle打亂順序
    temp = np.array([image_list, label_list]) #轉換成2維矩陣
    temp = temp.transpose() #轉置
    np.random.shuffle(temp) #按行隨機打亂順序


    #從打亂的temp中再取出list(img和lab)
    image_list = list(temp[:, 0])  #取出第0列數據,即圖片路徑
    label_list = list(temp[:, 1]) #取出第0列數據,即圖片路徑
    label_list = [int(i) for i in label_list] #轉換成int數據類型
    return image_list, label_list  
  
    
#%%    
def int64_feature(value):  
  """Wrapper for inserting int64 features into Example proto."""  
  if not isinstance(value, list):   #標籤的轉化形式
    value = [value]  
  return tf.train.Feature(int64_list=tf.train.Int64List(value=value))  
  
def bytes_feature(value):   #圖片的轉換格式
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))  
  

  
def convert_to_tfrecord(images, labels, save_dir, name):  
    '''''convert all images and labels to one tfrecord file. 
    Args: 
        images: list of image directories, string type 
        labels: list of labels, int type 
        save_dir: the directory to save tfrecord file, e.g.: '/home/folder1/' 
        name: the name of tfrecord file, string type, e.g.: 'train' 
    Return: 
        no return 
    Note: 
        converting needs some time, be patient... 
    '''  
    filename = os.path.join(save_dir, name + '.tfrecords')
    n_samples = len(labels)
    
    if np.shape(images)[0] != n_samples:
        raise ValueError('Images size %d does not match label size %d.' %(images.shape[0], n_samples))
    
    # wait some time here, transforming need some time based on the size of your data.
    writer = tf.python_io.TFRecordWriter(filename)
    print('\nTransform start......')
    for i in np.arange(0, n_samples):
        try:  
            '''因爲cv2讀出的圖片保存形式是BGR,要轉換成RGB形式'''
            image = cv2.imread(images[i])    
            image = cv2.resize(image, (208, 208))    
            b,g,r = cv2.split(image)    
            rgb_image = cv2.merge([r,g,b])  
#            image = io.imread(images[i]) # type(image) must be array!  #這邊是兩種讀取圖像的方法  
#            image =transform.resize(image, (208, 208))
#            img = image * 255 
#            img = img.astype(np.uint8)   
              
            image_raw =  rgb_image.tostring()
            
            label = int(labels[i])
            example = tf.train.Example(features=tf.train.Features(feature={
                            'label':int64_feature(label),
                            'image_raw': bytes_feature(image_raw)}))
            writer.write(example.SerializeToString())
        except IOError as e:
            print('Could not read:', images[i])
            print('error: %s' %e)
            print('Skip it!\n')
    writer.close()
    print('Transform done!')


Step4. 生成tfrecord後,接着便是tfrecord的讀取了  (只是在訓練過程中才調用這個函數,主要的作用就是講將tfrecord模式解碼)

def read_and_decode(tfrecords_file, batch_size):  
    '''''read and decode tfrecord file, generate (image, label) batches 
    Args: 
        tfrecords_file: the directory of tfrecord file 
        batch_size: number of images in each batch 
    Returns: 
        image: 4D tensor - [batch_size, width, height, channel] 
        label: 1D tensor - [batch_size] 
    '''  
    # make an input queue from the tfrecord file  
    filename_queue = tf.train.string_input_producer([tfrecords_file])  
      
    reader = tf.TFRecordReader()  
    _, serialized_example = reader.read(filename_queue)  
    img_features = tf.parse_single_example(  
                                        serialized_example,  
                                        features={  
                                               'label': tf.FixedLenFeature([], tf.int64),  
                                               'image_raw': tf.FixedLenFeature([], tf.string),  
                                               })  
    image = tf.decode_raw(img_features['image_raw'], tf.uint8)  
      
    ##########################################################  
    # you can put data augmentation here, I didn't use it  
    ##########################################################  
    # all the images of notMNIST are 28*28, you need to change the image size if you use other dataset.  
      
    image = tf.reshape(image, [208, 208,3])  
    label = tf.cast(img_features['label'], tf.float32)      
    image = tf.image.per_image_standardization(image)  
    image_batch, label_batch = tf.train.batch([image, label],  
                                                batch_size= batch_size,  
                                                num_threads= 64,   
                                                capacity = 2000)  #線程的個數及最大存儲量
    return image_batch, tf.reshape(label_batch, [batch_size])  



發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章