深度解析MegEngine亞線性顯存優化技術

基於梯度檢查點的亞線性顯存優化方法[1]由於較高的計算/顯存性價比受到關注。MegEngine經過工程擴展和優化,發展出一套行之有效的加強版亞線性顯存優化技術,既可在計算存儲資源受限的條件下,輕鬆訓練更深的模型,又可使用更大batch size,進一步提升模型性能,穩定batchwise算子。

使用MegEngine訓練ResNet18/ResNet50,顯存佔用分別最高降低23%/40%;在更大的Bert模型上,降幅更是高達75%,而額外的計算開銷幾乎不變。

該技術已在MegEngine開源,歡迎大家上手使用:https://github.com/MegEngine

深度神經網絡訓練是一件複雜的事情,它體現爲模型的時間複雜度和空間複雜度,分別對應着計算和內存;而訓練時內存佔用問題是漂浮在深度學習社區上空的一塊烏雲,如何撥雲見日,最大降低神經網絡訓練的內存佔用,是一個繞不開的課題。

GPU顯卡等硬件爲深度學習提供了必需的算力,但硬件自身有限的存儲,限制了可訓練模型的尺寸,尤其是大型深度網絡,由此誕生出一系列相關技術,比如亞線性顯存優化、梯度累加、混合精度訓練、分佈式訓練,進行GPU顯存優化。

其中,亞線性顯存優化方法[1]由於較高的計算/顯存性價比備受關注;曠視基於此,經過工程擴展和優化,發展出加強版的MegEngine亞線性顯存優化技術,輕鬆把大模型甚至超大模型裝進顯存,也可以毫無壓力使用大batch訓練模型。

這裏將圍繞着深度學習框架MegEngine亞線性顯存優化技術的工程實現和實驗數據,從技術背景、原理、使用、展望等多個方面進行首次深入解讀。

背 景

在深度學習領域中,隨着訓練數據的增加,需要相應增加模型的尺寸和複雜度,進行模型「擴容」;而ResNet [2] 等技術的出現在算法層面掃清了訓練深度模型的障礙。不斷增加的數據和持續創新的算法給深度學習框架帶來了新挑戰,能否在模型訓練時有效利用有限的計算存儲資源,尤其是減少GPU顯存佔用,是評估深度學習框架性能的重要指標。

在計算存儲資源一定的情況下,深度學習框架有幾種降低顯存佔用的常用方法,其示例如下:

  • 通過合適的梯度定義,讓算子的梯度計算不再依賴於前向計算作爲輸入,從而in-place地完成算子的前向計算,比如Sigmoid、Relu等;

  • 在生命週期沒有重疊的算子之間共享顯存;

  • 通過額外的計算減少顯存佔用,比如利用梯度檢查點重新計算中間結果的亞線性顯存優化方法[1];

  • 通過額外的數據傳輸減少顯存佔用,比如把暫時不用的數據從GPU交換到CPU,需要時再從CPU交換回來。

上述顯存優化技術在MegEngine中皆有不同程度的實現,這裏重點討論基於梯度檢查點的亞線性顯存優化技術。

原 理

一個神經網絡模型所佔用的顯存空間大體分爲兩個方面:1)模型本身的參數,2)模型訓練臨時佔用的空間,包括參數的梯度、特徵圖等。其中最大佔比是 2)中以特徵圖形式存在的中間結果,比如,從示例[1]可知,根據實現的不同,從70%到90%以上的顯存用來存儲特徵圖。

這裏的訓練過程又可分爲前向計算,反向計算和優化三個方面,其中前向計算的中間結果最佔顯存,還有反向計算的梯度。第 1)方面模型自身的參數內存佔用最小。

MegEngine加強版亞線性顯存優化技術借鑑了[1]的方法,尤其適用於計算存儲資源受限的情況,比如一張英偉達2080Ti,只有11G的顯存;而更貴的Tesla V100,最大顯存也只有32G。

