【小白學PyTorch】17 TFrec文件的創建與讀取

【新聞】:機器學習煉丹術的粉絲的人工智能交流羣已經建立,目前有目標檢測、醫學圖像、時間序列等多個目標爲技術學習的分羣和水羣嘮嗑的總羣,歡迎大家加煉丹兄爲好友,加入煉丹協會。微信:cyx645016617.

參考目錄:

本文的代碼已經上傳公衆號後臺,回覆【PyTorch】獲取。
第一次接觸到TFrec文件,我也是比較矇蔽的其實:

可以看到文件是.tfrec後綴的,而且先記住這個文件是186.72MB大小的。

1 爲什麼用tfrec文件

正常情況下我們用於訓練的文件夾內部往往會存着成千上萬的圖片或文本等文件,這些文件通常被散列存放。這種存儲方式有一些缺點:

  • 佔用磁盤空間;
  • 一個一個讀取文件消耗時間

而tfrec格式的文件存儲形式會很合理的幫我們存儲數據,核心就是tfrec內部使用Protocol Buffer的二進制數據編碼方案,這個方案可以極大的壓縮存儲空間

之前我們知道一個tfrec文件100多M,這是因爲這個tfrec文件內存儲了很多的圖片,類似於壓縮,對tfrec解壓縮後可以獲取到一部分的數據集,當我們把全部的rfrec文件都解壓縮後,可以獲取到全部的數據集。

值得一提的是,rfrec文件內除了可以存儲圖片,還可以存儲其他的數據,比方說圖片的label。字符串,float類型等都可以轉換成二進制的方法,所以什麼數據類型基本上都可以存儲到rfrec文件內,從而簡化讀取數據的過程。

2 tfrec文件的內部結構

tfrec文件時tensorflow的數據集存儲格式,tensorflow可以高效的讀取和處理這些數據集,因此我見過有的數據集因爲是tfrec文件,所以用TF讀取數據集,然後用pytorch訓練模型。

之前提到了tfrec文件裏面是有多個樣本的,所以tfrec可以爲是多個tf.train.Example文件組成的序列(每一個example是一個樣本),然後每一個tf.train.Example又是由若干個tf.train.Features字典組成。這個Features可以理解爲這個樣本的一些信息,如果是圖片樣本,那麼肯定有一個Features是圖片像素值數據,一個Features是圖片的標籤值;如果是預測任務,那麼這個Feature可能就是一些字符串類型的特徵

3 製作tfrec文件

import tensorflow as tf
import glob
# 先記錄一下要保存的tfrec文件的名字
tfrecord_file = './train.tfrec'
# 獲取指定目錄的所有以jpeg結尾的文件list
images = glob.glob('./*.jpeg')
with tf.io.TFRecordWriter(tfrecord_file) as writer:
    for filename in images:
        image = open(filename, 'rb').read()  # 讀取數據集圖片到內存,image 爲一個 Byte 類型的字符串
        feature = {  # 建立 tf.train.Feature 字典
            'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),  # 圖片是一個 Bytes 對象
            'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[1])),
            'float':tf.train.Feature(float_list=tf.train.FloatList(value=[1.0,2.0])),
            'name':tf.train.Feature(bytes_list=tf.train.BytesList(value=[str.encode(filename)]))
        }
        # tf.train.Example 在 tf.train.Features 外面又多了一層封裝
        example = tf.train.Example(features=tf.train.Features(feature=feature))  # 通過字典建立 Example
        writer.write(example.SerializeToString())  # 將 Example 序列化並寫入 TFRecord 文件

代碼中我們需要注意的地方是:

  • 先讀取圖片,然後構建一個字典來作爲這個example的格式;
  • 上面代碼中,字典中有四個屬性,首先是image圖片本身的像素值,然後有一個標籤,標籤是int類型,然後有一個float浮點類型,name是一個字符串類型,這個string類型的需要轉換成byte字節類型的才能進行存儲,所以這裏使用str.encode來把字符串轉換成字節;
  • 然後這個features再經過Example的封裝,再然後把這個example寫進這個tfrec文件中。

這一段代碼建議保存下來,方便以後的直接參考和複製。構建tfrec文件對於tensorflow處理圖片來說,應該是繞不過的一個步驟。

4 讀取tfrec文件

現在,我們運行完上面的代碼,應該生成了一個./train.tfrec文件,下面我們再對這個文件進行讀取。

import tensorflow as tf

dataset = tf.data.TFRecordDataset('./train.tfrec')

def decode(example):
    feature_description = {
        'image': tf.io.FixedLenFeature([], tf.string),
        'label': tf.io.FixedLenFeature([], tf.int64),
        'float': tf.io.FixedLenFeature([1, 2], tf.float32),
        'name': tf.io.FixedLenFeature([], tf.string)
    }
    feature_dict = tf.io.parse_single_example(example, feature_description)
    feature_dict['image'] = tf.io.decode_jpeg(feature_dict['image'])  # 解碼 JEPG 圖片
    return feature_dict

dataset = dataset.map(decode).batch(4)
for i in dataset.take(1):
    print(i['image'].shape)
    print(i['label'].shape)
    print(i['float'].shape)
    print(bytes.decode(i['name'][0].numpy()))
  • 首先使用專門用來讀取tfrec文件的方法tf.data.TFRecordDataset,進行讀取,創建了一個dataset,但是這個dataset並不能直接使用,需要對tfrec中的example進行一些解碼;
  • 自己寫一個解碼函數decode,首先寫一個特徵描述,我們知道在保存tfrec的時候每一個example有四個特徵,這裏需要對每一個特徵確定他的類型,是string還是int還是float這樣的。
  • 然後通過這個特徵描述和tf.io.parse_single_example方法,從example中提取到對應的特徵;
  • 因爲image是一個圖片張量,而我們讀取的時候是讀取的tf.string的類型,所以使用tf.io.decode_jpeg()來把字符串解碼成一個tensor張量。
  • 最後使用上節課講過的.batch(4)把數據集每一個batch包含四個樣本。

上面代碼輸出的結果爲:

需要注意的是這個如何把name轉換成string類型的,如果已經在本地跑完了上面的代碼,可以自己看看i['name']是一個什麼類型的,然後自己試試如何轉換成string類型的。上面的代碼是能成功轉換的。

下一次的內容就是如何構建模型,然後怎麼把數據集餵給模型。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章