[源碼解析] PyTorch 分佈式之 ZeroRedundancyOptimizer

[源碼解析] PyTorch 分佈式之 ZeroRedundancyOptimizer

0x00 摘要

PyTorch Zero Redundancy Optimizer 是一類旨在解決數據並行訓練和模型並行訓練之間權衡問題的算法。Zero Redundacy Optimizer 的思想來源於微軟的ZeRO,具體實現是基於 Fairscale 的OSS。

Fairscale 實現了 ZeRO 的三個階段的算法,Fairscale 是 Facebook AI Research (FAIR) 開源的項目,個人理解爲是Facebook 大規模深度學習分佈式訓練的一個試驗田,如果其中某個模塊發展成熟,就會合併到 PyTorch 之中。

OSS 就是Fairscale實現的 ZeRO-1,其實現了優化器狀態分片(參見下圖紅色方框)。PyTorch 則是基於 FairScale 的 OSS 實現了 ZeroRedundancyOptimizer。

注:本文基於 PyTorch 1.9.0。

0x01 歷史

1.1 Github說明

ZeroRedundancyOptimizer 是在 https://github.com/pytorch/pytorch/pull/46750 引入的,我們看看其說明。

ZeroRedundancyOptimizer: an implementation of a standalone sharded optimizer wrapper #46750

Implement the first stage of ZeRO, sharding of the optimizer state, as described in this blog post and this paper. This implementation is completely independent from the DeepSpeed framework, and aims at providing ZeRO-compliant building blocks within the PyTorch scheme of things.

This works by:

  • acting as a wrapper to a pytorch optimizer. ZeROptimizer does not optimize anything by itself, it only shards optimizers for distributed jobs
  • each rank distributes parameters according to a given partitioning scheme (could be updated), and owns the update of a given shard only
  • the .step() is called on each rank as expected, the fact that the optimizer actually works on a shard of the model is not visible from the outside
  • when the update is completed, each rank broadcasts the updated model shard to all the other ranks

This can be used with DDP, although some communications are wasted in that case (gradients are all-reduced to all ranks). This implementation was initially developed in Fairscale, and can also be used with an optimized DDP which only reduces to the relevant ranks. More context on ZeRO and PyTorch can be found in this RFC

The API with respect to loading and saving the state is a known pain point and should probably be discussed an updated. Other possible follow ups include integrating more closely to a modularized DDP, making the checkpoints partition-agnostic, exposing a gradient clipping option and making sure that mixed precision states are properly handled.

original authors include @msbaines, @min-xu-ai and myself(blefaudeux )

1.2 解析

因此,我們可以知道如下信息:

  • Zero Redundacy Optimizer 的思想來源於微軟的ZeRO。
  • Fairscale 實現了 ZeRO 的三個階段的算法,Fairscale 是 Facebook AI Research (FAIR) 開源的項目,個人理解爲是Facebook 大規模深度學習分佈式訓練的一個試驗田,如果某個模塊發展成熟,就會合併到 PyTorch 之中。
  • OSS 是Fairscale實現的 ZeRO-1,其實現了優化器狀態分片。
  • PyTorch 就是基於 FairScale 的 OSS 實現了 ZeroRedundancyOptimizer。

我們有必要具體看一下。

0x02 背景知識

2.1 ZeRO

ZeRO(零冗餘優化器,Zero Redundacy Optimizer)是微軟開源的DeepSpeed(一種優化大規模訓練的框架)的一部分。ZeRO 是一種深度學習模型的內存優化方法,其尋求模型並行和數據並行的一箇中間點,以最大化模型的可擴展性。

ZeRO的優化涉及了深度學習模型內存使用的多個方面,包括激活內存、碎片內存和模型狀態內存。

  • 模型狀態內存(Model State Memory): 深度學習模型的狀態可歸爲:優化器狀態、梯度和參數這三個基本過程。
  • 激活內存(Activation Memory):在優化了模型狀態內存之後,人們發現激活函數也會導致瓶頸。激活函數計算位於前向傳播之中,用於支持後向傳播。
  • 碎片內存(Fragmented Memory):深度學習模型的低效有時是由於內存碎片所導致的。在模型之中,每個張量的生命週期不同,由於不同張量壽命的變化而會導致一些內存碎片。由於這些碎片的存在,會導致即使有足夠的可用內存,也會因爲缺少連續內存而使得內存分配失敗。ZeRO 根據張量的不同壽命主動管理內存,防止內存碎片。

比如優化可以參見下圖:

圖片來源 https://www.microsoft.com/en-us/research/blog/zero-2-deepspeed-shattering-barriers-of-deep-learning-speed-scale/。

2.2 Fairscale 的 ZeRO 實現

我們接下來看看 Fairscale 的使用指南。

這其實是分佈式/大規模機器學習方案的一個梳理,從中可以看到,其依據 ZeRO <https://arxiv.org/pdf/1910.02054.pdf>實現了三種不同的算法,分別對應了 ZeRO的三個階段:

  • Optimizer State Sharding (OSS) 實現了 Optimizer 分片,優化了分區優化器狀態的內存使用。
  • Sharded Data Parallel (SDP) 負責 Optimizer + Gradient State Sharding。
  • Fully Sharded Data Parallel (FSDP) 實現了 Optimizer + Gradient + Horizontal Model Sharding。

