[源碼解析] 快手八卦 --- 機器學習分佈式訓練新思路(1)

[源碼解析] 快手八卦 --- 機器學習分佈式訓練新思路(1)

0x00 摘要

“Bagua“ 是快手和蘇黎世理工(ETH Zürich)聯合開發的分佈式訓練框架。其專門針對分佈式的場景設計特定的優化算法,實現算法和系統層面的聯合優化,力圖極致化分佈式訓練的效率。其特點是:

  • 並行性能顯著提高;

  • 對網絡環境更魯棒;

  • “一鍵式”使用;

  • 分佈式通訊算法易拓展性;

  • 可用於工業級場景大規模使用;

  • 安全、故障易排查;

本文以:

爲基礎來分析學習。本文學習“bagua"總體設計思路和負載均衡數據加載器。

0x01 設計思路

以下摘錄於快手官方帖子 快手八卦!突破 TensorFlow、PyTorch 並行瓶頸的開源分佈式訓練框架來了! 和 ETH PPT,按照自己理解有調整。

1.1 如何通信

在數據並行之中,從單機單卡的訓練到多機多卡訓練的核心,是每個卡把自己的計算結果進行累加和傳播,所以一個關鍵點是兩個worker之間如何進行通信。

這個過程好比每個人把自己知道的信息傳遞給他人,然後又從其他人那裏獲取信息,最後完成全局的信息同步。如果把計算單元之間的信息同步類比爲人與人之間的信息同步,那麼社會實踐經驗告訴我們,“八卦”可能是消息傳遞最高效的模式。“八卦”消息傳播具有去中心化、異步通訊、信息壓縮的特點,這與 Bagua 裏面實現的通訊算法剛好一一呼應。

1.2 通信模式分類

針對通信模式,有如下分類。

1.2.1 系統架構

按照系統架構來區分,是參數服務器和Allreduce。

下圖是參數服務器和Allreduce範式的圖例。

  • 參數服務器架構中,模型可以被分割成分片(shard)並分佈到多個節點(我們稱這些節點爲 "參數服務器")。在訓練階段,worker定期從參數服務器獲取模型,利用計算單元(如GPU)進行前向和後向傳播,並將梯度推送給參數服務器,而參數服務器彙總梯度並更新參數。
  • Allreduce範式之中,所有worker都與他們的鄰居合作進行模型/梯度交換。現有的系統通常採用環形拓撲結構進行兩階段的交流:首先,範式將模型/梯度劃分爲n個塊(其中n爲節點數),並使用不同起點和終點的n個環來聚合n個塊;其次,位於不同節點的每個塊的聚合結果會在環內進行廣播。

1.2.2 同步角度

從通信同步角度看可以分爲同步或是異步(Synchronous or Asynchronous):

  • 同步模式中,在每一次迭代過程中,所有工作節點都需要進行通信,並且下一步迭代必須等待當前迭代的通信完成才能開始。
  • 反之,異步式分佈算法 則不需要等待時間:當某個節點完成計算後就可直接傳遞本地梯度,進行模型更新。

1.2.3 通信拓撲

從通信拓撲角度看可以分成中心化或是去中心化(Centralized or Decentralized):

  • 在中心化的通訊模式中,梯度或模型的同步過程需要所有的工作節點進行參與,因此,較高的網絡延時往往會導致訓練效率的降低。
  • 去中心化的通信模式往往可以有效的解決上述問題:在該模式下,工作節點可以被連接成特定的拓撲結構(例如環),在通信過程中,每一個工作節點只與和它相鄰的節點進行通信。

1.2.4 壓縮

從通信壓縮與否角度看,有完整精度模式或信息壓縮模式(Full-Precision or Low-Precision)兩種:

  • 完整精度模式會使用與本地模型相同的 32 位浮點數(float32)進行傳輸。
  • 另一方面,在通訊存在瓶頸的情況下,基於大量已有研究通過量化 (quantization) 或稀疏化 (sparsification) 等方法壓縮梯度,再用壓縮後的梯度更新參數。在很多場景下,可以達到和完整精度相同的精度,同時提升通訊效率。

1.3 挑戰

快手在實現之中,遇到了三個挑戰:

  • 理論基礎:通信模式需要有理論的支撐,需要嚴格在理論上證明通信是有效的,收斂的。
  • 系統設計:現有分佈式學習系統都無法滿足所有的新的通信模式,所以需要設計新的系統結構,才能利用這種算法帶來的優勢。
    • 參數服務器基本操作put/get,無法實現去中心化和誤差補償。
    • Allreduce是全局性的,無法實現去中心化或者異步模式。
  • 評測:需要在大規模真實場景下對各種算法進行評測。

1.4 Bagua 實現

1.4.1 分層

Bagua 具體分爲三層:

  • 算法層:在邏輯層基礎之上,實現了具體算法,比如某一個算法是去中心化,壓縮,異步的。
  • 邏輯通信層:在物理通信層基礎之上,實現了多種通信原語,比如去中心化,精度,同步等等,這些通信原語不是針對某一類算法特殊設計的,而對上層是統一的。
  • 物理通信層:在此層集成了一些常見通信庫,從而提供了基本的send,receive操作。

1.4.2 通信算法選項

