TorchAcc:基於 TorchXLA 的分佈式訓練框架

本文旨在探討阿里雲 TorchAcc,這是一個基於 PyTorch/XLA 的大模型分佈式訓練框架。

過去十年 AI 領域的顯著進步,關鍵在於訓練技術的革新和模型規模的快速攀升。儘管大模型展現了堪比人類的理解力,但其訓練卻對算力提出了極高的要求。唯有配備充足的計算資源,方能在海量數據上有效訓練大模型,確保其在有限時間內實現優質收斂。

圖片來源於 GTC 2024大會China AI Day 線上專場的演講《TorchAcc:基於TorchXLA的分佈式訓練框架》

根據上圖左側圖表顯示,過去五年,大模型規模的增長態勢尤爲突出,平均每兩年大小翻 15 倍;而對於 Transformer 爲代表的語言模型以及多模態模型而言,其規模膨脹速度更加驚人,每隔兩年以 750 倍劇增。對比之下,右側圖表揭示了一個明顯的矛盾點:不論是單個 GPU 的計算能力抑或是 GPU 顯存容量的發展速度,都無法跟上模型規模如此急劇的擴張步伐。這一現實狀況直接催生了對分佈式訓練的迫切需求。分佈式訓練不再侷限於以往單純的數據並行模式,而是在此基礎上,更加重視並採取模型並行策略,以彌補單個計算單元算力與存儲提升速度相對於模型規模增長的滯後性。

在分佈式訓練實踐中,開發人員普遍認同,構建模型並行的分佈式訓練系統相比數據並行更爲複雜。數據並行從分佈式角度來看,其邏輯相對直接和簡潔,因爲每個計算節點執行的任務本質上是對等且一致的。在這種情況下,只需在訓練過程末尾插入 AllReduce 步驟,將各個工作節點(worker)獨立計算出的梯度差異累加整合,然後求平均值,並將最終梯度結果廣播至所有參與工作的節點,用以同步更新全局模型參數。

這類簡單的分佈式訓練範式,確實呈現出類似單機計算的特點,主要涉及全局梯度同步的 AllReduce。然而步入大模型時代,由於模型規模過大,已無法容納於單個 GPU 之內,我們就必須採用模型並行策略,其開發難度也就陡然上升了。

原因是,模型並行需要根據模型的規模和結構來決定如何恰當地“分割”模型,即將其分割爲多個可以平衡計算負載的模塊。在不同的分割策略下,模型在各個節點上算子的算法實現方式會發生變化,同時,不同分割方法還會引起節點間通信原語的差異,需要精心選擇最優分割方案以及配套的通信原語。

在模型分割完成後,接下來的任務就是選用適合的通信原語,並精細地調度各個算子及其相關的通信操作,力求最大化計算與網絡通信的重疊(overlap),以充分發揮底層計算資源的效率。正是由於存在多種可能的分割選項與調度決策,尋求最優模型並行策略的複雜性明顯高於數據並行,對開發者的技巧和經驗提出了更高的要求。

圖片來源於 GTC 2024 大會 China AI Day 線上專場的演講《TorchAcc:基於 TorchXLA 的分佈式訓練框架》

本文將圍繞四個核心方面展開。首個議題是如何在 TorchAcc 中實現多樣化的並行策略,涵蓋了常規的數據並行,以及當下備受關注的 FSDP(Fully Sharded Data Parallel,又稱 ZeRO (Zero Redundancy Optimizer)) 。此外,還包括了模型並行的各種形態,諸如算子並行,即 Tensor Parallelism,以及流水線並行(Pipeline Parallelism)等。

TorchAcc 的一大亮點在於其能夠自動探尋並有機整合各類並行策略,併爲用戶提供高度自動化的分佈式策略配置方案;與此同時,爲了滿足高級開發者的定製化需求,TorchAcc 還提供了半自動化的控制接口,允許用戶介入並調整自動探索並行策略的過程,從而在兼顧靈活性的同時,最大程度地提升訓練效率和資源利用率。

通過上述方式,TorchAcc 有效地助力算法開發者將精力集中於模型自身的結構設計、訓練方法的優化,以及追求模型收斂性能的提升上,而非花費精力在分佈式訓練的具體實現細節。TorchAcc 將智能化地協助開發者探尋並實現最佳的分佈式訓練方案,從而顯著提升計算資源利用效率和算法迭代效率。