圖 1(a) 給出了卷積神經網絡的基本單元,它由Conv-BN-Relu組成。可以看到,反向計算梯度的過程依賴於前向計算獲取的中間結果,一個網絡需要保存的中間結果與其大小成正比,即顯存複雜度爲O(n)。

本質上,亞線性顯存優化方法是以時間換空間,以計算換顯存,如圖 1(b) 所示,它的算法原理如下:

  • 選取神經網絡中k個檢查點,從而把網絡分成k個block,需要注意的是,初始輸入也作爲一個檢查點;前向計算過程中只保存檢查點處的中間結果;

  • 反向計算梯度的過程中,首先從相應檢查點出發,重新計算單個block需要的中間結果,然後計算block內部各個block的梯度;不同block的中間結果計算共享顯存。這種方法有着明顯的優點,即大幅降低了模型的空間複雜度,同時缺點是增加了額外的計算:

  • 顯存佔用從O(n)變成O(n/k)+ O(k),O(n/k)代表計算單個節點需要的顯存,O(k)代表k個檢查點需要的顯存, 取k=sqrt(n),O(n/k)+ O(k)~O(sqrt(n)),可以看到顯存佔用從線性變成了亞線性;

  • 因爲在反向梯度的計算過程中需要從檢查點恢復中間結果,整體需要額外執行一次前向計算。

工 程

在[1]的基礎上,MegEngine結合自身實踐,做了工程擴展和優化,把亞線性顯存優化方法擴展至任意的計算圖,並結合其它常見的顯存優化方法,發展出一套行之有效的加強版亞線性顯存優化技術。

亞線性優化方法採用簡單的網格搜索(grid search)選擇檢查點,MegEngine在此基礎上增加遺傳算法,採用邊界移動、塊合併、塊分裂等策略,實現更細粒度的優化,進一步降低了顯存佔用。

如圖2所示,採用型號爲2080Ti的GPU訓練ResNet50,分別藉助基準、亞線性、亞線性+遺傳算法三種顯存優化策略,對比了可使用的最大batch size。僅使用亞線性優化,batch size從133增至211,是基準的1.6x;而使用亞線性+遺傳算法聯合優化,batch size進一步增至262,較基準提升2x。

圖2:三種顯存優化方法優化batch size的對比:ResNet50

通過選定同一模型、給定batch size,可以更好地觀察遺傳算法優化顯存佔用的情況。如圖3所示,隨着迭代次數的增加,遺傳算法逐漸收斂顯存佔用,並在第5次迭代之後達到一個較穩定的狀態。

圖3:遺傳算法收斂示意圖

此外,MegEngine亞線性優化技術通過工程改良,不再侷限於簡單的鏈狀結構和同質計算節點, 可用於任意的計算圖,計算節點也可異質,從而拓展了技術的適用場景;並可配合上述顯存優化方法,進一步降低模型的顯存佔用。

實 驗

MegEngine基於亞線性顯存技術開展了相關實驗,這裏固定batch size=64,在ResNet18和ResNet50兩個模型上,考察模型訓練時的顯存佔用和計算時間。

如圖4所示,相較於基準實現,使用MegEngine亞線性顯存技術訓練ResNet18時,顯存佔用降低32%, 計算時間增加24%;在較大的ReNet50上,顯存佔用降低40%,計算時間增加25%。同時經過理論分析可知,模型越大,亞線性顯存優化的效果越明顯,額外的計算時間則幾乎不變。

圖4:MegEngine亞線性優化技術實驗顯存/時間對比:ReNet18/ReNet50

在更大模型Bert上實驗數據表明,藉助MegEngine亞線性顯存技術,顯存佔用最高降低75%,而計算時間僅增加23%,這與理論分析相一致。有興趣的同學可前往MegEngine ModeHub試手更多模型實驗:https://megengine.org.cn/model-hub/

使 用

MegEngine官網提供了亞線性顯存優化技術的使用文檔。當你的GPU顯存有限,苦於無法訓練較深、較大的神經網絡模型,或者無法使用大batch進一步提升深度神經網絡的性能,抑或想要使batchwise算子更加穩定,那麼,MegEngine亞線性顯存優化技術正是你需要的解決方案。