2.3 Optimizer State Sharding (OSS)

因爲OSS是ZeroRedundancyOptimizer的源頭,所以我們先看看其思路。OSS實現了與優化器內存相關的優化。像Adam這樣的優化器通常需要保持動量、方差。即便可以使用FP16精度的參數和梯度進行訓練,參數和梯度也需要保存爲FP32精度。當每個rank更新完整模型時,這意味着相當大一部分內存被優化器狀態的冗餘表示所佔用。爲了克服這種冗餘,優化器狀態分片需要將模型優化步驟劃分在不同的rank之間,以便每個rank只負責更新模型的對應分片。這反過來又確保優化器狀態在每個rank上小得多,並且它不包含跨rank的冗餘信息。

2.3.1 訓練流程

OSS 訓練流程可以從DDP的執行流程做如下修改:

  1. wrapped optimizer根據參數大小(而不是使用順序)以貪心算法方式來對優化器狀態進行分片。這是爲了確保每個rank具有幾乎相同大小的優化器內存。

  2. 訓練過程類似於PyTorch的分佈式數據並行(DDP)的過程。在每個rank上先完成前向傳播,然後是向後傳播。在後向傳播過程中,使用allreduce同步梯度

  3. 每個rank只更新它負責的優化器狀態參數,然後丟棄其餘的優化器參數

  4. 更新後,將執行broadcast或allgather操作,以確保所有rank都收到最新更新的參數值。

具體參見下圖。

2.3.2 最佳實踐

幾條最佳實踐如下:

  • OSS公開了一個broadcast_fp16 flag,您可能應該在多節點作業中使用它。在單節點實驗中通常不需要這樣做。
  • 如果您的模型在大小方面極不平衡(例如,存在一個巨大的張量),那麼這種方法將不會有很大幫助,而張量切分選項,如 fairscale.nn.FullyShardedDataParallel 將更可取。
  • OSS與大多數DDP功能保持兼容。
  • OSS應該是DDP環境中的一個臨時解決方案。

2.3.3 性能說明

以下是一些關於性能的說明。

  • 在單個節點上,OSS應該總是比vanilla PyTorch快,內存節省會因使用的優化器而異。

  • 當您使用具有附加狀態的優化器(如Adam)時,OSS非常有用。

  • 如果您使用的是SGD或任何內存佔用有限的優化器,那麼在使用多個節點時,由於上面流程之中步驟4中的額外通信,您可能會看到速度減慢。在第2步的allreduce過程中,也有一些用於存儲梯度的浪費內存,這些內存隨後被丟棄。

  • 當使用多個節點時,OSS也可以比vanilla PyTorch快或慢,具體取決於所使用的優化器和可選標誌(如上文提到的broadcast_fp16、梯度壓縮、梯度累積)

  • 如果您可以使用更大的batch size,最好是則採取更大的batch size並減少所涉及的rank數,或者使用梯度累積,因爲這樣可以降低通信成本。

我們接下來正式進入 ZeroRedundancyOptimizer。

0x03 如何使用

我們首先使用 https://pytorch.org/tutorials/recipes/zero_redundancy_optimizer.html 來看看如何使用 ZeroRedundancyOptimizer。

3.1 背後思想

ZeroRedundancyOptimizer的思想來自 DeepSpeed/ZeRO projectMarian ,這兩個項目會跨分佈式數據並行進程對優化器狀態進行分片,以減少每個進程的內存佔用。ZeRO的優化策略主要是通過對模型狀態進行切分以優化顯存佔用,模型狀態主要包括優化器狀態,梯度和模型參數。

ZeroRedundancyOptimizer 則實現了對優化器狀態(optimizer states)的切分,優化器狀態就是優化器運行所需要的參數和本地狀態。例如,SGD需要和模型參數一樣大小的動量,Adam優化器對於用每個參數保存了exp_avgexp_avg_sq 狀態。因此,Adam優化器的內存消耗至少是模型大小的兩倍。所以,當模型較大時,優化器狀態是不小的顯存開銷。

在分佈式數據並行入門教程(Getting Started With Distributed Data Parallel )中,我們展示瞭如何使用DistributedDataParallel(DDP)來訓練模型。在DDP中:

  • 每個worker進程(rank,node或者device)都保留優化器的專用副本。
  • 由於DDP已經在反向傳播中用all-reduce同步了梯度,因此所有優化器副本在每次迭代中都將在相同的參數和梯度值上運行。
  • 這些優化器用all-reduce後的gradients去更新模型參數,這就是DDP可以使各個模型副本(rank)保持相同參數狀態的原因。

根據這一觀察結果,我們可以通過在DDP進程之間分割優化器狀態來減少優化器內存佔用。更具體地說,就是:

  • 把優化器切分到不同worker之上,每個worker上的優化器實例只保留其模型參數分片所對應的那部分(1/world_size)優化器狀態,而不是爲所有參數創建對應的參數狀態。
  • 優化器 step() 函數只負責更新其分片中的參數,當worker完成參數更新之後,會將更新後的參數廣播給所有其他對等DDP進程,以便所有模型副本仍處於相同的狀態。

