大規模訓練數據的shuffle

必要性12

以貓狗分類爲例, 假如數據集是

Dog,Dog,Dog,… ,Dog,Dog,Dog,Cat,Cat,Cat,Cat,… ,Cat,Cat

所有的狗都在貓前面,如果不shuffle,模型訓練一段時間內只看到了Dog,必然會過擬合於Dog,一段時間內又只能看到Cat,必然又過擬合於Cat,這樣的模型泛化能力必然很差。 那如果Dog和Cat一直交替,會不會就不過擬合了呢?

Dog,Cat,Dog,Cat,Dog ,Cat,Dog,…

依然會過擬合,模型是會記住訓練數據路線的,爲什麼呢?

當用隨機梯度下降法訓練神經網絡時,通常的做法是洗牌數據。在糾結細節的情況下,讓我們用一個極端的例子來解釋爲什麼shuffle是有用的。假設你正在訓練一個分類器來區分貓和狗,你的訓練集是50,000只貓後面跟着50,000只狗。如果你不洗牌,你的訓練成績就會很差。
嚴格地說,這個問題是由梯度噪聲中的序列相關性和參數更新的不可交換性引起的。首先我們需要明白固定的數據集順序,意味着給定迭代步,對應此迭代步的訓練數據是固定的。 假如目標函數是J=f(w,b)J=f(w, b),使用梯度下降優化JJ。給定權重取值wbw、b和迭代步step的情況下,固定的數據集順序意味着固定的訓練樣本,也就意味着權值更新的方向是固定的,而無順序的數據集,意味着更新方向是隨機的。所以固定的數據集順序,嚴重限制了梯度優化方向的可選擇性,導致收斂點選擇空間嚴重變少,容易導致過擬合。所以模型是會記住數據路線的,所以shuffle很重要,一定shuffle。

2-pass-shuffle算法

我們假設一個數據集XmX^m包含樣本數目爲mm, 大小爲SXmS_{X^m}, 計算內存RAM大小爲SRAMS_{RAM}.
SX<SRAMS_X \lt S_{RAM}的時,我們完全可以使用訓練框架中的Dataset shuffle函數進行處理,如Fisher Yates Shuffle。但我們實際應用場景中,SXSRAMS_X \ggg S_{RAM}. 本節將針對這種業務場景進行討論。

# https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle
def fisher_yates_shuffle(dataset=list()):
	size = len(dataset)
	for i in range(size-1):
		j = random.randint(i, size-1)
		dataset[i], dataset[j] = dataset[j], dataset[i]

分塊是一種很普遍的想法,但是如何分塊,以及分塊後如何隨機地寫回到文件中才是最終目標。而且要注意的是,數據集XmX^m的每一次訪問都存在大量的IO,將非常耗時。因此,設計隨機算法的過程中,IO也要考慮在內。

2-pass-shuffle算法過程中包括塊id的shuffle和塊內部的shuffle. Fisher Yates算法和 twice pass shuffle算法如下。

需要自己設置一個超參數MM, 直觀上需要滿足的條件:
MSXmSRAMM \ge \frac{S_{X^m}}{S_{RAM}}

python代碼模擬實現如下:

import os
import random

# https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle
def fisher_yates_shuffle(dataset=list()):
	size = len(dataset)
	for i in range(size-1):
		j = random.randint(i, size-1)
		dataset[i], dataset[j] = dataset[j], dataset[i]

def twice_pass_shuffle(dataset, total_size, M):
	# first pass
	p = [[] for _ in range(M)]
	for i in range(total_size):
		j = random.randint(0, M-1)
		p[j].append(dataset[i])
	# second pass
	result = []
	for j in range(M):
		fisher_yates_shuffle(p[j])
		result.extend(p[j])
	return result

if __name__ == '__main__':
	l = [i for i in range(1,101)]
	print("befor shuffle:\n", l)
	result = twice_pass_shuffle(l, total_size=100, M=10)
	print("\nshuffle result:\n", result)

當我面對這個問題的時候,第一次並沒有給出這個答案,第二次纔給出接近這個算法的答案。
之前的算法分M塊,然後M塊之間兩兩洗牌,進行M(M1)2\frac{M(M-1)}{2}次。這個方法看上去好像可以,但是存在以下問題:

  • IO次數太多,性能應該不好
  • 兩塊之間怎麼洗牌的算法決定這shuffle的結果是否隨機,而且當時我並沒有給出比較好的洗牌策略。

2-pass-shuffle的其他問題處理

還有可能遇到的問題就是第一次pass過程中,每個分塊的數據並不是相等的,很有可能有那麼一兩塊的大小比SRAMS_{RAM}大,導致後面不能進行內存內shuffle. 這個問題在how to shuffle a big dataset3這篇文章中有一個解決方案。其實還有簡單粗暴的方案就是針對這個特殊的分塊進行單獨處理,再進行一次類似的2-pass-shuffle就是了。

如何訓練過程中,隨機從一個超大數據集合中獲取訓練數據4

The Dataset.shuffle() implementation is designed for data that could be shuffled in memory; we’re considering whether to add support for external-memory shuffles, but this is in the early stages. In case it works for you, here’s the usual approach we use when the data are too large to fit in memory:

  1. Randomly shuffle the entire data once using a MapReduce/Spark/Beam/etc. job to create a set of roughly equal-sized files (“shards”).
  2. In each epoch:
  • Randomly shuffle the list of shard filenames, using Dataset.list_files(...).shuffle(num_shards).
  • Use dataset.interleave(lambda filename: tf.data.TextLineDataset(filename), cycle_length=N) to mix together records from N different shards.
  • Use dataset.shuffle(B) to shuffle the resulting dataset. Setting B might require some experimentation, but you will probably want to set it to some value larger than the number of records in a single shard.

參考地址


  1. https://juejin.im/post/5c6b989bf265da2ddd4a5261 “數據集shuffle的重要性” ↩︎

  2. https://blog.janestreet.com/how-to-shuffle-a-big-dataset/#whyshuffle “why shuffle” ↩︎

  3. https://blog.janestreet.com/how-to-shuffle-a-big-dataset “how to shuffle a big dataset” ↩︎

  4. https://github.com/tensorflow/tensorflow/issues/14857 ↩︎

發佈了44 篇原創文章 · 獲贊 17 · 訪問量 5萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章