針對通信模式分類,Bagua 相應將通信過程抽象成了如下的算法選項:

  • 中心化或是去中心化(Centralized or Decentralized)。

  • 同步或是異步(Synchronous or Asynchronous)。

  • 完整精度模式或信息壓縮模式(Full-Precision or Low-Precision)。

雖然爲了提升通訊效率,Bagua 沒有依照傳統的方式同步所有計算節點的結果,甚至每次同步的信息還有偏差,但是得益於最新理論上的進展,這幾種通訊策略以及他們的組合最終收斂解的正確性和效率仍然能得到充分保證,而且計算複雜度跟同步中心化和信息無損的方法相當,但是通訊效率更高。

img

Bagua 提供了一套詳盡的通信模式來支持用戶在上述模式中任意選擇組合,我們將這一分佈式訓練系統對於上述算法選項的支持情況總結在下表中:

img

從表格中不難看出,現有框架的優化只是針對較爲通用的算法(中心化同步完整精度),對於其他的算法組合,這些系統的支持非常有限。對於中心化同步進行信息壓縮,這些系統往往只能支持較爲簡單的 float32->float16 壓縮,相較而言,Bagua 則可以支持更爲複雜的 ByteGrad,QAdam 等算法。對於其他的算法組合,現有的框架通常無法支持,而 Bagua 則可以自由支持。

1.4.3 總體

BAGUA的核心是一個訓練算法,由開發者使用BAGUA提供的通信原語和抽象概念來實現。算法將最終用戶提供的神經網絡作爲輸入,併爲其配備一個特定於算法的通信功能。具體來說,算法的開發者會在執行的不同階段將這個通信功能註冊爲鉤子。

1.4.4 優化

然而,簡單地支持算法選項並不能直接在大規模集羣上帶來性能的提升。Bagua 的核心優勢在於,爲了追求極致化的性能,而實現算法和實現的聯合優化。具體來講,基於上述的通信層抽象,用戶既可以方便得選擇系統提供的各種算法組合從而獲得性能提升,又能靈活得實現新的分佈式 SGD 算法 —— Bagua 將自動爲這一算法實現提供系統層優化。這些系統優化包含:

  • 將通訊時間隱藏在計算時間中。
  • 參數分桶及其內存管理。
  • 分層化的通信實現。

想要強調的是,這些系統實現層面的優化是對於各種算法組合廣泛適用,而非侷限在某一特定的算法設置上。因此,所有的系統優化都可以被靈活的複用到各種算法實現中去,這在保證“端到端”的性能提升的同時,也爲開發新的分佈式算法提供了良好的平臺。

1.5 流程圖

我們使用官方號的圖例做一下總結

img

0x02 分析思路

通過官方文章我們可以發現對於分析學習來說有如下情況:

  • 通信方面的優化實現是八卦項目的一大特點。
  • 底層 Rust 語言筆者不熟悉。
  • 通盤研究整體代碼不現實。

因此我們決定以 中心化、異步通訊,分層化的通信實現 爲中心,再結合幾個特色實現來學習分析。本文學習負載均衡數據加載器。

0x03 Load Balanced Data Loader

在某些場景下當訓練數據中樣本的計算複雜度是不同的,比如在 NLP 和語音任務中每個樣本的長度就不同。這時,使用八卦的負載均衡數據加載器可以大大提高分佈式訓練吞吐量,在這種情況下,worker 的工作負載是相似的。我們接下來就從實例入手,看看如何實現數據加載的負載均衡

我們先看看負載均衡的需求,假如我們有兩個模型副本進行數據並行,有如下數據,假如這些數據代表的是數據複雜度(會影響計算時間)

[ 7,  1, 11,  5,  10,  2,  9, 4,  6,  0,  8,  3]

那麼第一個模型副本收到的數據爲:[7,11,10,9,6, 8]。第二個模型副本收到的數據爲:[1,5,2,4,0,3]。可以看出來兩個模型在每個batch收到數據的複雜度不同,會造成負載不均衡。

                         +  8                         + 3
                         |                            |
                         |  6                         | 0
                         |                            |
                         |  9                         | 4
                         |                            |
batch 3   +----------->  |  10                        | 2  <----------+  batch 3
                         |                            |
batch 2   +----------->  |  11                        | 5  <----------+  batch 2
                         |                            |
batch 1   +----------->  v  7                         v 1  <----------+  batch 1

                  +-------------------+        +-------------------+
                  |                   |        |                   |
                  |     worker 0      |        |     worker 1      |
                  |                   |        |                   |
                  |                   |        |                   |
                  +-------------------+        +-------------------+

理想狀態應該是兩個模型每個batch收到的數據複雜度都相仿,比如第一個模型收到 [1,3,5,7,9],第二個模型的數據是[2,4,6,8,10],在下圖的輸入下,可以看到每次batch數據複雜度相仿,從而達到負載均衡的效果:

                         +                            +
                         |  9                         | 10
                         |                            |
                         |  7                         | 8
                         |                            |
batch 3   +----------->  |  5                         | 6  <----------+  batch 3
                         |                            |
batch 2   +----------->  |  3                         | 4  <----------+  batch 2
                         |                            |
batch 1   +----------->  v  1                         v 2  <----------+  batch 1

                  +-------------------+        +-------------------+
                  |                   |        |                   |
                  |     worker 0      |        |     worker 1      |
                  |                   |        |                   |
                  |                   |        |                   |
                  +-------------------+        +-------------------+

3.1 使用