3.2 如何使用

ZeroRedundancyOptimizer可與torch.nn.parallel.DistributedDataParallel結合使用,以減少每個rank的內存峯值消耗。下面的代碼演示瞭如何使用ZeroRedundancyOptimizer. 大部分代碼類似於 Distributed Data Parallel notes中給出的簡單DDP示例。 主要區別在於example函數中的if else子句,這個語句包裝了優化器構造,可以在ZeroRedundancyOptimizer和Adam 之間進行切換。我們只要使用 ZeroRedundancyOptimizer對常規的optimizer進行warp即可。

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.nn.parallel import DistributedDataParallel as DDP

def print_peak_memory(prefix, device):
    if device == 0:
        print(f"{prefix}: {torch.cuda.max_memory_allocated(device) // 1e6}MB ")

def example(rank, world_size, use_zero):
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    # create default process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

    # create local model
    model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)])
    print_peak_memory("Max memory allocated after creating local model", rank)

    # construct DDP model
    ddp_model = DDP(model, device_ids=[rank])
    print_peak_memory("Max memory allocated after creating DDP", rank)

    # define loss function and optimizer
    loss_fn = nn.MSELoss()
    if use_zero:
        optimizer = ZeroRedundancyOptimizer( # 這裏使用了ZeroRedundancyOptimizer
            ddp_model.parameters(),
            optimizer_class=torch.optim.Adam, # 包裝了Adam
            lr=0.01
        )
    else:
        optimizer = torch.optim.Adam(ddp_model.parameters(), lr=0.01)

    # forward pass
    outputs = ddp_model(torch.randn(20, 2000).to(rank))
    labels = torch.randn(20, 2000).to(rank)
    # backward pass
    loss_fn(outputs, labels).backward()

    # update parameters
    print_peak_memory("Max memory allocated before optimizer step()", rank)
    optimizer.step()
    print_peak_memory("Max memory allocated after optimizer step()", rank)

    print(f"params sum is: {sum(model.parameters()).sum()}")



def main():
    world_size = 2
    print("=== Using ZeroRedundancyOptimizer ===")
    mp.spawn(example,
        args=(world_size, True),
        nprocs=world_size,
        join=True)

    print("=== Not Using ZeroRedundancyOptimizer ===")
    mp.spawn(example,
        args=(world_size, False),
        nprocs=world_size,
        join=True)

if __name__=="__main__":
    main()

輸出如下所示。

無論是否使用ZeroRedundancyOptimizer,在每個迭代之後,模型參數都使用了同樣內存,所以打印的輸出是一樣的。當啓用 ZeroRedundancyOptimizer 來封裝 Adam時,優化器 step() 的內存峯值消耗是 Adam內存消耗的一半。這與我們的預期相符,因爲我們把 Adam優化器狀態分片到了兩個進程之上。

=== Using ZeroRedundancyOptimizer ===
Max memory allocated after creating local model: 335.0MB
Max memory allocated after creating DDP: 656.0MB
Max memory allocated before optimizer step(): 992.0MB
Max memory allocated after optimizer step(): 1361.0MB
params sum is: -3453.6123046875
params sum is: -3453.6123046875
=== Not Using ZeroRedundancyOptimizer ===
Max memory allocated after creating local model: 335.0MB
Max memory allocated after creating DDP: 656.0MB
Max memory allocated before optimizer step(): 992.0MB
Max memory allocated after optimizer step(): 1697.0MB
params sum is: -3453.6123046875
params sum is: -3453.6123046875

3.3 小結

經過上面的原理分析和使用說明,我們知道:ZeroRedundancyOptimizer類可以對任意一個optim.Optimizer 進行封裝,並可以在組中的ranks之中分割自己的狀態。每個rank中的本地優化器實例只負責更新大約 1 / world_size 的參數,因此只需要保持 1 / world_size 大小的優化器狀態。

所以我們下面分析的重點就是:

  • 如何將優化器參數進行分區?
  • 每個rank如何知道自己對應的參數?

0x04 初始化

我們首先從 __init__ 看看如何構建,其主要做了三步:

  • 初始化基類。
  • 初始化各種成員變量。
  • 使用 _update_trainable 內部同步&構建buffer,其內部會調用 _optim_constructor 來構建內部優化器。
    def __init__(
        self,
        params,
        optimizer_class: Type[Optimizer], # 就是被包裝的原生優化器類型
        group: Optional[Any] = None,
        parameters_as_bucket_view: bool = False,
        **default: Any,
    ):
        # Hold all the model params in the root .param_groups
        # NOTE: the default constructor uses `add_param_group` which is partially overloaded here
        # we introduce the `initialized` flag for be able to dissociate the behaviour of
        # `add_param_group` in between super() and ZeroRedundancyOptimizer
        self.initialized = False
        super().__init__(params, default) # 初始化基類

        # Partition information. lazy evaluation, computed if requested
        self._per_device_params_cache: "OrderedDict[torch.device, List[List[Parameter]]]" = (
            OrderedDict()
        )  # device, rank, params

        # Build the wrapped optimizer, responsible for a shard of the params
        self._param_rank_cache: Dict[torch.Tensor, int] = {} # 初始化各種成員變量
        self._param_to_index_cache: Dict[int, int] = {}
        self._partition_parameters_cache: List[List[Dict]] = []
        self._index_to_param_cache: Dict[int, torch.Tensor] = {}
        self._all_params = params
        self._reference_is_trainable_mask = list(map(_is_trainable, self._all_params))

        self.group = group if group is not None else dist.group.WORLD
        self.world_size = dist.get_world_size(self.group)
        self.rank = dist.get_rank(self.group) 
        # global是用來在進程之間同步
        self.global_rank = _get_global_rank(self.group, self.rank)
        self.parameters_as_bucket_view = parameters_as_bucket_view

        self._optim_defaults = default
        self._optim_constructor = optimizer_class # 如何生成原生優化器

        #  Optional consolidated optimizer state
        self._all_states: List[Dict[str, Any]] = []
        # Current default device is set by the parameters allocated to this rank
        self._device = list(self._per_device_params.keys())[0]
        self.buckets: Dict[torch.device, List[torch.Tensor]] = {}

        self._update_trainable() # 內部同步&構建buffer,調用 _optim_constructor 來構建內部優化器
        self.initialized = True