其次,模型並行技術的必要性是因爲大模型尺寸超出單個 GPU 顯存容量的限制。顯存容量對於模型訓練至關重要,如何打破顯存瓶頸,對於提升分佈式訓練的整體效率來說至關重要。因此,TorchAcc 提供了一種顯存智能分配器,通過對顯存資源的精細化調度與地址分配策略,最大限度地提高模型並行訓練時的效率,確保模型能充分利用現有的顯存地址空間。

再者,隨着模型結構日益複雜,且規模不斷增大,用戶對計算資源的需求也在持續攀升,因此,進一步優化模型在訓練過程中的計算密集度及減少訪存開銷也非常關鍵。

最後,考慮到當前數據中心基礎設施的發展趨勢,大模型訓練對網絡條件的要求日漸嚴苛。現代數據中心服務器間的互聯帶寬已達到 TB 級別,以滿足大規模模型並行訓練對高速數據交換的需求。然而,模型並行所帶來的複雜通信模式與高頻次的數據交互亦會對整體訓練效率構成挑戰。因此,如何有效利用網絡帶寬,減少通信過程在迭代計算中佔據的時間比例,也就成了訓練效率提升的另一重要因素。

在具體實現上,TorchAcc 通過一系列技術手段,成功地將用戶在前端,無論是基於 PyTorch 還是 TensorFlow 構建的模型訓練過程轉化爲統一的中間表示層(Model IR)的 graph。其中,對於 TensorFlow 而言,因其自身就是一種計算圖模型,轉化過程相對直接,而對於 PyTorch,我們採用了符號式追蹤(symbolic tracing)以及 LazyTensor 等技術捕獲計算圖,進而轉化爲 IR Graph。

圖片來源於 GTC 2024 大會 China AI Day 線上專場的演講《TorchAcc:基於 TorchXLA 的分佈式訓練框架》

基於中間表示層(IR Graph)的構建,TorchAcc 實施了一系列多元化的優化策略,涵蓋計算優化、存儲優化、通信優化以及分佈式策略優化,IR Graph 以各類組合並反覆執行這些優化的 Pass 後,最終得到一個最優的執行 Plan。然後交由底層 Backend 執行,以實現模型訓練性能的最大化提升。

通過這一整套方案,TorchAcc 在多個模型的分佈式訓練場景中表現出了顯著的性能優勢。部分模型的訓練過程得以實現高達 3 倍的性能提速,充分證明了 TorchAcc 在解決分佈式訓練難題上的高效性和實用性。

圖片來源於 GTC 2024 大會 China AI Day 線上專場的演講《TorchAcc:基於 TorchXLA 的分佈式訓練框架》

這張圖片主要展示了 TorchAcc 的框架總體架構。TorchAcc 以 Pytorch/XLA 爲基礎,並 TorchAcc 依託於 OpenXLA,構建了一套大模型訓練加速框架。TorchAcc 在處理使用不同前端構建的模型時,會靈活採用適宜的圖捕獲技術,如 Symbolic Trace 和 LazyTensor,進而生成兩種不同層級的圖表示:FX Graph 和 HLO Graph。其中,FX Graph 位於較高抽象層次,而 HLO Graph 則更爲底層。

基於捕獲到的模型計算圖,TorchAcc 即可進一步展開了四類優化工作,即前文提及的計算優化、存儲優化、通信優化以及分佈式策略優化。

圖片來源於 GTC 2024 大會 China AI Day 線上專場的演講《TorchAcc:基於 TorchXLA 的分佈式訓練框架》

在分佈式策略優化層面,TorchAcc 支持業界廣泛使用的各種並行策略,並能夠靈活地結合這些策略對給定模型進行有效的並行化處理。具體而言,對於數據並行 DP(Data Parallelism)、流水並行 PP(Pipeline Parallelism)以及 FSDP(Fully Sharded Data Parallel, 也稱爲 ZeRO)這三種分佈式策略,其實現和優化都是在 FX Graph 這一較高抽象層次上完成的。

選擇在 FX Graph 層面對並行策略進行操作的原因在於,這一層級所包含的關於計算圖結構和操作的信息已足夠豐富,足以支撐開發人員設計出適應不同並行策略的優化方案。相較於在更低層的 HLO Graph 上直接進行優化,由於 FX Graph 具有更高的抽象性和概括性,在這一層面上進行優化的成本通常較低,更容易實施高效且針對性強的分佈式策略調整。

以流水並行作爲例子,系統能夠自動檢測 FX Graph 層級上的不同階段,並確定合適的分割點,從而有效地將模型分割爲多個連續執行的階段,實現流水線並行化。在此過程中,我們可以利用 FX Graph 提供的詳細計算結構信息來進行智能分割。