我們直接使用源碼中的例子修改學習一下。

import torch
from load_balancing_data_loader import LoadBalancingDistributedSampler
from torch.utils.data import TensorDataset, DataLoader

def test_load_balancing_distributed_batch_sampler():
    num_replicas = 2 # 分成兩個副本
    total_batch = 3 

    n = sum([i + 1 for i in range(total_batch)]) * num_replicas
    dataset = TensorDataset(torch.randn(n, 2), torch.randperm(n))

    sampler = LoadBalancingDistributedSampler(
        dataset,
        complexity_fn=lambda x: x[1],
        num_replicas=num_replicas,
        rank=0,
        shuffle=True, # 需要shuffle
        random_level=0.5, # 加入隨機
    )

    dataloader = torch.utils.data.DataLoader(dataset, sampler=sampler)

    cur_idx = 0
    for i, data in enumerate(dataloader):
        batch_size = data[0].shape[0]
        cur_idx += batch_size * num_replicas
        print(cur_idx)

test_load_balancing_distributed_batch_sampler()

因爲此處代碼十分繞,所以我們逐次解析。

3.2 生成數據集

首先是生成數據集部分。torch.randn(n, 2) 生成了隨機張量,torch.randperm(n) 生成了 n 的隨機排序。這裏假定 n 是12。

# 生成了數據集
n = sum([i + 1 for i in range(total_batch)]) * num_replicas
dataset = TensorDataset(torch.randn(n, 2), torch.randperm(n))

TensorDataset 類似 zip 命令,生成了tuple列表。