因爲 Python 語言的特點,沒有專門的地方來初始化成員變量,而是在程序運行之中遇到了某個變量就即時初始化。所以,我們不會按照程序實際初始化的順序來分析,而是按照成員變量邏輯上初始化的順序來分析

以下分析的這些函數或者說成員變量都是在__init__方法之中被間接調用或者初始化

4.1 將參數分區

partition_parameters 方法會將參數進行分區,其返回 _partition_parameters_cache。

被包裝(wrapped)的optimizer根據參數大小(而不是使用順序)以排序貪婪(sorted-greedy)算法來對優化器狀態進行分片,在每個rank中打包一些參數,這樣每個參數都屬於一個rank,不在ranks之間劃分。分區是任意的,可能與參數註冊或使用順序不匹配。這是爲了確保每個rank具有幾乎相同大小的優化器內存

def partition_parameters(self) -> List[List[Dict]]:
    r"""
    Partitions parameters across distributed data parallel ranks.

    Returns:
        a list of ``param_groups`` (which is a list of dict) where each
        element of the list contains the param_groups for a rank. Element 0
        corresponds to rank 0, etc. We need all the ranks for the broadcast
        inside ``step()``.
    """
    if len(self._partition_parameters_cache) == 0:
        self._partition_parameters_cache = [list() for _ in range(self.world_size)]
        # 生成一個數組,用來記錄每個rank的大小,一共有world size個rank
        sizes = [0] * self.world_size 
        
        for param_group in self.param_groups: # 遍歷參數組
            param_lists: List[List] = [list() for _ in range(self.world_size)]
              
            for param in param_group["params"]:
                # Add this param to rank with smallest size.
                rank = sizes.index(min(sizes)) # 找到最小的那個rank
                param_lists[rank].append(param) # 把參數放到最小rank之中
                sizes[rank] += param.numel() # 增加rank的大小

            for rank, params in enumerate(param_lists): # 遍歷list
                param_group_rank = copy.copy(param_group)
                param_group_rank["params"] = params
                self._partition_parameters_cache[rank].append(param_group_rank)

    return self._partition_parameters_cache

這裏就分區好了,最終返回一個param_groups 的列表(這是一個dict列表),列表的每個元素都包含一個rank的param_groups,比如元素0對應於rank 0,每個rank的group的參數有差不多大小。在step()中,我們需要所有rank的信息來進行廣播。下圖給出了rank 0和 rank 5 對應的param_groups。

_partition_parameters_cache

          +
          |
          |
          v                +---------------+
  +-------+---------+      | param_group   |
  |       0         +----> |               |      <-------+  100 M   +------------->
  +-----------------+      +---------------+
  |       1         |      |               |     +--------+---------+------+--------+
  +-----------------+      |   "params" +------> |param 1 | param 2 | ...  | param 6|
  |       2         |      |               |     |        |         |      |        |
  +-----------------+      +---------------+     +--------+---------+------+--------+
  |                 |
  |                 |
  |     ......      |
  |                 |      +---------------+
  +-----------------+      | param_group   |      <-------+  105 M  +----------------->
  |       5         +----> |               |
  +-----------------+      +---------------+     +--------+---------+-------+---------+
                           |               |     |        |         |       |         |
                           |  "params"  +------> | param 7| param 8 | ...   | param 11|
                           |               |     |        |         |       |         |
                           +---------------+     +--------+---------+-------+---------+

4.2 將參數分給rank

現在,參數已經分成大小相近的group,接下來需要把這些group分到各個rank之上

_param_to_rank 方法生成一個表,裏面記錄每一個參數對應的rank,就是哪個參數在哪個rank之中。

@property
def _param_to_rank(self) -> Dict[torch.Tensor, int]:
    r"""Look up table to match a given param with a data parallel rank"""
    if len(self._param_rank_cache) == 0:
        for rank, param_groups in enumerate(self.partition_parameters()):
            for param_group in param_groups:
                for param in param_group["params"]:
                    self._param_rank_cache[param] = rank
    return self._param_rank_cache

依據上圖例子,我們知道param 1,param 2,param 6 在rank 0之中,param 8,param 11 在 rank 5 之中.....,具體如下:

