使用TensorFlow的DataSet解決樣本量過大,讀入導致的內存爆炸問題

問題描述

作爲一個優秀的菜鳥,內存爆炸這個坑一定會踩一下的:
以前的訓練數據都是幾千的樣本,直接用numpy讀取到內存中,不會出什麼問題

今天突然讀取一個4萬的數據集,我算了一下,我的樣本是227 * 227 * 3的,所以一個樣本大小大約是 (227 * 227 * 3 * 8)byte = 151KB 左右,4萬張圖片就是5898MB,相當於6個G了(不知道計算的對不對,難怪我8個G的內存一下就跑滿了),4萬數據集在模型訓練裏面並不算多,所以以前的方法一定有問題,所以,今天打算用半天的時間去TensorFlow官網查一下資料研究一下解決方案,順便記錄下來,方便其他菜鳥。

解決方案用TensorFlow的DataSet

原理
  • 既然一次性讀取這麼多數據會內存爆炸,那麼我就分批次讀取,當訓練數據是,纔將圖片讀取到內存中
  • 分批次讀取,數據就不能很好的打亂了,所以先讀取全部圖片的地址和對應的標籤,然後打亂,然後根據圖片地址數據集進行分批次讀取
  • 新問題,每次讀取新圖片需要很長的時間,完全滿足不了GPU的速度,就是每次迭代就會產生很大的延遲,勢必會影響訓練速度,所以就是預讀取(上次還沒訓練還沒結束,就提前先讀取下次迭代需要的數據)

代碼

  • 圖片地址一樣全部加入內存中,只是圖片地址,加入後打亂,爲了代碼簡潔,突出重點,圖片地址加入內存中,並打亂的方法就不再描述了
img_path = get_all_file(path) # 獲得所有的圖片地址
image_labels = get_lable_by_file_path(path,img_path) #獲得圖片地址對應的標籤
  • 新建一個處理圖片加載圖片的方法,這2個方法是在訓練模型時,對每一輪的樣本分別處理
# 處理圖片
def preprocess_image(image):
  image = tf.image.decode_jpeg(image, channels=3)
  image = tf.image.resize(image, [227, 227])
  image /= 255.0  # normalize to [0,1] range
  return image
# 加載圖片
def load_and_preprocess_image(path):
  image = tf.io.read_file(path)
  return preprocess_image(image)
關鍵代碼來了
  • 1.將圖片路徑處理爲path_ds (path的DataSet)
  • 2.圖片處理方法和圖片路徑(path的DataSet)綁定
  • 3.將標籤處理爲label_ds (標籤的DataSet)
  • 4.將圖片的DateSet和標籤的DataSet綁定
AUTOTUNE = tf.data.experimental.AUTOTUNE
path_ds = tf.data.Dataset.from_tensor_slices(img_path) #步驟1
image_ds = path_ds.map(load_and_preprocess_image, num_parallel_calls=AUTOTUNE)#步驟2
label_ds = tf.data.Dataset.from_tensor_slices(tf.cast(y, tf.int64)) #步驟3
image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))
  • 這需要處理一下ds,打亂順序、重啓repeat、樣本切片
BATCH_SIZE = 200
image_count = len(img_path)
ds = image_label_ds.apply(
  tf.data.experimental.shuffle_and_repeat(buffer_size=int(image_count/200)))
ds = ds.batch(BATCH_SIZE)
ds = ds.prefetch(buffer_size=AUTOTUNE)
  • 下面就是新建模型和訓練模型,就不多說了,我將訓練代碼貼上
model.fit(ds,
          epochs=200,
          steps_per_epoch=200
         )
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章