爲了增強人工智能的趣味性,我們開展了你畫我猜這一競賽。你畫我猜數據集中包括了40種生活中常見的類別如飛機、蘋果、籃球等。圖片信息以json格式存儲,在每個json中記錄了用戶每一筆簡筆畫對應的橫座標集合和縱座標集合。其中訓練集、驗證集和測試集劃分比例爲6:2:2.
以下爲給的json示例:
{"drawing": [[[18, 21, 15, 17, 23], [255, 185, 106, 97, 89]], [[17, 7, 3, 0, 9, 19, 29, 40, 41, 30, 28], [70, 60, 50, 26, 4, 0, 12, 39, 49, 75, 88]], [[28, 25, 22, 13, 11, 14, 15, 7, 11], [63, 10, 67, 53, 30, 17, 28, 63, 58]]]}
可以看到數據給的格式是用戶繪畫是的點的順序。由於本人對nlp不是很熟悉,因此將此問題轉化爲圖像分類問題。具體就是將這些點在(256, 256, 3)大小的圖像上面顯示。可視化結果如下:
數據預處理
- 將給的json數據格式轉化爲圖片格式,然後以圖片分類的思路進行分類。
數據增強
- 訓練的時候使用隨機翻轉、隨機尺度以及mixup來進行數據增強。
- mixup是一種非常規的數據增強方法,一個和數據無關的簡單數據增強原則,其以線性插值的方式來構建新的訓練樣本和標籤。最終對標籤的處理如下公式所示,這很簡單但對於增強策略來說又很不一般。
實現mixup數據增強很簡單,其實我個人認爲這就是一種抑制過擬合的策略,增加了一些擾動,從而提升了模型的泛化能力。
def get_batch(x, y, step, batch_size, alpha=0.2):
candidates_data, candidates_label = x, y
offset = (step * batch_size) % (candidates_data.shape[0] - batch_size)
train_features_batch = candidates_data[offset:(offset + batch_size)]
train_labels_batch = candidates_label[offset:(offset + batch_size)]
if alpha == 0:
return train_features_batch, train_labels_batch
if alpha > 0:
weight = np.random.beta(alpha, alpha, batch_size)
x_weight = weight.reshape(batch_size, 1, 1, 1)
y_weight = weight.reshape(batch_size, 1)
index = np.random.permutation(batch_size)
x1, x2 = train_features_batch, train_features_batch[index]
x = x1 * x_weight + x2 * (1 - x_weight)
y1, y2 = train_labels_batch, train_labels_batch[index]
y = y1 * y_weight + y2 * (1 - y_weight)
return x, y
而模型增前後的效果如下:
模型選擇
- 模型選擇上面採用先小模型驗證代碼流程,後用大模型漲分的策略。
- 首先使用resnet18模型可以快速驗證想法的正確性,將模型表達能力增強放在最後一步提升。使用resnet18可以達到84的準確率。
- 最終使用的senet154則可以將準確率提升到91.5左右。
模型優化
- 模型優化使用SGD算法,採用動量優化,權重衰減0.0001.具體代碼如下:
optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, momentum=0.9, weight_decay=0.0001, nesterov=True)
- 學習率使用帶有warmup的的餘弦退火動態調整。其示意如下:
最後的總結
首先要使用一個簡單的基礎模型將流程跑通,得到一個baseline。之後在baseline的基礎上添加測試訓練技巧,這樣可以快速漲分!!
本文爲作者在FlyAI平臺發佈的原創內容,採用知識共享署名-非商業性使用-禁止演繹 4.0 國際許可協議進行許可,轉載請附上原文出處鏈接和本聲明。
本文鏈接地址(視頻講解直達):https://www.flyai.com/n/110736