_param_rank_cache

      +
      |
      |
      |
      v
 +----+--------------+------------+
 |                   |            |
 |   param 1         |     0      |
 +--------------------------------+
 |                   |            |
 |   param 2         |     0      |
 +--------------------------------+
 |                   |            |
 |   param 6         |     0      |
 +--------------------------------+
 |                   |            |
 |   param 8         |     5      |
 +--------------------------------+
 |                   |            |
 |   param 11        |     5      |
 +--------------------------------+
 |                   |            |
 |   param n         |     n      |
 |                   |            |
 +-------------------+------------+

4.3 _per_device_params

現在,參數已經分配給各個rank,接下來就要具體分配到設備之上,每個設備上可能包含多個rank的參數組_per_device_params 方法就是把優化器的param_groups在各個設備之間進行分配,其返回_per_device_params_cache

請注意,_per_device_params 這裏包括全部的模型參數,雖然已經按照設備進行了分類。即,在每個ZeRO優化器之中都是相同的。這樣ZeRO優化器之間可以廣播同步這些參數。

@property
def _per_device_params(self) -> Dict[torch.device, List[List[Parameter]]]:
    r"""
    Sorted list of all the params, first per device then per rank.

    Within a list params are sorted per number of elements to allow for an easy bucketing.
    """
    if len(self._per_device_params_cache) == 0:
        # Go through all params, log them per device
        # The ordering is important here, needs to be the same on all ranks
        # So that ulterior broadcast calls are matching
        for param_group in self.param_groups: # 遍歷參數
            for param in param_group["params"]:
                device = param.device # 找到其設備
                if self._per_device_params_cache.get(device) is None:
                    self._per_device_params_cache[device] = [[] for _ in range(self.world_size)]
                # 每個設備內部還需要按照rank來分開    
                self._per_device_params_cache[device][self._param_to_rank[param]] += [param]

        # Sort param_lists by size
        for k in self._per_device_params_cache.keys():
            for r in self._per_device_params_cache[k]:
                r.sort(key=lambda x: x.numel())

    return self._per_device_params_cache

比如,下面 CPU,GPU 1(忽略),GPU 2 都有自己的參數列表,每個列表之內都是按照參數大小排列。

_per_device_params_cache

      +
      |                                      +--------+--------+-------+--------+
      |                                      |        |        |       |        |
      |                     +---------+      | param1 | param3 |param5 | param6 |
      v                     |         |      |        |        |       |        |
 +----+--------------+      | rank 0  +----> |  1k    |  2k    |  3k   |   7k   |
 |                   |      |         |      +--------+--------+-------+--------+
 |     "CPU"         +----> +---------+
 |                   |      |         |
 +-------------------+      | rank 1  |      +--------+--------+-------+--------+
 |                   |      |         +----> |        |        |       |        |
 |     "GPU 1"       |      +---------+      | param9 | param2 | param4| param8 |
 |                   |                       |        |        |       |        |
 +-------------------+                       |  0.5k  |  1k    |  4k   |   8k   |
 |                   |                       +--------+--------+-------+--------+
 |     "GPU 2"       |      +---------+
 |                   +----> |         |      +---------+------------+-----------+
 +-------------------+      |         |      |         |            |           |
                            | rank 5  +----> | param 11|  param 13  | param 15  |
                            |         |      |         |            |           |
                            +---------+      +---------+------------+-----------+
                            |         |
                            | rank 6  |      +---------+------------+-----------+
                            |         +----> |         |            |           |
                            |         |      | param 19|  param 12  | param 14  |
                            +---------+      |         |            |           |
                                             +---------+------------+-----------+

4.4 _update_trainable

因爲某些參數會變化,所以需要在本地優化器和ZeroRedundancyOptimizer 之間彼此同步。

  • 首先得到 self._default_device 爲 "CPU" 或者 "GPU #"。
  • 然後調用 _optim_constructor 來構建內部優化器。注意,這裏就是告訴本地優化器,你就負責優化這些參數即可,不用管其他的shard。partition_parameters 方法前面提到,其會將參數進行分區,其返回 _partition_parameters_cache。
# 只是選取自己rank對應的參數進行優化
self.optim = self._optim_constructor(self.partition_parameters()[self.rank], **self._optim_defaults)

# 運行時變量如下:
#_optim_constructor = {type} <class 'torch.optim.adam.Adam'>
#_optim_defaults = {dict: 1} {'lr': 0.01}
  • 接着,調用 _sync_param_groups 同步參數。

  • 最後,建立 flat buffer。

具體代碼如下:

def _update_trainable(self) -> None:
    r"""
    Updates the partitioning and communication patterns if the trainability
    (``requires_grad``) of some parameters changed.
    """

    # Create the optim which will work on the param shard
    if not hasattr(self, "optim"):
        self._clear_cache()
        # 獲得缺省設備
        self._default_device = list(self._per_device_params.keys())[0]
        # 構建本地優化器,只是選取本rank對應的參數
        self.optim = self._optim_constructor(self.partition_parameters()[self.rank], **self._optim_defaults)
        # 調用 _sync_param_groups 同步參數,self.optim 是被包裝的優化器
        self._sync_param_groups(self.optim.param_groups, self.param_groups)

    if self.parameters_as_bucket_view:
        self._setup_flat_buffers() # 建立 flat buffer