dataset = {TensorDataset: 12} 
 tensors = {tuple: 2} (
   
  0 = {Tensor: 12} tensor([[-1.5556,  0.6848],\n        [ 2.0811,  1.5011],\n        [ 0.7434, -0.4990],\n        [-0.2706,  1.7227],\n        [ 0.2179,  0.0622],\n        [-0.3014, -0.6435],\n        [-0.1773, -1.3405],\n        [-1.8212,  0.3702],\n        [-0.5526, -0.2077],\n        [-1.6543,  0.3109],\n        [ 0.3265,  0.5987],\n        [-1.5566,  0.2854]])
   
   1 = {Tensor: 12} tensor([ 7,  8, 11,  4,  5,  2,  9, 10,  0,  6,  1,  3])

得出目前的TensorDataset如下 ,0 是實際數據,1 是數據複雜度,後續處理的目的就是按照數據複雜度對這些張量排序。我們可以設想下,最終排序應該就是一個複雜度均勻的排序結果。

+-----------------------------------------------------------------------------+
| TensorDataset                                                               |
|                                                                             |
|   0 = {Tensor: 12} tensor([[-1.5556,  0.6848],......                        |
|                                                                             |
|   1 = {Tensor: 12} tensor([ 7,  8, 11,  4,  5,  2,  9, 10,  0,  6,  1,  3]) |
|                                                                             |
+-----------------------------------------------------------------------------+

3.3 初始化

我們來到了 LoadBalancingDistributedSampler 的初始化。

def __init__(
    self,
    dataset: Dataset,
    complexity_fn: Callable[..., int],
    num_replicas: Optional[int] = None,
    rank: Optional[int] = None,
    shuffle: bool = True,
    seed: int = 0,
    drop_last: bool = False,
    random_level: float = 0,
) -> None:
    if num_replicas is None:
        num_replicas = dist.get_world_size()
    if rank is None:
        rank = dist.get_rank()

    self.dataset = dataset
    self.num_replicas = num_replicas
    self.rank = rank
    self.epoch = 0
    self.drop_last = drop_last

    # If the dataset length is evenly divisible by # of replicas, then there
    # is no need to drop any data, since the dataset will be split equally.
    dataset_len = len(self.dataset)  # type: ignore
    if self.drop_last and dataset_len % self.num_replicas != 0:  # type: ignore
        # Split to nearest available length that is evenly divisible.
        # This is to ensure each rank receives the same amount of data when
        # using this Sampler.
        self.num_samples = math.ceil(
            # `type:ignore` is required because Dataset cannot provide a default __len__
            # see NOTE in pytorch/torch/utils/data/sampler.py
            (dataset_len - self.num_replicas)
            / self.num_replicas
        )
    else:
        self.num_samples = math.ceil(dataset_len / self.num_replicas)  # type: ignore
    self.total_size = self.num_samples * self.num_replicas
    self.shuffle = shuffle
    self.seed = seed

""" 
此時變量爲
self = {LoadBalancingDistributedSampler: 6} 
 dataset = {TensorDataset: 12} <torch.utils.data.dataset.TensorDataset object at 0x7ff7385aecf8>
 drop_last = {bool} False
 epoch = {int} 0
 num_replicas = {int} 2
 num_samples = {int} 6
 rank = {int} 0
 seed = {int} 0
 shuffle = {bool} True
 total_size = {int} 12 
"""       
    
    # 以下是與PyTorch原生的主要不同之處
    self.item_complexity_map = dict()
    for item_index in range(dataset_len):
        # 每一個item都有一個complexity
        self.item_complexity_map[item_index] = complexity_fn(
            self.dataset[item_index]
        )

"""
complexity_fn 是選取 tuple 的第二個元素作爲複雜度,我們回憶一下數據集的複雜度
{Tensor: 12} tensor([ 7,  8, 11,  4,  5,  2,  9, 10,  0,  6,  1,  3])

所以得到了複雜度map如下:
item_complexity_map = {dict: 12} {0: tensor(7), 1: tensor(8), 2: tensor(11), 3: tensor(4), 4: tensor(5), 5: tensor(2), 6: tensor(9), 7: tensor(10), 8: tensor(0), 9: tensor(6), 10: tensor(1), 11: tensor(3)}
 0 = {Tensor} tensor(7) # 第 0 個元素複雜度是 7
 1 = {Tensor} tensor(8) # 第 1 個元素複雜度是 8
 2 = {Tensor} tensor(11)
 3 = {Tensor} tensor(4)
 4 = {Tensor} tensor(5)
 5 = {Tensor} tensor(2)
 6 = {Tensor} tensor(9)
 7 = {Tensor} tensor(10)
 8 = {Tensor} tensor(0)
 9 = {Tensor} tensor(6)
 10 = {Tensor} tensor(1)
 11 = {Tensor} tensor(3)
"""        
        
    # 按照複雜度排序    
    self.ordered_item_complexity_map = OrderedDict(
        sorted(self.item_complexity_map.items(), key=lambda t: t[1])
    )
    
"""
排序之後如下:
ordered_item_complexity_map = {OrderedDict: 12} OrderedDict([(8, tensor(0)), (10, tensor(1)), (5, tensor(2)), (11, tensor(3)), (3, tensor(4)), (4, tensor(5)), (9, tensor(6)), (0, tensor(7)), (1, tensor(8)), (6, tensor(9)), (7, tensor(10)), (2, tensor(11))])
 8 = {Tensor} tensor(0) 第8個元素複雜度最低,是0
 10 = {Tensor} tensor(1) # 第10個元素複雜度次低,是1
 5 = {Tensor} tensor(2)
 11 = {Tensor} tensor(3)
 3 = {Tensor} tensor(4)
 4 = {Tensor} tensor(5)
 9 = {Tensor} tensor(6)
 0 = {Tensor} tensor(7)
 1 = {Tensor} tensor(8)
 6 = {Tensor} tensor(9)
 7 = {Tensor} tensor(10)
 2 = {Tensor} tensor(11)
"""    
    
    max_complexity = max(self.item_complexity_map.values()) # 11
    min_complexity = min(self.item_complexity_map.values()) # 0
    self.random_number = int((max_complexity - min_complexity) * random_level + 1) # 6
    
# random_number = {int} 1
  

拓展如下:

  • TensorDataset ,0 = ... 是實際數據,1 = ... 是數據複雜度,後續就是按照複雜度排序,而且所有排序或者打亂都沒有對原始數據進行移動,而是通過額外空間完成。
  • 初始化內部會對複雜度進行排序,
    • item_complexity_map 是得到每個元素的原始複雜度,比如 0: 7 表示第 0 個元素複雜度是 7。
    • ordered_item_complexity_map 就是排序之後的結構,其中 (8, 0) 表示第8個元素複雜度最低,是0,整個map是升序排列。

TensorDataset 的邏輯圖拓展如下,現在數據集 ordered_item_complexity_map 之中按照複雜度從低到高進行排序了。

+-----------------------------------------------------------------------------+
| TensorDataset                                                               |
|                                                                             |
|   0 = {Tensor: 12} tensor([[-1.5556,  0.6848],......                        |
|                                                                             |
|   1 = {Tensor: 12} tensor([ 7,  8, 11,  4,  5,  2,  9, 10,  0,  6,  1,  3]) |
|                                                                             |
+-------------------------------------------+---------------------------------+
                                            |
                                            |
                                            v
+-------------------------------------------+------------------------------------------+
| LoadBalancingDistributedSampler.__init__                                             |
|                                                                                      |
|                                                                                      |
|  item_complexity_map = {dict: 12} {0: 7, 1: 8, 2: 11, 3: 4, 4: 5, 5: 2,              |
|                                                                                      |
|                                    6: 9, 7: 10, 8: 0, 9: 6, 10: 1, 11: 3}            |
|                                           +                                          |
|                                           |                                          |
|                                           |  sorted                                  |
|                                           |                                          |
|                                           v                                          |
|  ordered_item_complexity_map = {OrderedDict: 12} [(8, 0), (10, 1), (5, 2), (11, 3),  |
|                                                                                      |
|                    (3, 4), (4, 5), (9, 6), (0, 7), (1, 8), (6, 9), (7, 10), (2, 11)] |
|                                                                                      |
+--------------------------------------------------------------------------------------+

3.4 使用

示例代碼之中接下來是使用數據:

dataloader = torch.utils.data.DataLoader(dataset, sampler=sampler)

cur_idx = 0
for i, data in enumerate(dataloader):
    batch_size = data[0].shape[0]
    cur_idx += batch_size * num_replicas
    print(cur_idx)

3.4.1 獲取數據

我們接下來看看如何獲取數據,就是如何從loader拿到sample。

  • 首先會調用 shuffle_chunks 來打亂數據。
  • 然後得到自己rank對應的index。
def __iter__(self) -> Iterator:
    index_chunks, chunk_indices = self.shuffle_chunks() # 打亂數據
    # subsample
    indices = [index_chunks[i][self.rank] for i in chunk_indices] # 用 rank來提取數據

"""
得到數據如下:
chunk_indices = {list: 6} [0, 5, 4, 1, 2, 3] 把 index_chunks 順序打亂,chunk_indices 是打亂之後的結果
index_chunks = {list: 6} [[8, 3], [5, 11], [4, 10], [0, 9], [6, 1], [7, 2]] 均勻分成兩組
indices = {list: 6} [8, 7, 6, 5, 4, 0] 得到自己rank對應的index
"""    
    return iter(indices)

3.4.2 shuffle

我們看看shuffle 具體代碼如下,這裏最終要分成 6 = 12(數據數目) / 2( num_replicas ) 組數據。

def shuffle_chunks(self):
    def chunks_wrap_padding(lst, n):
        """Yield successive n-sized chunks from lst."""
        num_chunks = max(1, self.num_samples)
        num_elements = num_chunks * n
        current_lst = []
        for i in range(num_elements):
            current_lst.append(lst[i % len(lst)])
            if len(current_lst) == n:
                yield current_lst
                current_lst = []

    if self.shuffle: # 需要再次打亂
        # deterministically shuffle based on epoch and seed
        g = torch.Generator()
        g.manual_seed(self.seed + self.epoch)

        if self.random_number > 0:
            # 這裏的打亂機制很巧妙,就是隨機再生成複雜度,然後加到原先複雜度map上
            item_complexity_map = self.item_complexity_map.copy() # 原來map做個拷貝
            complexity_random_ints = torch.randint( # 新生成了一些複雜度變化值
                self.random_number, (len(item_complexity_map),), generator=g
            ).tolist()
"""
complexity_random_ints = {list: 12} [2, 3, 5, 0, 1, 3, 1, 1, 1, 3, 5, 2]

item_complexity_map = {dict: 12} {0: tensor(7), 1: tensor(8), 2: tensor(11), 3: tensor(4), 4: tensor(5), 5: tensor(2), 6: tensor(9), 7: tensor(10), 8: tensor(0), 9: tensor(6), 10: tensor(1), 11: tensor(3)}
"""
            
            # 原來複雜度map + 複雜度變化值
            for k, v in zip(item_complexity_map, complexity_random_ints):
                item_complexity_map[k] += v
"""
生成新的複雜度
item_complexity_map = {0: tensor(9), 1: tensor(11), 2: tensor(16), 3: tensor(4), 4: tensor(6), 5: tensor(5), 6: tensor(10), 7: tensor(11), 8: tensor(1), 9: tensor(9), 10: tensor(6), 11: tensor(5)}
"""
        
            # 再次對新複雜度排序
            ordered_item_complexity_map = OrderedDict(
                sorted(item_complexity_map.items(), key=lambda t: t[1])
            )

"""
ordered_item_complexity_map = {OrderedDict: 12} OrderedDict([(8, tensor(1)), (3, tensor(4)), (5, tensor(5)), (11, tensor(5)), (4, tensor(6)), (10, tensor(6)), (0, tensor(9)), (9, tensor(9)), (6, tensor(10)), (1, tensor(11)), (7, tensor(11)), (2, tensor(16))])
 8 = {Tensor} tensor(1)
 3 = {Tensor} tensor(4)
 5 = {Tensor} tensor(5)
 11 = {Tensor} tensor(5)
 4 = {Tensor} tensor(6)
 10 = {Tensor} tensor(6)
 0 = {Tensor} tensor(9)
 9 = {Tensor} tensor(9)
 6 = {Tensor} tensor(10)
 1 = {Tensor} tensor(11)
 7 = {Tensor} tensor(11)
 2 = {Tensor} tensor(16)
 __len__ = {int} 12
"""
        else:
            ordered_item_complexity_map = self.ordered_item_complexity_map

        index_chunks = list( # 按照 num_replicas 進行分片
            chunks_wrap_padding(
                list(ordered_item_complexity_map.keys()), self.num_replicas
            )
        )

"""
被均勻分配成兩組,每組中兩個元素的複雜度接近
index_chunks = {list: 6} [[8, 3], [5, 11], [4, 10], [0, 9], [6, 1], [7, 2]]
 0 = {list: 2} [8, 3]
 1 = {list: 2} [5, 11]
 2 = {list: 2} [4, 10]
 3 = {list: 2} [0, 9]
 4 = {list: 2} [6, 1]
 5 = {list: 2} [7, 2]
 __len__ = {int} 6
"""        
        # 再次打亂 index_chunks
        chunk_indices = torch.randperm(len(index_chunks), generator=g).tolist()  # type: ignore
    
"""
chunk_indices = {list: 6} [0, 5, 4, 1, 2, 3]
"""    
    
    else:
        index_chunks = list(
            chunks_wrap_padding(
                list(self.ordered_item_complexity_map.keys()), self.num_replicas
            )
        )
        chunk_indices = list(range(len(index_chunks)))  # type: ignore

    if not self.drop_last:
        # add extra samples to make it evenly divisible
        padding_size = self.num_samples - len(chunk_indices)
        if padding_size <= len(chunk_indices):
            chunk_indices += chunk_indices[:padding_size]
        else:
            chunk_indices += (
                chunk_indices * math.ceil(padding_size / len(chunk_indices))
            )[:padding_size]
    else:
        # remove tail of data to make it evenly divisible.
        chunk_indices = chunk_indices[: self.num_samples]
    assert len(chunk_indices) == self.num_samples
    return index_chunks, chunk_indices

總體拓展如下:

  • TensorDataset ,0 = ... 是實際數據,1 = ... 是數據複雜度,後續就是按照複雜度排序:
  • LoadBalancingDistributedSampler.__init__ 初始化內部會對複雜度進行排序,
    • item_complexity_map 是得到每個元素的複雜度,比如 0: 7 表示第 0 個元素複雜度是 7。
    • ordered_item_complexity_map 就是按照複雜度排序之後的結構,其中 (8, 0) 表示第8個元素複雜度最低,是0。
  • shuffle_chunks 內部繼續處理,這裏的打亂機制很巧妙,沒有移動數據,而是隨機再生成複雜度,然後加到原先複雜度map上,這樣就打亂了
    • complexity_random_ints 新生成了一些複雜度變化值。
    • item_complexity_map 把原來map做個拷貝。
    • item_complexity_map 繼續操作,即:新複雜度 = 原來複雜度map + 複雜度變化值。
    • ordered_item_complexity_map 對新複雜度排序。
    • 對 ordered_item_complexity_map 按照 num_replicas 進行分片,得到 index_chunks,ordered_item_complexity_map 被均勻分配成六組,每組中兩個元素的複雜度接近
    • 然後再次打亂 index_chunks,得到 chunk_indices,就是爲了把index順序打亂而已。
+--------------------------------------------------------------------------------------+
| TensorDataset                                                                        |
|                                                                                      |
|   0 = {Tensor: 12} tensor([[-1.5556,  0.6848],......                                 |
|                                                                                      |
|   1 = {Tensor: 12} tensor([ 7,  8, 11,  4,  5,  2,  9, 10,  0,  6,  1,  3])          |
|                                                                                      |
+-------------------------------------------+------------------------------------------+
                                            |
                                            |
                                            v
+-------------------------------------------+------------------------------------------+
| LoadBalancingDistributedSampler.__init__                                             |
|                                                                                      |
|                                                                                      |
|  item_complexity_map = {dict: 12} {0: 7, 1: 8, 2: 11, 3: 4, 4: 5, 5: 2,              |
|                                                                                      |
|                                    6: 9, 7: 10, 8: 0, 9: 6, 10: 1, 11: 3}            |
|                                           +                                          |
|                                           |                                          |
|                                           |  sorted                                  |
|                                           |                                          |
|                                           v                                          |
|  ordered_item_complexity_map = {OrderedDict: 12} [(8, 0), (10, 1), (5, 2), (11, 3),  |
|                                                                                      |
|                    (3, 4), (4, 5), (9, 6), (0, 7), (1, 8), (6, 9), (7, 10), (2, 11)] |
|                                                                                      |
+-------------------------------------------+------------------------------------------+
                                            |
                                            |
                                            v
+-------------------------------------------+------------------------------------------+
| __iter__                                                                             |
|                                                                                      |
+-------------------------------------------+------------------------------------------+
                                            |
                                            |
                                            v
+-------------------------------------------+------------------------------------------+
|                                                                                      |
| shuffle_chunks()                                                                     |
|                                                                                      |
|                                                                                      |
|   complexity_random_ints = {list: 12} [2, 3, 5, 0, 1, 3, 1, 1, 1, 3, 5, 2]           |
|                                                                                      |
|                                                                                      |
|                                                                                      |
|   item_complexity_map = {0: 9, 1: 11, 2: 16, 3: 4, 4: 6, 5: 5, 6: 10, 7: 11, 8: 1,   |
|                                                                                      |
|                                                                9: 9, 10: 6, 11: 5}   |
|                                                                                      |
|                                                                                      |
|                                                                                      |
|   ordered_item_complexity_map = {OrderedDict: 12} [(8, 1), (3, 4), (5, 5), (11, 5),  |
|                                                                                      |
|                                                    (4, 6), (10, 6), (0, 9), (9, 9),  |
|                                                                                      |
|                                                (6, 10), (1, 11), (7, 11), (2, 16)])  |
|                                                                                      |
|                                           +                                          |
|                                           |                                          |
|                                           |                                          |
|                                           v                                          |
|                                                                                      |
|     index_chunks = {list: 6} [[8, 3], [5, 11], [4, 10], [0, 9], [6, 1], [7, 2]]      |
|                                                                                      |
|                                                                                      |
|     chunk_indices = {list: 6} [0, 5, 4, 1, 2, 3]                                     |
|                                                                                      |
|                                                                                      |
+--------------------------------------------------------------------------------------+

3.4.3 梳理

shuffle 細化

看到這裏讀者可能有點暈,所以我們需要具體梳理一下。

ordered_item_complexity_map 就是按照複雜度排序之後的結構,其中 (8, 0) 表示第8個元素複雜度最低,是0。ordered_item_complexity_map 擁有 12個元素,按照兩個副本分配,所以 ordered_item_complexity_map 應該被均勻分配成六組,每組中兩個元素的複雜度接近

index_chunks = {list: 6} [[8, 3], [5, 11], [4, 10], [0, 9], [6, 1], [7, 2]] 是最終的結果,這裏[8, 3]是一組,複雜度接近,[5, 11]是一組,複雜度接近,比如結合 ordered_item_complexity_map 來看:

  • (8, 1), (3, 4) 就是說,第 8 個元素複雜度是1,第3個元素複雜度是4,所以 index 8,index 3 被分成一組。

  • (5, 5), (11, 5) 就是說,第 5 個元素複雜度是5,第11個元素複雜度是5,所以 index 5,index 11 被分成一組。

shuffle_chunks 的演示如下:

+--------------------------------------------------------------------------------------+
| shuffle_chunks                                                                       |
|                                                                                      |
|                                                                                      |
|                                      +--------------+     +---------------+          |
|   ordered_item_complexity_map = [ +--+(8, 1), (3, 4)|   +-+(5, 5), (11, 5)|          |
|                                   |  +--------------+   | +---------------+          |
|                                   |                     |                            |
|                                   |  +---------------+  | +---------------+          |
|                              +-------+(4, 6), (10, 6)|  | |(0, 9), (9, 9) +-------+  |
|                              |    |  +---------------+  | +---------------+       |  |
|                              |    |                     |                         |  |
|                              |    |  +----------------+ | +----------------+      |  |
|                              |    |  |(6, 10), (1, 11)| | |(7, 11), (2, 16)|  ]   |  |
|                              |    |  +-------------+--+ | +----------+-----+      |  |
|                              |    |                |    |            |            |  |
|                              +------------------+  +-------------+   +----+       |  |
|                                   |             |       |        |        |       |  |
|                                   |        +------------+   +---------------------+  |
|                                   |        |    |           |    |        |          |
|                                   v        v    v           v    v        v          |
|     index_chunks = {list: 6} [[8, 3], [5, 11], [4, 10], [0, 9], [6, 1], [7, 2]]      |
|                                                                                      |
|                                      +                                               |
|                                      |                                               |
|                                      |                                               |
|                                      v                                               |
|                                                                                      |
|     chunk_indices = {list: 6} [0, 5, 4, 1, 2, 3]                                     |
|                                                                                      |
+--------------------------------------------------------------------------------------+
二次打亂

我們結合原始數據再來分析,先回頭看看 獲取數據。

def __iter__(self) -> Iterator:
    index_chunks, chunk_indices = self.shuffle_chunks()
    # subsample
    indices = [index_chunks[i][self.rank] for i in chunk_indices]

"""
得到數據如下:
chunk_indices = {list: 6} [0, 5, 4, 1, 2, 3] 把 index_chunks 順序打亂
index_chunks = {list: 6} [[8, 3], [5, 11], [4, 10], [0, 9], [6, 1], [7, 2]] 均勻分成兩組
indices = {list: 6} [8, 7, 6, 5, 4, 0] 得到自己rank對應的index
"""    
    
    assert len(indices) == self.num_samples

    return iter(indices)

原始數據爲 :[ 7, 8, 11, 4, 5, 2, 9, 10, 0, 6, 1, 3],後續會按照原始數據的index 來排序

按照複雜度排序/shuffle之後,rank 0 就是 [8, 5, 4, 0, 6, 7]。rank 1 就是 [3, 11, 10, 9, 1, 2]。

rank 0 和 rank 1 的batch 是 [[8, 3], [5, 11], [4, 10], [0, 9], [6, 1], [7, 2]] ,兩兩一組。

但是,還需要再次打亂順序,因爲目前這個batch是按照複雜度從小到大排序,這樣會影響訓練效果,所以需要打亂這個順序。所以就按照 chunk_indices [0, 5, 4, 1, 2, 3] 這個順序來打亂。

打亂之後的順序是:[[8, 3], [7, 2], [6, 1], [5, 11], [4, 10], [0, 9]]。

  • 假如本worker 是 rank 0,則會獲取 index_chunks 這六組數據中和自己對應的,得到 [8, 7, 6, 5, 4, 0]。

  • 假如本worker rank 1,則是 [3,2,1,11,10,9]。注意,這些還都是原始數據的index。

具體演示如下圖(這裏只給出 rank 0 的效果):

+--------------------------------------------------------------------------------------+
| shuffle_chunks                                                                       |
|                                                                                      |
|                                      +--------------+     +---------------+          |
|   ordered_item_complexity_map = [ +--+(8, 1), (3, 4)|   +-+(5, 5), (11, 5)|          |
|                                   |  +--------------+   | +---------------+          |
|                                   |                     |                            |
|                                   |  +---------------+  | +---------------+          |
|                               +------+(4, 6), (10, 6)|  | |(0, 9), (9, 9) +------+   |
|                               |   |  +---------------+  | +---------------+      |   |
|                               |   |                     |                        |   |
|                               |   |  +----------------+ | +----------------+     |   |
|                               |   |  |(6, 10), (1, 11)| | |(7, 11), (2, 16)|  ]  |   |
|                               |   |  +-------------+--+ | +----------+-----+     |   |
|                               |   |                |    |            |           |   |
|                               +-----------------+  +-------------+   +----+      |   |
|                                   |             |       |        |        |      |   |
|                                   |        +------------+   +--------------------+   |
|                                   |        |    |           |    |        |          |
|                                   v        v    v           v    v        v          |
|     index_chunks = {list: 6} [[8, 3], [5, 11], [4, 10], [0, 9], [6, 1], [7, 2]]      |
|                                      +                                               |
|                                      |                                               |
|                                      |                                               |
|                                      v                                               |
|     chunk_indices = {list: 6} [0, 5, 4, 1, 2, 3]                                     |
|                                                                                      |
+--------------------------------------+-----------------------------------------------+
                                       |
                                       |
                                       v

+--------------------------------------------------------------------------------------+
| __iter__                                                                             |
|                                    0       1        2        3       4       5       |
|        index_chunks = {list: 6} [[8, 3], [5, 11], [4, 10], [0, 9], [6, 1], [7, 2]]   |
|                                   +       +        +        +       +       +        |
|                                   |       |        |        |       |       |        |
|                                   +----+  +-----+  |  +-----+       |       |        |
|                                        |        |  |  |             |       |        |
|                                        |        |  |  |             |       |        |
|                                        v        v  v  v             |       |        |
|                   indices = {list: 6} [8, 7, 6, 5, 4, 0]            |       |        |
|                                           ^  ^                      |       |        |
|                                           |  |                      |       |        |
|                                           |  +----------------------+       |        |
|                                           |                                 |        |
|                                           +---------------------------------+        |
|                                                                                      |
+--------------------------------------------------------------------------------------+
最終效果

我們看看最終效果是什麼:

  • 原始數據爲 :[ 7, 8, 11, 4, 5, 2, 9, 10, 0, 6, 1, 3]。

  • 最終shuffle/二次打亂之後的數據爲:rank 0 是 [8, 7, 6, 5, 4, 0],rank 1 則是 [3,2,1,11,10,9]。這裏數值是原始數據的index。

  • 最終結果是:

    • batch如下,rank 0 和 rank 1 的batch 是 [[8, 3], [7, 2], [6, 1], [5, 11], [4, 10], [0, 9]],兩兩一組。這裏數值是原始數據的index。
    • rank 0 的數據是 [0, 10, 9, 2, 5, 7],rank 1的數據是[4, 11, 7, 3, 1, 6],這裏數值就是原始數據的數值了。

具體如下圖,可以看到,因爲過程之中引入了隨機值,所以不是理想均衡狀態,但已經比較均衡了:

                         + 7                          + 6
                         |                            |
                         | 5                          | 1
                         |                            |
                         | 2                          | 3
                         |                            |
batch 3   +----------->  | 9                          | 7  <----------+  batch 3
                         |                            |
batch 2   +----------->  | 10                         | 11 <----------+  batch 2
                         |                            |
batch 1   +----------->  v 0                          v 4  <----------+  batch 1

                  +-------------------+        +-------------------+
                  |                   |        |                   |
                  |     worker 0      |        |     worker 1      |
                  |                   |        |                   |
                  |                   |        |                   |
                  +-------------------+        +-------------------+

0xFF 參考

PyTorch internals

快手八卦!突破 TensorFlow、PyTorch 並行瓶頸的開源分佈式訓練框架來了!

https://arxiv.org/pdf/2107.01499.pdf

[1] Dean, Jeffrey, Greg S. Corrado, Rajat Monga, Kai Chen, Matthieu Devin, Quoc V. Le, Mark Z. Mao et al. “Large scale distributed deep networks.” (2012).

[2] Zhengyuan Zhou, Panayotis Mertikopoulos, Nicholas Bambos, Peter Glynn, Yinyu Ye, Li-Jia Li, and Li Fei-Fei. 2018. Distributed asynchronous optimization with unbounded delays: How slow can you go?. In International Conference on Machine Learning. PMLR, 5970–5979.

[3] DanAlistarh, DemjanGrubic, JerryLi, RyotaTomioka, and MilanVojnovic. 2016. QSGD: Communication-efficient SGD via gradient quantization and encoding. arXiv preprint arXiv:1610.02132 (2016).

[4] Dan Alistarh, Torsten Hoefler, Mikael Johansson, Sarit Khirirat, Nikola Konstanti- nov, and Cédric Renggli. 2018. The convergence of sparsified gradient methods. In Proceedings of the 32nd International Conference on Neural Information Processing Systems. 5977–5987.

[5] Anastasia Koloskova, Sebastian Stich, and Martin Jaggi. 2019. Decentralized stochastic optimization and gossip algorithms with compressed communication. In International Conference on Machine Learning. PMLR, 3478–3487.

[6] Xiangru Lian, Ce Zhang, Huan Zhang, Cho-Jui Hsieh, Wei Zhang, and Ji Liu. 2017. Can decentralized algorithms outperform centralized algorithms? a case study for decentralized parallel stochastic gradient descent. In Proceedings of the 31st International Conference on Neural Information Processing Systems. 5336–5346.

[7] Christopher De Sa, Matthew Feldman, Christopher Ré, and Kunle Olukotun. 2017. Understanding and optimizing asynchronous low-precision stochastic gradient descent. In Proceedings of the 44th Annual International Symposium on Computer Architecture. 561–574.

[8] Xiangru Lian, Wei Zhang, Ce Zhang, and Ji Liu. 2018. Asynchronous decentral- ized parallel stochastic gradient descent. In International Conference on Machine Learning. PMLR, 3043–3052.

[9] Hanlin Tang, Shaoduo Gan, Ce Zhang, Tong Zhang, and Ji Liu. 2018. Com- munication compression for decentralized training. In Proceedings of the 32nd International Conference on Neural Information Processing Systems. 7663–7673.

[10] Ji Liu, Ce Zhang, et al. 2020. Distributed Learning Systems with First-Order Methods. Foundations and Trends® in Databases 9, 1 (2020), 1–100.![]

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章