至於 Tensor Parallelism (張量並行)和 Sequence Parallelism (序列並行)這兩種更爲複雜的並行策略,它們要求更爲細緻精確的信息以便進行決策。爲了實現這一點,系統需要對前向傳播和反向傳播的整個計算圖的執行計劃來進行分析。這時的工作主要在 HLO 這一低級別表示層面上進行。

通過利用 PyTorch/XLA 提供的 mark sharding 接口,系統能夠在模型參數上添加相應的拆分標記,然後將這些拆分信息傳遞給 OpenXLA 的 SPMD 優化 Pass,進而觸發計算圖的拆分、優化、推導和重寫過程,最終實現自動的 Tensor Parallelism 和 Sequence Parallelism 功能。

圖片來源於 GTC 2024 大會 China AI Day 線上專場的演講《TorchAcc:基於 TorchXLA 的分佈式訓練框架》

在算子優化層面,TorchAcc 引入 FlashAttention 技術來提升 Attention 模塊的執行效率。首先,通過 XLA 的 custom call 功能,將 FlashAttention 的實現無縫地融入到了 OpenXLA 編譯器和運行時框架中。這意味着 FlashAttention 可以直接在 XLA 內核層級被執行,從而充分利用硬件加速能力。

在整合過程中,要處理好在 PyTorch 與 XLA 之間 Tensor 數據的傳遞問題,確保在兩個系統間轉換時的數據一致性與性能優化,同時,還要妥善處理 FlashAttention內部參數傳遞等細節問題,保證在並行計算和優化的過程中,這些關鍵參數能夠正確且高效地應用到計算中,進一步提升模型在執行注意力機制部分的運算速度和資源利用率。

爲了用戶能便捷地使用 FlashAttention 優化功能,我們提供了兩種接口,用戶也可以直接通過 Python 接口調用預先寫好的 FlashAttention 算子,第三種方法是用戶可以使用我們在 OpenXLA 上寫好的 Pattern Match Pass,該 Pass 能夠自動識別計算圖中的 Attention Block,並將這部分計算結構提取出來,替換爲FlashAttention 的 custom call。這樣設計的優勢在於,既能充分利用 XLA 原本就十分出色的 Kernel fusion 等算子優化功能,又能結合 FlashAttention 帶來的先進計算優化技術。

圖片來源於 GTC 2024 大會 China AI Day 線上專場的演講《TorchAcc:基於 TorchXLA 的分佈式訓練框架》

在 Llama 2-7B 模型的性能測試中,我們能夠明顯觀察到上述計算優化帶來的效果。通過利用 XLA 自身的優化技術,尤其是 kernel fusion,我們將大量的訪存密集型算子做了有效合併,從而大幅減少其數量,在疊加 FlashAttention 後,優化性能進一步提升。

圖片來源於 GTC 2024 大會 China AI Day 線上專場的演講《TorchAcc:基於 TorchXLA 的分佈式訓練框架》

在通信優化層面,我們主要完成了三項核心任務以提升分佈式訓練效率:首先,我們合併了一些零散的 collective 通訊算子,通過減少算子數量來降低通訊開銷和調度複雜度,其次,我們將合併的 collective 通訊算子移至獨立的 CUDA Stream 上執行,這樣一來,就能夠異步實現計算與通訊的重疊執行。最後,我們充分利用了 OpenXLA 的 Latency Hiding Scheduler 功能,對通訊算子的調度進行了精細優化,使其儘早啓動和執行,從而增強通訊與計算之間的重疊效果。

圖片來源於 GTC 2024 大會 China AI Day 線上專場的演講《TorchAcc:基於 TorchXLA 的分佈式訓練框架》

通過在 Llama2 -7B 模型上進行的端到端多機性能測試,我們發現,應用了通訊優化策略後,在 128 張 GPU 卡上進行分佈式訓練,優化後的加速比從原來的 88 提升到了 116,通過 timeline 圖我們也可以直觀地看到,優化後的通訊算子更加有序,並且能夠更好地和計算重疊執行。

圖片來源於 GTC 2024 大會 China AI Day 線上專場的演講《TorchAcc:基於 TorchXLA 的分佈式訓練框架》

本文最後一個章節紹 TorchAcc 的顯存優化功能,該功能通過優化計算圖中算子的執行順序以及 Tensor 在顯存中的地址分配,來降低顯存開銷。

如圖舉例說明,假設有一個包含四個算子 V0、V1、V2、V3 的計算圖,如果不控制算子執行順序,如左圖所示按照 V0-V1-V2-V3 的順序執行,若每個 Tensor 按照默認方式進行顯存地址申請,則可能出現如 B 圖左半部分所示的情況,即顯存容量不足以容納所有 Tensor,導致 out of memory 錯誤。