我們用 rank 5 爲例,其本地優化器就只是指向 _partition_parameters_cache[5] 對應的那部分待優化參數,本地優化器只優化這些參數即可

這樣就實現了優化器參數分區。_partition_parameters_cache[5] 這樣的參數可以在後續被放置到 GPU 之上,這樣每個GPU就只包括 優化器的部分分區

需要注意的是:模型參數,梯度都沒有變化,只是本地 ZeroRedundancyOptimizer 指向了部分需要優化的參數,所以 ZeroRedundancyOptimizer 的優化器狀態也相應減少了

就下圖來說,原先優化器需要優化全部的參數,可能有 100 M + 105 M + ....,現在ZeroRedundancyOptimizer只需要優化 105 M。

 _partition_parameters_cache

        +
        |
        |
        v                +---------------+
+-------+---------+      | param_group   |
|       0         +----> |               |      <-------+  100 M   +------------->
+-----------------+      +---------------+
|       1         |      |               |     +--------+---------+------+--------+
+-----------------+      |   "params" +------> |param 1 | param 2 | ...  | param 6|
|       2         |      |               |     |        |         |      |        |
+-----------------+      +---------------+     +--------+---------+------+--------+
|                 |
|                 |
|     ......      |
|                 |      +---------------+
+-----------------+      | param_group   |      <-------+  105 M  +----------------->
|       5         +----> |               |
+-----------------+      +---------------+     +--------+---------+-------+---------+
                         |               |     |        |         |       |         |
                    +--> |  "params"  +------> | param 7| param 8 | ...   | param 11|
                    |    |               |     |        |         |       |         |
                    |    +---------------+     +--------+---------+-------+---------+
                    |
                    |
                    |
+-----------------------+
| Local Optimizer   |   |
|                   |   |
|                   |   |
|                   +   |
|                       |
|                       |
|                       |
|                       |
+-----------------------+

我們還需要再細化一下,看看 _sync_param_groups 和 _setup_flat_buffers 這兩個函數。

4.4.1 同步參數組

_sync_param_groups 用來把內部優化器的參數組同步到本Zero優化器的參數組

    @staticmethod
    def _sync_param_groups(source: List[Dict[Any, Any]], destination: List[Dict[Any, Any]]) -> None:
        r"""Sync learning rate and other optimizer attributes (needed to support schedulers)."""

        for source_group, destination_group in zip(source, destination):
            # Sync everything but the parameters
            for k in filter(lambda x: x != "params", source_group.keys()):
                destination_group[k] = source_group[k]

4.4.2 建立single buffer

如果設置了parameters_as_bucket_view,則調用_setup_flat_buffers 建立若干buffer。同樣設備上同樣rank的張量被視爲一個buffer。就是處理 _per_device_params。

def _setup_flat_buffers(self) -> None:
    r"""
    Make all params which are on the same device and tied to the same rank
    views of a single buffer. This is used at construction time, and anytime
    parameter trainability is changed (frozen or unfrozen) and
    ``_update_trainable`` is called.
    """

    for device, per_rank_params in self._per_device_params.items():
        # Only wipe the existing buckets if there are none
        # (could be that this is called twice, when trainability changes)
        if device not in self.buckets.keys():
            self.buckets[device] = []

        # Make parameters a view of the bucket
        for dst_rank, params in enumerate(per_rank_params):
            if len(params) > 0:

                # Clone the non-trainable params, if in a bucket it will get destroyed
                for param in filter(lambda x: not x.requires_grad, params):
                    param.data = param.data.detach().clone()

                # Merge all the trainable params in a single bucket
                trainable_params = list(filter(_is_trainable, params))
                buffer_size = sum(map(lambda x: x.numel(), trainable_params))
                bucket = torch.empty(buffer_size, dtype=params[0].dtype, device=device)
                offset = 0

                for param in trainable_params:
                    offset_next = offset + param.numel()
                    bucket[offset:offset_next].copy_(param.data.flatten())
                    param.data = bucket[offset:offset_next].view_as(param.data)
                    offset = offset_next

                # Either replace the existing bucket, or create it
                if len(self.buckets[device]) == dst_rank:
                    self.buckets[device].append(bucket)
                else:
                    self.buckets[device][dst_rank] = bucket
            else:
                self.buckets[device].append(torch.zeros(1, device=device))

具體可以看看如下圖例,同樣設備上同樣rank的張量被視爲一個buffer。

buckets
     +
     |
     |               +---------------------------------------+
     v               | Tensor                                |
+----+-------+       | +-----------------------------------+ |
|            |       | |                                   | |
|  "CPU"     +-----> | | Param 1, param 2,  Param 3......  | |
|            |       | +-----------------------------------+ |
+------------+       +---------------------------------------+
|            |
|  "GPU 1"   +-----> +---------------------------------------+
|            |       | Tensor                                |
+------------+       | +-----------------------------------+ |
|            |       | |                                   | |
|            |       | | Param 6, Param 7,  Param 8......  | |
|            |       | +-----------------------------------+ |
|            |       +---------------------------------------+
|            |
+------------+