上手MegEngine亞線性優化技術非常便捷,無需手動設定梯度檢查點,通過幾個簡單的參數,輕鬆控制遺傳算法的搜索策略。具體使用時,在MegEngine靜態圖接口中調用SublinearMemoryConfig設置trace的參數sublinear_memory_config,即可打開亞線性顯存優化:

from megengine.jit import trace, SublinearMemoryConfig
 
config = SublinearMemoryConfig()
 
@trace(symbolic=True, sublinear_memory_config=config)
def train_func(data, label, *, net, optimizer):
    ...

MegEngine在編譯計算圖和訓練模型時,雖有少量的額外時間開銷,但會顯著緩解顯存不足問題。下面以ResNet50爲例,說明MegEngine可有效突破顯存瓶頸,訓練batch size從100最高增至200:

import os
from multiprocessing import Process
 
 
def train_resnet_demo(batch_size, enable_sublinear, genetic_nr_iter=0):
    import megengine as mge
    import megengine.functional as F
    import megengine.hub as hub
    import megengine.optimizer as optim
    from megengine.jit import trace, SublinearMemoryConfig
    import numpy as np
 
    print(
        "Run with batch_size={}, enable_sublinear={}, genetic_nr_iter={}".format(
            batch_size, enable_sublinear, genetic_nr_iter
        )
    )
    # 使用GPU運行這個例子
    assert mge.is_cuda_available(), "Please run with GPU"
    try:
        # 我們從 megengine hub 中加載一個 resnet50 模型。
        resnet = hub.load("megengine/models", "resnet50")
 
        optimizer = optim.SGD(resnet.parameters(), lr=0.1,)
 
        config = None
        if enable_sublinear:
            config = SublinearMemoryConfig(genetic_nr_iter=genetic_nr_iter)
 
        @trace(symbolic=True, sublinear_memory_config=config)
        def train_func(data, label, *, net, optimizer):
            pred = net(data)
            loss = F.cross_entropy_with_softmax(pred, label)
            optimizer.backward(loss)
 
        resnet.train()
        for i in range(10):
            batch_data = np.random.randn(batch_size, 3, 224, 224).astype(np.float32)
            batch_label = np.random.randint(1000, size=(batch_size,)).astype(np.int32)
            optimizer.zero_grad()
            train_func(batch_data, batch_label, net=resnet, optimizer=optimizer)
            optimizer.step()
    except:
        print("Failed")
        return
 
    print("Sucess")
 
 
# 以下示例結果在2080Ti GPU運行得到,顯存容量爲 11 GB
 
# 不使用亞線性內存優化,允許的batch_size最大爲 100 左右
p = Process(target=train_resnet_demo, args=(100, False))
p.start()
p.join()
# 報錯顯存不足
p = Process(target=train_resnet_demo, args=(200, False))
p.start()
p.join()
 
# 使用亞線性內存優化,允許的batch_size最大爲 200 左右
p = Process(target=train_resnet_demo, args=(200, True, 20))
p.start()
p.join()

展 望

如上所述,MegEngine的亞線性顯存優化技術通過額外做一次前向計算,即可達到O(sqrt(n))的空間複雜度。如果允許做更多次的前向計算,對整個網絡遞歸地調用亞線性顯存算法,有望在時間複雜度爲O(n log n)的情況下,達到 O(log n)的空間複雜度。

更進一步,MegEngine還將探索亞線性顯存優化技術與數據並行/模型並行、混合精度訓練的組合使用問題,以期獲得更佳的集成效果。最後,在RNN以及GNN、Transformer等其他類型網絡上的使用問題,也是MegEngine未來的一個探索方向。

瞭解更多信息請查詢:

參考文獻

1. Chen, T., Xu, B., Zhang, C., & Guestrin, C. (2016). Training deep nets with sublinear memory cost. arXiv preprint arXiv:1604.06174.

2. He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 770-778).

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章