然而,如果我們能夠預判並精細管理內存分配,即在分配地址時預知後續執行的算子序列,即可如 B 圖右半部分所示進行更優的顯存佈局,使得整體計算可在有限顯存內順利完成。更進一步,通過精確控制執行順序,比如按照 V0-V2-V1-V3 的方式執行,可以進一步壓縮顯存需求至原始需求的 70% 左右。

這一理念是基於 XLA 中間表示層已有的 scheduler 和 buffer 管理機制,我們在此基礎上提出了更先進的顯存優化方法。目前業界存在多種優化顯存分配的方法,如啓發式算法、約束求解等,但這些方法往往難以兼顧時效性和高效性,在實際生產環境的集羣中應用時可能存在侷限性。

圖片來源於 GTC 2024 大會 China AI Day 線上專場的演講《TorchAcc:基於 TorchXLA 的分佈式訓練框架》

在訓練場景中實現有效且高效的顯存優化是一項極具挑戰的任務,原因主要包括以下幾個方面:

  1. NP-Hard 問題本質:由於模型的規模、算子的種類繁多,以及算子間顯存分配的複雜性,顯存優化問題成爲一個典型的 NP-hard 問題,即找到全局最優解在計算上通常是不可行的。
  2. 算子執行靈活性:訓練過程中,前向傳播、反向傳播和權重更新等操作具有很高的靈活性,特別是在權重更新方面,梯度產生後隨時可以被用於權重更新,但不同的執行時機會影響顯存的申請和釋放,增加了優化難度。
  3. 顯存複用複雜性:在訓練過程中,前向和反向傳播可以通過複用顯存減少重新計算,但 Tensor 生命週期的多樣性和尺寸的變化使得顯存複用變得極爲複雜,這對啓發式算法等傳統優化手段構成了嚴峻挑戰。

爲了解決上述難題,我們採取了一種分治策略:

  1. Memory-aware Weight Update Scheduler:引入了顯存感知的權重更新調度器,它會根據梯度產生的時機、使用的優化器類型以及當前顯存資源狀況,選擇合適的權重更新時間點,避免即時更新加重顯存壓力,特別是對於複雜的優化器如 Adam,需考慮動量和其他變量的存儲。
  2. Graph 分割與局部優化:將大計算圖根據關鍵節點 (memory insensitive operator) 分割成多個內存無關性的子圖,子圖間執行順序固定,而子圖內部的執行順序則可以多樣化。通過這種方式,可以將複雜的全局線性規劃問題分解成多個局部問題,在子圖範圍內採用高效的優化方法,如線性規劃求解最優執行順序。

通過上述分治策略,最終我們能夠聚合這些子圖的求解結果,這也就是我們提出的 ROAM (Reorder Operators and Arrange Tensors Address to Reduce Memory Usage) 這一內存優化探索方式。

上述方法可以成功實現對顯存優化問題的高效處理。實驗結果顯示,與原生 PyTorch、啓發式算法以及 Facebook 近期基於整數線性規劃的優化方法等 baseline 相比,ROAM 分別節省了約 16%、13% 和 27% 的顯存開銷,且在優化時長和可擴展性方面表現出色,證實了這種方法的有效性。

圖片來源於 GTC 2024 大會 China AI Day 線上專場的演講《TorchAcc:基於 TorchXLA 的分佈式訓練框架》

圖片來源於 GTC 2024 大會 China AI Day 線上專場的演講《TorchAcc:基於 TorchXLA 的分佈式訓練框架》

從另一個維度衡量效果,我們考察了算法求解的時間開銷。實驗證明,在常見的深度學習場景中,我們的優化算法能夠在短短几分鐘內得出優化結果。從右圖所示對比中可以看出,相較於 Facebook 最近提出的 MODeL(一種基於線性規劃的優化方法),我們的方法在求解時間上實現了顯著的縮減。原因在於,MODeL 在處理大規模圖時並未對其進行有效分割,而我們的方法通過引入 memory-aware weight update scheduler 和子圖劃分策略,有效地降低了優化問題的空間複雜度,從而提高了求解效率。

綜上所述,TorchAcc 在顯存優化、計算優化、通信優化以及並行策略優化等方面均取得顯著成效,全方位提升了分佈式訓練的效率與性能。

演講人:林偉,阿里雲研究員,阿里雲人工智能平臺 PAI 技術負責人

原文鏈接

本文爲阿里雲原創內容,未經允許不得轉載。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章