0x05 更新參數

我們接下來看看優化器如何更新參數,其邏輯如下:

  • 如果計算圖有變化,則需要重新處理。
  • 調用 _sync_param_groups 將本地優化器參數同步給 ZeRO優化器,防止其被 scheduler 已經修改。
  • 調用 self.optim.step,讓本地優化器在本地參數之上進行更新。
  • 調用 dist.broadcast 在ranks 之間同步參數。
  • 再次調用 _sync_param_groups 將本地優化器參數同步給 ZeRO優化器,因爲其已經被更新了。
def step(self, closure: Optional[Callable[[], float]] = None, **kwargs: Any) -> Optional[float]:
    r"""
    Performs a single optimization step (parameter update).

    Arguments:
        closure (callable): A closure that reevaluates the model and
            returns the loss. Optional for most optimizers.
    Returns:
        optional loss, depends on the underlying optimizer

    .. note: Any extra parameter is passed to the base optimizer as-is
    """

    # Check whether the model trainability graph changed
    # 如果計算圖有變化,則需要重新處理
    trainable_mask = list(map(_is_trainable, self._all_params))
    if trainable_mask != self._reference_is_trainable_mask:
        self._update_trainable()
        self._reference_is_trainable_mask = trainable_mask

    # Sync oss param_groups attributes in case they've been updated by a scheduler.
    self._sync_param_groups(self.param_groups, self.optim.param_groups)

    # Run the optimizer step on this shard only:
    # 更新本地參數
    if closure is not None:
        loss = self.optim.step(closure=closure, **kwargs)  # type: ignore[call-arg]
    else:
        loss = self.optim.step(**kwargs)

    # Sync all the updated shards in between the ranks
    handles = []
    if self.parameters_as_bucket_view:
        for device in self.buckets.keys():
            for src_rank, bucket in enumerate(self.buckets[device]):
                global_src_rank = _get_global_rank(self.group, src_rank)
                handles.append(dist.broadcast(tensor=bucket, src=global_src_rank, group=self.group, async_op=True))
    else:        
        for device, per_rank_params in self._per_device_params.items(): # 遍歷設備+其參數
            for dst_rank, params in enumerate(per_rank_params): # 遍歷rank
                global_dst_rank = _get_global_rank(self.group, dst_rank)
                for param in params: # 對於每一個參數,都進行broadcast
                    handles.append(
                        dist.broadcast(tensor=param.data, src=global_dst_rank, group=self.group, async_op=True)
                    )

    _ = list(map(lambda x: x.wait(), handles))

    # Sync hypothethical new results from the wrapped optimizer to the exposed param_groups
    self._sync_param_groups(self.optim.param_groups, self.param_groups)

    return loss

5.1 更新

首先是本地更新模型參數。

# 更新本地參數
if closure is not None:
    loss = self.optim.step(closure=closure, **kwargs)  # type: ignore[call-arg]
else:
    loss = self.optim.step(**kwargs)

假設模型一共有8個參數,分成上下兩個節點,每個節點有一個優化器。這裏爲了更好說明,在上下兩個優化器中,把參數和rank序號大的放在上面。

再次強調:模型參數,梯度都沒有變化,只是本地 ZeroRedundancyOptimizer 指向了部分需要優化的參數,所以 ZeroRedundancyOptimizer 的優化器狀態也相應減少了

所以,上下兩個優化器之中,模型(需要優化的參數)大小都一樣,但是:

  • ZeroRedundancyOptimizer 0 之中,優化的是 rank 0,參數 0 ~ 3 是本地優化的,對應兩個節點來說,這部分參數是全局最新的。

  • ZeroRedundancyOptimizer 1 之中,優化的是 rank 1,參數 4 ~ 7 是本地優化的,對應兩個節點來說,這部分參數是全局最新的。

+--------------------------------------------------------------------------------+
|                                                     ZeroRedundancyOptimizer 0  |
|                                                                                |
|   _per_device_params_cache                                                     |
|       +                                                                        |
|       |                                                                        |
|       v          +--------+           +--------+--------+-------+--------+     |
|   +---+-----+    | rank 1 |           |        |        |       |        |     |
|   |         |    |        +---------> | param4 | param5 | param6| param7 |     |
|   | "GPU"1" +--> +--------+           |        |        |       |        |     |
|   |         |    |        |           +--------+--------+-------+--------+     |
|   +---------+    | rank 0 |                                                    |
|                  |        |           +--------+--------+-------+--------+     |
|                  |        +---------> |        |        |       |        |     |
|                  +--------+           | param0 | param1 |param2 | param3 | NEW |
|                               +---->  |        |        |       |        |     |
|   +----------------+          |       +--------+--------+-------+--------+     |
|   |Local Optimizer |          |                                                |
|   |                +----------+                                                |
|   |                |                                                           |
|   +----------------+                                                           |
|                                                                                |  Node 0
+--------------------------------------------------------------------------------+



+--------------------------------------------------------------------------------+
|                                                                                |  Node 1
|                                                                                |
|   _per_device_params_cache                                                     |
|       +                                                                        |
|       |                               +--------+--------+-------+--------+     |
|       v          +--------+     +---> |        |        |       |        |     |
|   +---+-----+    | rank 1 |     |     | param4 | param5 | param6| param7 | NEW |
|   |         |    |        +---------> |        |        |       |        |     |
|   | "GPU"1" +--> +--------+     |     +--------+--------+-------+--------+     |
|   |         |    |        |     |                                              |
|   +---------+    | rank 0 |     |     +--------+--------+-------+--------+     |
|                  |        +---------> |        |        |       |        |     |
|                  |        |     |     | param0 | param1 |param2 | param3 |     |
|                  +--------+     |     |        |        |       |        |     |
|                                 |     +--------+--------+-------+--------+     |
|   +----------------+            |                                              |
|   |Local Optimizer |            |                                              |
|   |                +------------+                                              |
|   |                |                                                           |
|   +----------------+                                 ZeroRedundancyOptimizer 1 |
|                                                                                |
+--------------------------------------------------------------------------------+

5.2 廣播

首先需要注意,_per_device_params 這裏包括全部的模型參數,雖然已經按照設備進行了分類。

現在狀態是,本rank的優化器參數(本分區)已經更新了,就是模型的部分得到了更新。爲了維持模型的最新,需要彼此進行廣播。

在本地更新參數後,每個rank將向所有其他對等方廣播其參數,以保持所有模型副本處於相同狀態。

+--------------------------------------------------------------------------------+
|                                                     ZeroRedundancyOptimizer 0  |
|                                                                                |
|   _per_device_params_cache                                                     |
|       +                                                                        |
|       |                                                                        |
|       v          +--------+           +--------+--------+-------+--------+     |
|   +---+-----+    | rank 1 |           |        |        |       |        |     |
|   |         |    |        +---------> | param4 | param5 | param6| param7 |     |
|   | "GPU"1" +--> +--------+           |        |        |       |        |     |
|   |         |    |        |           +--------+--------+-------+--------+     |
|   +---------+    | rank 0 |                                                    |
|                  |        |           +--------+--------+-------+--------+     |
|                  |        +---------> |        |        |       |        |     |
|                  +--------+           | param0 | param1 |param2 | param3 | NEW |
|                               +---->  |        |        |       |        |     |
|   +----------------+          |       +---+----+---+----+-+-----+--+-----+     |
|   |Local Optimizer |          |           |        |      |        |           |
|   |                +----------+           |        |      |        |           |
|   |                |                      |  ^     |  ^   |  ^     |   ^       |
|   +----------------+                      |  |     |  |   |  |     |   |       |
|                                           |  |     |  |   |  |     |   |       | Node 0
+--------------------------------------------------------------------------------+
                                            |  |     |  |   |  |     |   |
                                            |  |     |  |   |  |     |   |
                                            |  |     |  |   |  |     |   |
+--------------------------------------------------------------------------------+
|                                           |  |     |  |   |  |     |   |       | Node 1
|                                           v  |     v  |   v  |     v   |       |
|   _per_device_params_cache                   |        |      |         |       |
|       +                                      |        |      |         |       |
|       |                               +------+-+------+-+----+--+------+-+     |
|       v          +--------+     +---> |        |        |       |        |     |
|   +---+-----+    | rank 1 |     |     | param4 | param5 | param6| param7 | NEW |
|   |         |    |        +---------> |        |        |       |        |     |
|   | "GPU"1" +--> +--------+     |     +--------+--------+-------+--------+     |
|   |         |    |        |     |                                              |
|   +---------+    | rank 0 |     |     +--------+--------+-------+--------+     |
|                  |        +---------> |        |        |       |        |     |
|                  |        |     |     | param0 | param1 |param2 | param3 |     |
|                  +--------+     |     |        |        |       |        |     |
|                                 |     +--------+--------+-------+--------+     |
|   +----------------+            |                                              |
|   |Local Optimizer |            |                                              |
|   |                +------------+                                              |
|   |                |                                                           |
|   +----------------+                                 ZeroRedundancyOptimizer 1 |
|                                                                                |
+--------------------------------------------------------------------------------+

5.3 同步本地參數

最後,需要再次調用 _sync_param_groups 將本地優化器參數同步給 ZeRO優化器,因爲其已經被更新了。

# Sync hypothethical new results from the wrapped optimizer to the exposed param_groups
self._sync_param_groups(self.optim.param_groups, self.param_groups)

具體函數我們再揪出來溫習一下。

@staticmethod
def _sync_param_groups(source: List[Dict[Any, Any]], destination: List[Dict[Any, Any]]) -> None:
    r"""Sync learning rate and other optimizer attributes (needed to support schedulers)."""

    for source_group, destination_group in zip(source, destination):
        # Sync everything but the parameters
        for k in filter(lambda x: x != "params", source_group.keys()):
            destination_group[k] = source_group[k]

0xFF 參考

談談torch1.10中的ZeroRedundancyOptimizer和Join

https://pytorch.org/tutorials/recipes/zero_redundancy_optimizer.html

https://pytorch.org/docs/master/distributed.optim.html

https://medium.com/swlh/inside-microsofts-new-frameworks-to-enable-large-scale-ai-953e9a977912

https://www.microsoft.com/en-us/research/blog/zero-2-deepspeed-shattering-barriers-of-deep-learning-speed-scale/

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章