[源碼分析] Facebook如何訓練超大模型--- (5)

[源碼分析] Facebook如何訓練超大模型--- (5)

0x00 摘要

我們在前文介紹過,微軟 ZeRO 可以對一個萬億參數模型可以使用 8 路模型並行、64 路管道並行和 8 路數據並行在 4,096 個 NVIDIA A100 GPU 上進行擴展。而FSDP(Fully Sharded Data Parallel)是Facebook 深度借鑑微軟ZeRO之後提出的PyTorch DDP升級版本,可以認爲是對標微軟 ZeRO,其本質是 parameter sharding。Parameter sharding 就是把模型參數等切分到各個GPU之上。我們會以 Google,微軟和 Facebook 的論文,博客以及代碼來進行學習分析。

之前文章之中我們談到了FSDP支持混合精度訓練,本篇來看看 Activation recomputation。

本系列其他文章如下:

[源碼解析] PyTorch 分佈式之 ZeroRedundancyOptimizer

[論文翻譯] 分佈式訓練 Parameter sharding 之 ZeRO

[論文翻譯] 分佈式訓練 Parameter Sharding 之 Google Weight Sharding

[源碼分析] Facebook如何訓練超大模型---(1)

[源碼分析] Facebook如何訓練超大模型 --- (2)

[源碼分析] Facebook如何訓練超大模型 --- (3)

[源碼分析] Facebook如何訓練超大模型---(4)

0x01 背景

激活重新計算(Activation recomputation),也稱“激活檢查點(activation checkpointing)”或“梯度檢查點(gradient checkpointing)”(Chen et al,2016 https://arvix.org/abs/1604.06174),其思路是用時間換空間,即,犧牲計算時間來換取內存空間。其減少了深度神經網絡訓練層的內存開銷,代價是每個batch會消耗額外的前向傳播計算。

比方說,該方法將m層網絡平均劃分爲d個分區,只保存分區邊界的激活,並在workers之間交換這些激活。因爲後向傳播之中依然需要分區內層間激活值(Intermediate activations at intra-partition layers)來計算梯度,所以在後向傳播過程中會在分區內部重新計算激活。

下圖爲論文之中的示意圖。

我們在之前文章之中介紹過重計算 [源碼解析] 深度學習流水線並行 GPipe(3) ----重計算。本文會看看 FairScale 是如何對其進行進一步封裝和改進。

0x02 思路

2.1 學習建議

在看思路之前,我們先來講講如何更好的分析一個開源框架或者說如何學習源碼。個人的意見是按照:論文 --> 文檔 --> 用戶手冊 --> 註釋 --> 源碼 這個順序來學習。

爲什麼按照這個順序?因爲這個順序是:

  • 從抽象邏輯(或者說是體系架構)到具體細節。
    • 論文是把作者的思想提煉,邏輯化,體系化的結果,文檔次之。而且重讀經典論文,其收穫是多維度的。
    • 手冊則會從使用或者注意點方面幫你完成對這個框架整體的認識。
    • 源碼則給你呈現了大量的細節。
  • 從人的思想到機器的思想。
    • 註釋是作者給閱讀者看的,代碼是作者給機器看的。
    • 註釋會告訴你爲什麼這樣實現(Why),代碼告訴你怎麼實現(How)。

對於我們來說,應該首先尋求一種思維的改變,知識框架的更新與整理,然後纔是用代碼來分析驗證(畢竟紙上得來終覺淺)。當然,很多時候我們只有源碼,那麼就只能從源碼之中根據細節來探尋,重建作者的思路,提煉其精華,爭取和作者達到一個跨越空間和時間的共鳴,共鳴越多,你就越接近作者了 _

2.2 具體思路

我們接下來就看看源碼文檔之中的思路介紹。

激活檢查點是一種用於減少訓練期間GPU內存使用的技術。具體做法是:

  • 在向前傳播過程中避免存儲中間激活張量。
  • 在後向傳播過程中依靠跟蹤原始輸入來重新進行前向傳播計算。

其結果是:以略有增加(約33%)的計算成本來減少了存儲大型激活張量的必要,因此允許我們增加batch size,從而增加模型的淨吞吐量。

激活檢查點是通過重載 torch.autograd.Function 來完成的。

  • 通過在前向函數之中使用no_grad,我們可以在很長一段時間內(即直到反向傳播開始)避免前向計算圖的創建和中間激活張量的具化(materialization)。
  • 在向後傳播期間內,會先再次執行向前傳播,然後執行向後傳播。
    • 向前傳播的輸入已經保存在上下文對象之中,所以在向後傳播之中可以通過該上下文對象拿到原始輸入。
    • 因爲在某些情況下(Dropout layers)會用到,所以還保存了前向和後向傳播的Random Number Generator(RNG) 狀態。

上述功能在torch.utils.checkpoint.checkpoint_wrapper 之中可以看到其具體實現,可以在前向傳播之中使用這個API來對模塊進行封裝。FairScale中的包裝器提供的功能比PyTorch API提供的功能更多,比如用戶可以使用 fairscale.nn.checkpoint.checkpoint_wrapper 來包裝一個 nn.Module,這樣就可以在正向傳遞中處理kwargs,將中間激活卸載(offload)到CPU,並處理從前向函數返回的非張量輸出。

2.3 最佳實踐

我們接下來看看 fairscale.nn.checkpoint.checkpoint_wrapper 的最佳實踐。

  • 內存節省效果取決於模型和checkpoint wrapping如何進行分段。即,內存節省收益取決於層激活的內存佔用情況。
  • 使用BatchNormalization時,您可能需要凍結統計數據的計算,因爲在這種情況下會運行兩次前向傳播。
  • 確保輸入張量的requires_grad 屬性設置爲True。通過將輸入張量的requires_grad 屬性設置爲True,我們確保輸入可以傳播到輸出,並觸發 backward 函數。

0x03 具體實現

3.1 Wrapper

checkpoint_wrapper 是具體的wrapper,其內部就是調用了其他函數。但是我們發現其註釋可以讓我們進一步學習,所以翻譯如下:

checkpoint_wrapper 是執行激活檢查點的包裝器,其比PyTorch版本更加用戶友好,具備如下特點:

  • 包裝一個nn.Module,以便所有後續調用都將使用checkpointing。

  • 處理前向過程中的關鍵字參數(keyword arguments)。

  • 處理來自正向過程中的非張量輸出。

  • 支持將激活卸載到CPU。

爲了更好的瞭解checkpointing和"offload_to_cpu"帶來的好處,我們將激活分爲兩種類型:

  • 內部激活。其依靠 activation checkpointing 來保存。
  • 外部激活,也就是檢查點模塊。其依靠offload_to_cpu來保存。

就GPU內存節約效果而言:

  • 當內部激活很大而外部激活很小時,檢查點會帶來很大收穫,offload_to_cpu可能只帶來很小的收益。

  • 當內部激活小而外部激活很大時,檢查點幫助很小,offload_to_cpu會帶來很大收益。

  • 當內部激活和外部激活都很大時,檢查點和offload_to_cpu帶來的益處是疊加的。

另外,第一層和最後一層不太可能受益於offload_to_cpu標誌,因爲:

  • 第一層的輸入通常有其他引用,因此GPU內存不會被釋放;
  • 最後一層的輸入會立即被向後傳播使用,不會節省內存。
def checkpoint_wrapper(
    module: nn.Module, offload_to_cpu: bool = False, maintain_forward_counter: bool = False
) -> nn.Module:
    """
    A friendlier wrapper for performing activation checkpointing.

    Compared to the PyTorch version, this version:

        - wraps an nn.Module, so that all subsequent calls will use checkpointing
        - handles keyword arguments in the forward
        - handles non-Tensor outputs from the forward
        - supports offloading activations to CPU

    Usage::

        checkpointed_module = checkpoint_wrapper(my_module, offload_to_cpu=True)
        a, b = checkpointed_module(x, y=3, z=torch.Tensor([1]))

    To understand the benefits of checkpointing and the `offload_to_cpu` flag,
    let's divide activations into 2 types: inner activations and outer
    activations w.r.t. the checkpointed modules. The inner ones are saved
    by activation checkpointing, the outer ones are saved by offload_to_cpu.

    In terms of GPU memory savings:

        - When inner ones are large in size and outer ones are small,
          checkpointing helps a lot, offload_to_cpu may help a little.
        - When inner ones are small and outer ones are large,
          checkpointing helps little, offload_to_cpu helps a lot.
        - When both inner and outer are large, both help and the
          benefit is additive.

    ..Note::

        The first and last layers are not likely to benefit from the `offload_to_cpu` flag
        because (1) there are typically other references to the first layer's input, so
        the GPU memory won't be freed; (2) the input to the last layer is immediately
        used by the backward pass and won't result in memory savings.

    Args:
        module (nn.Module):
            The module to be wrapped
        offload_to_cpu (bool):
            Whether to offload activations to CPU.
        maintain_forward_counter (bool):
            If True, maintain a forward counter per inner module. The counter will first
            increases in forward calls of outer forward pass and then decreases in the
            forward calls of outer backward pass. It is used by FullyShardedDataParallel.

    Returns:
        (nn.Module):
            Wrapped module
    """
    # Patch the batchnorm layers in case there are any in this module.
    patch_batchnorm(module)

    if maintain_forward_counter:
        init_counter(module)

    # The use of weakref here is to prevent creating a ref cycle: m -> m.forward -> m.
    # When such cycle exists, gc won't collect the module when the module is freed.
    # That causes GPU memory to be leaked. See the unit test for how we catch that.
    #
    # We prefer this over a class wrapper since the class wrapper would have to
    # proxy a lot of fields and methods.
    module.forward = functools.partial(  # type: ignore
        _checkpointed_forward, type(module).forward, weakref.ref(module), offload_to_cpu
    )
    return module # 包裝一個nn.Module,以便所有後續調用都將使用checkpointing

3.2 如何使用

我們從源碼之中找出一些代碼,大家可以看看。

self.layers = nn.Sequential(
    nn.Sequential(nn.Linear(4, 4), nn.Linear(4, 4), nn.Linear(4, 8)),
    nn.Sequential(nn.Linear(8, 4), nn.Linear(4, 4), nn.Linear(4, 4)),
    nn.Sequential(nn.Linear(4, 6), nn.Linear(6, 8), nn.Linear(8, 2)),
)

if enable_checkpoint:
    for i, layer in enumerate(self.layers):
        # Only middle layer needs to have offloading
        self.layers[i] = checkpoint_wrapper(layer, cpu_offload if i == 1 else False)

3.2 _checkpointed_forward

前面提到對比PyTorch版本,FairScale有幾點益處,此處就對應了以下有下劃線的兩點:

  • 包裝一個nn.Module,以便所有後續調用都將使用checkpointing。

  • 處理前向過程中的關鍵字參數(keyword arguments)。

  • 處理來自正向過程中的非張量輸出。

  • 支持將激活卸載到CPU。

代碼邏輯如下:

  • 如果禁用了disabled,則直接使用 .forward() 。這樣做還可以確保內部fwd counter在前向過程中不會增加,但是這在eval過程中會是一個問題,因爲不會有相應的後向過程來減少fwd counter。
  • 因爲後向傳播必須爲每個輸入參數返回一個梯度(或None),所以PyTorch中的Autograd函數在帶有位置信息參數下工作最佳。將關鍵字參數扁平化可以讓這種處理更加方便。
  • 調用 CheckpointFunction 完成 activation checkpointing。這裏需要注意的是:當original_forward的輸入爲非張量(即一個元組)時,因此 CheckpointFunction 傳入了一個帶有grad的 dummy tensor 參數來確保向後傳播被調用。
    • 在輸入爲元組類型的情況下,即便設置張量的requires_grad標誌也不會觸發後向傳播。
    • 使用這個 dummy tensor 可以避免要求用戶設置輸入張量的requires_grad標誌。
  • 處理來自正向過程中的輸出爲tuple,就是把張量和非張量打包在一起。

具體代碼如下:

def _checkpointed_forward(
    original_forward: Any, weak_self: Any, offload_to_cpu: bool, *args: Any, **kwargs: Any
) -> Any:
    module = weak_self()

    # If gradients are disabled, just use original `.forward()` method directly.
    # Doing so also ensures the internal fwd counter is not incremented in the forward pass,
    # which would be an issue during eval since there wouldn't be a corresponding backward pass
    # to decrement the fwd counter.
    # See https://github.com/facebookresearch/fairscale/pull/709.
    if not torch.is_grad_enabled():
        return original_forward(module, *args, **kwargs)

    # Autograd Functions in PyTorch work best with positional args, since
    # the backward must return gradients (or None) for every input argument.
    # We can flatten keyword arguments to make this easier.
    args = (module,) + args
    kwarg_keys, flat_args = pack_kwargs(*args, **kwargs) # 處理輸入
    parent_ctx_dict: Dict[str, Any] = {
        "offload": offload_to_cpu,
    }
    # Dummy tensor with grad is used to ensure the backward pass is called. This is needed
    # when original_forward's input are non-tensor (i.e. a tuple). Using this dummy tensor
    # avoids requiring users to set their input tensors's requires_grad flag. In the case
    # of tuple type inputs, setting the flag won't even trigger the backward pass.
    output = CheckpointFunction.apply(
        torch.tensor([], requires_grad=True), original_forward, parent_ctx_dict, kwarg_keys, *flat_args
    )
    
    # 處理非張量輸出
    if not isinstance(output, torch.Tensor):
        # parent_ctx_dict["packed_non_tensor_outputs"] 是 CheckpointFunction 返回的
        packed_non_tensor_outputs = parent_ctx_dict["packed_non_tensor_outputs"]
        if packed_non_tensor_outputs:
            # 統一處理成tuple
            output = unpack_non_tensors(output, packed_non_tensor_outputs) # 處理輸出
    return output

3.2.1 處理輸入

在處理前向過程中的關鍵字參數(keyword arguments)之中,使用了pack_kwargs,其作用就是把參數的key,value整理成爲兩個list,具體可以參見示例。

def pack_kwargs(*args: Any, **kwargs: Any) -> Tuple[Tuple[str, ...], Tuple[Any, ...]]:
    """
    Turn argument list into separate key list and value list (unpack_kwargs does the opposite)
    Usage::

        kwarg_keys, flat_args = pack_kwargs(1, 2, a=3, b=4)
        assert kwarg_keys == ("a", "b")
        assert flat_args == (1, 2, 3, 4)
        args, kwargs = unpack_kwargs(kwarg_keys, flat_args)
        assert args == (1, 2)
        assert kwargs == {"a": 3, "b": 4}
    """
    kwarg_keys: List[str] = []
    flat_args: List[Any] = list(args)
    for k, v in kwargs.items():
        kwarg_keys.append(k)
        flat_args.append(v)
    return tuple(kwarg_keys), tuple(flat_args)

3.2.2 非張量輸出

3.2.2.1 壓縮非張量

把一個tuple分割爲一個張量列表和後續重建所需要的信息。

def split_non_tensors(
    mixed: Union[torch.Tensor, Tuple[Any, ...]]
) -> Tuple[Tuple[torch.Tensor, ...], Optional[Dict[str, List[Any]]]]:
    """
    Split a tuple into a list of tensors and the rest with information
    for later reconstruction.

    Usage::

        x = torch.Tensor([1])
        y = torch.Tensor([2])
        tensors, packed_non_tensors = split_non_tensors((x, y, None, 3))
        assert tensors == (x, y)
        assert packed_non_tensors == {
            "is_tensor": [True, True, False, False],
            "objects": [None, 3],
        }
        recon = unpack_non_tensors(tensors, packed_non_tensors)
        assert recon == (x, y, None, 3)
    """
    if isinstance(mixed, torch.Tensor):
        return (mixed,), None
    tensors: List[torch.Tensor] = []
    packed_non_tensors: Dict[str, List[Any]] = {"is_tensor": [], "objects": []}
    for o in mixed:
        if isinstance(o, torch.Tensor):
            packed_non_tensors["is_tensor"].append(True)
            tensors.append(o)
        else:
            packed_non_tensors["is_tensor"].append(False)
            packed_non_tensors["objects"].append(o)
    return tuple(tensors), packed_non_tensors
3.2.2.2 解壓非張量

unpack_non_tensors 用來把非張量列表恢復成tuple。

def unpack_non_tensors(
    tensors: Tuple[torch.Tensor, ...], packed_non_tensors: Optional[Dict[str, List[Any]]]
) -> Tuple[Any, ...]:
    """See split_non_tensors."""
    if packed_non_tensors is None:
        return tensors
    assert isinstance(packed_non_tensors, dict), type(packed_non_tensors)
    mixed: List[Any] = []
    is_tensor_list = packed_non_tensors["is_tensor"]
    objects = packed_non_tensors["objects"]

    obj_i = tnsr_i = 0
    for is_tensor in is_tensor_list:
        if is_tensor:
            mixed.append(tensors[tnsr_i])
            tnsr_i += 1
        else:
            mixed.append(objects[obj_i])
            obj_i += 1
    return tuple(mixed)

3.3 CheckpointFunction

我們接下來分析 CheckpointFunction,就是具體 activation checkpointing 的業務函數。關於 PyTorch 的 CheckpointFunction 版本,可以參見 [源碼解析] 深度學習流水線並行 GPipe(3) ----重計算

這裏對應了優點之中的:支持將激活卸載到CPU

3.3.1 前向傳播

其前向傳播的邏輯如下:

  • 分割非張量參數列表,得到張量輸入和非張量輸入。
    • 如果設置了"offload",在上下文記錄設備,梯度需求情況,並且把輸入張量放到cpu上。
  • 爲後向傳播保存輸入。
  • 如果設置了activation checkpointing,則處理參數,進行前向計算。
  • 如果輸出不是張量,因爲Autograd Functions不喜歡非張量輸出。我們可以拆分爲非張量和張量輸出,通過parent_ctx_dict引用返回前者,然後直接返回後者。
class CheckpointFunction(torch.autograd.Function):
    """Similar to the torch version, but support non-Tensor outputs.

    The caller is expected to provide a dict (*parent_ctx_dict*) that will hold
    the non-Tensor outputs. These should be combined with the Tensor *outputs*
    by calling :func:`unpack_non_tensors`.
    """

    @staticmethod
    def forward(  # type: ignore
        ctx: Any,
        dummy_tensor_requires_grad: torch.Tensor,
        run_function: Any,
        parent_ctx_dict: Dict[str, Any],
        kwarg_keys: Tuple[str, ...],
        *args: Any,
        **kwargs: Any
    ) -> Any:
        torch_checkpoint.check_backward_validity(args)

        ctx.run_function = run_function # 在上下文之中存儲前向傳播函數
        ctx.kwarg_keys = kwarg_keys
        ctx.fwd_rng_state = get_rng_state() # 在上下文之中存儲前向傳播狀態
        ctx.had_autocast_in_fwd = is_autocast_enabled()

        # 分割非張量參數列表,得到張量輸入和非張量輸入
        tensor_inputs, packed_non_tensor_inputs = split_non_tensors(args) 
        if parent_ctx_dict["offload"]:
            # 在上下文記錄設備,梯度需求情況,並且把輸入張量放到cpu上
            ctx.fwd_device = tuple(x.device for x in tensor_inputs) # 在上下文存儲前向傳播設備
            ctx.grad_requirements = tuple(x.requires_grad for x in tensor_inputs)
            tensor_inputs = tuple(x.to("cpu", non_blocking=True) for x in tensor_inputs)
        else:
            ctx.fwd_device, ctx.grad_requirements = None, None

        # 爲後向傳播保存輸入
        ctx.save_for_backward(*tensor_inputs)
        ctx.packed_non_tensor_inputs = packed_non_tensor_inputs

        with torch.no_grad(), enable_checkpointing(): # 如果設置了activation checkpointing
            unpacked_args, unpacked_kwargs = unpack_kwargs(kwarg_keys, args) # 處理參數
            outputs = run_function(*unpacked_args, **unpacked_kwargs) # 前向計算
            the_module = unpacked_args[0]
            inc_counter(the_module)

        if not isinstance(outputs, torch.Tensor): # 如果輸出不是張量
            # Autograd Functions don't like non-Tensor outputs. We can split the
            # non-Tensor and Tensor outputs, returning the former by reference
            # through *parent_ctx_dict* and returning the latter directly.
            # Autograd Functions不喜歡非張量輸出。我們可以拆分爲非張量和張量輸出,
            # 通過parent_ctx_dict引用返回前者,然後直接返回後者。
            outputs, packed_non_tensor_outputs = split_non_tensors(outputs)
            parent_ctx_dict["packed_non_tensor_outputs"] = packed_non_tensor_outputs
        return outputs


3.3.2 後向傳播

後向傳播邏輯如下:

  • 拿到存儲在上下文的張量輸入。
  • 如果設置了在設備上計算,則:
    • 把 offlad 的張量再移到 GPU之上。
    • 找到需要計算的梯度。
  • 處理非張量輸入,最終和張量輸入組合在一起。
  • 保存當前狀態。
  • 從上下文加載前向傳播時候的狀態。
  • 重新做前向傳播。
  • 處理前向傳播輸出。
  • 恢復後向傳播的狀態。
  • 從前向傳播輸出找到需要梯度的張量,在後向傳播的輸入之中找到對應的張量。
  • 進行後向傳播。
  • 返回梯度。
class CheckpointFunction(torch.autograd.Function):
    """Similar to the torch version, but support non-Tensor outputs.

    The caller is expected to provide a dict (*parent_ctx_dict*) that will hold
    the non-Tensor outputs. These should be combined with the Tensor *outputs*
    by calling :func:`unpack_non_tensors`.
    """

    @staticmethod
    def backward(ctx: Any, *args: Any) -> Tuple[Optional[Tensor], ...]:
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible")

        # 拿到存儲在上下文的張量輸入
        tensor_inputs: Tuple = ctx.saved_tensors
        tensor_inputs = torch_checkpoint.detach_variable(tensor_inputs)
        if ctx.fwd_device is not None: # 如果設置了在設備上計算
            # 把 offload 的張量再移到 GPU之上
            tensor_inputs = tuple(t.to(ctx.fwd_device[i], non_blocking=True) for i, t in enumerate(tensor_inputs))
            for i, need_grad in enumerate(ctx.grad_requirements): # 找到需要計算的梯度
                tensor_inputs[i].requires_grad = need_grad
        # 處理非張量輸入,最終和張量輸入組合在一起        
        inputs = unpack_non_tensors(tensor_inputs, ctx.packed_non_tensor_inputs)

        # Store the current states.
        bwd_rng_state = get_rng_state() # 拿到之前保存的當前狀態

        # Set the states to what it used to be before the forward pass.
        set_rng_state(ctx.fwd_rng_state) # 從上下文加載前向傳播時候的狀態

        with torch.enable_grad(), enable_recomputing(), autocast(ctx.had_autocast_in_fwd):
            unpacked_args, unpacked_kwargs = unpack_kwargs(ctx.kwarg_keys, inputs)
            outputs = ctx.run_function(*unpacked_args, **unpacked_kwargs) # 重新做前向傳播
            tensor_outputs, _ = split_non_tensors(outputs) # 處理前向傳播輸出
            the_module = unpacked_args[0]
            dec_counter(the_module)

        # Set the states back to what it was at the start of this function.
        set_rng_state(bwd_rng_state) # 恢復後向傳播的狀態

        # Run backward() with only Tensors that require grad
        outputs_with_grad = [] 
        args_with_grad = []
        # 從前向傳播輸出找到需要梯度的張量
        for i in range(len(tensor_outputs)):
            if tensor_outputs[i].requires_grad:
                outputs_with_grad.append(tensor_outputs[i])
                args_with_grad.append(args[i]) # 在後向傳播的輸入之中找到對應的張量
        if len(outputs_with_grad) == 0:
            raise RuntimeError("None of the outputs have requires_grad=True, " "this checkpoint() is not necessary")

        # 進行後向傳播     
        torch.autograd.backward(outputs_with_grad, args_with_grad)

        # 從inputs裏面得到梯度
        grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in inputs)
        return (None, None, None, None) + grads # 返回梯度

我們整理邏輯如下:

0x04 OffloadFunction

前文在 OffloadModel 的 forward 方法之中,如果設置了 _checkpoint_activation,則調用 OffloadFunction 把激活檢查點卸載到CPU之上,直接返回。我們接下來看看 OffloadFunction 如何實現與activation相關的操作。

此函數通過覆蓋nn.Module的向前和向後傳播,在分片邊界啓用中間激活的檢查點。這樣只保存分區邊界的激活,並在workers之間交換這些激活。

本節與上節的主要區別是:

  • CheckpointFunction只是把輸入張量在GPU和CPU之間移動,丟棄了內部激活
  • OffloadFunction 把激活(沒有丟棄)與模型都在在GPU和CPU之間移動,而且因爲分區是一層或者多層layers,所以只是在worker之間交換這些分區邊界的激活。

4.1 前向傳播

在FW過程中,它遍歷每一個分區,針對每一個分區,刪除前一個分片中的參數,並加載下一個分片的參數,然後進行這個分區的前向計算。FW過程中未構造任何計算圖。這使我們能夠卸載分片邊界上的中間激活。

這裏有幾點說明:

  • model_instance.model_slices 是模型的分片,每個分片裏面包含一個或者多個層。
  • 除了之後一個分區的激活,其餘分區之間的激活都存在CPU之上。這裏假設目標張量也位於執行計算的GPU上,那麼對於最後一層計算來說,其輸出激活也應該位於這個GPU之上。如果輸出激活移動到CPU之上,反向傳播就可能找不到其梯度函數了。

具體代碼如下:

class OffloadFunction(torch.autograd.Function):
    """
     This Function enables checkpointing of intermediate activations at
     shard boundaries by overriding the forward and backward pass of the nn.Module.

     - In the FW pass, it drops parameters in the previous shard and
     loads parameters for the next shard. No graph is constructed in the FW pass.
     This enables us to offload intermediate activations present at the shard
     boundaries.

     - In the BW pass, it does the reverse. We run the forward pass using the
     saved intermediate activations and calculate gradients as needed.
     The trade-off is latency vs memory when using activation checkpointing.

     - Follows heavily from https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html#checkpoint.

     NOTE: see https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function
     """

    @staticmethod
    @_conditional_amp_fwd_decorator  # type: ignore
    def forward(ctx: Any, inputs: Any, dummy_input: Any, model_instance: Any) -> Any:
        inputs = inputs if isinstance(inputs, tuple) else (inputs,)

        # 把後向傳播所需要的信息存儲在上下文。
        ctx.inputs = inputs
        ctx.model_instance = model_instance
        # TODO(anj-s): We might need to store this for each boundary activation.
        # Currently we assume all boundary activation inputs require
        ctx.grad_requirements = tuple(x.requires_grad for x in inputs)
        ctx.fwd_rng_state = torch.get_rng_state()

        # List of input activations starting with the given input.
        model_instance._activations = [inputs]
        # Enumerate through layer shards and apply activations from the previous shard.
        for index, layer_shard in enumerate(model_instance.model_slices): # 遍歷模型的分區
            with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:forward_load"):
                # Bring in the current activations onto the device.
                # 把當前激活拷貝到設備之上
                model_instance._activations[index] = tuple([a.cuda() for a in list(model_instance._activations[index])])
                # Bring in the current layer shard onto the device.
                # 把當前層加載到設備之上
                layer_shard.forward_load()

            # Apply the FP and store the activations on the CPU.
            inputs = model_instance._activations[index]
            with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:no_grad_forward_pass"):
                with torch.no_grad(): # 不會跟蹤下面的梯度,只是計算激活
                    output_list: List[Any] = []
                    for given_input in inputs:
                        given_input_list = torch.chunk(given_input, model_instance._num_microbatches)
                        given_output_list = []
                        for inputs in given_input_list:
                            output = layer_shard(inputs) # 前向操作
                            given_output_list.append(output)
                        given_output = torch.cat(given_output_list).squeeze(-1)
                        output_list.append(given_output)
                    output = tuple(output_list) # 得到輸出

            output = output if isinstance(output, tuple) else (output,)
            with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:forward_drop"):
                # Move the activation used back for the curent shard back to the CPU.
                # 把激活移動到CPU
                model_instance._activations[index] = tuple([a.cpu() for a in list(model_instance._activations[index])])
                # The newly computed activations remain on the GPU ready for the next shard computation.
                model_instance._activations.append(output)
                # Move the layer shard back to the CPU.
                layer_shard.forward_drop() # 把層移動到CPU

        # The last instance will lose the gradient function if we move it to the CPU.
        # This is because all grad function are present on the device that ran the FW pass.
        # The last activation remains on the GPU and is the return value of this function.
        # Note that this assumes that the target is also on the GPU which is required for calculating
        # the loss.
        
        result = model_instance._activations[-1] # 最後一層的激活
        result = [r.cuda() for r in result] # 把最後一層的激活移動到設備上,其餘的已經移動到CPU之上
        for r in result:
            r.requires_grad = True
        return result[0] if len(result) == 1 else result

4.2 後向傳播

在BW過程中,它執行相反的操作。我們使用保存的中間激活運行前向傳播,並根據需要計算梯度。在使用激活檢查點時,需要權衡延遲和內存。因爲這裏會用到幾個PyTorch的內置方法,所以我們需要首先來看看其用法和原理。

4.2.1 no_grad

torch.no_grad() 是一個上下文管理器,被 no_grad 包括起來的代碼不會跟蹤其梯度。我們用一個例子來看看。

import torch

x = torch.tensor([2.2], requires_grad=True)
y = x * 3
print(y)
y.add_(2)
print(y)

with torch.no_grad():
    y.div_(3)
    print(y)

輸出爲:

tensor([6.6000], grad_fn=<MulBackward0>) # 這裏記錄了梯度操作
tensor([8.6000], grad_fn=<AddBackward0>) # add操作被跟蹤
tensor([2.8667], grad_fn=<AddBackward0>) # 用了no_grad,所以div沒有被跟蹤

4.2.2 chunk

torch.chunk(tensor, chunk_num, dim) 將張量按dimension(行或列)分割得到 chunk_num 個張量塊,此函數將返回一個元組,比如下面例子。

x = torch.Tensor([[1,2,3]])
y = torch.Tensor([[4,5,6], [7,8,9], [10,11,12]])
z = torch.cat((x,y), dim=0)
print(z)
print(z.size())
c = torch.chunk(z,4,dim=0)
print(c)
print(len(c))

輸出爲:

# cat之後的輸出
tensor([[ 1.,  2.,  3.],
        [ 4.,  5.,  6.],
        [ 7.,  8.,  9.],
        [10., 11., 12.]])
torch.Size([4, 3])

# chunk之後的輸出
(tensor([[1., 2., 3.]]), tensor([[4., 5., 6.]]), tensor([[7., 8., 9.]]), tensor([[10., 11., 12.]]))
4

4.2.3 反向傳播

OffloadFunction 的反向傳播如下,這裏有個reverse操作需要注意。

  • 在代碼初期,會把模型分片和激活進行reverse(注意,沒有把原始分配和激活進行reverse,這裏是reverse之後的結果返回,不影響原始數據),因爲計算梯度是從後向前,所以把-1放到第一個位置,依次類推,這樣可以方便使用backward_load和backward_drop。
  • 在代碼最後,因爲之前的reverse沒有對 model_instance._activations 做修改,所以可以直接返回輸入之中的梯度。

具體代碼如下:

class OffloadFunction(torch.autograd.Function):

    # Ignore the following function for code coverage since the backward pass
    # is triggered by C++ code and cannot be calculated when overriding
    # autograd.Function
    @staticmethod
    @_conditional_amp_bwd_decorator
    def backward(ctx, *grad_outputs):  # type: ignore # pragma: no cover
        inputs = ctx.inputs
        model_instance = ctx.model_instance

        # 遍歷上下文存儲的信息,給輸入設定是否需要梯度
        for i, need_grad in enumerate(ctx.grad_requirements):
            inputs[i].requires_grad = need_grad

        # 得到反向傳播的輸入
        all_grads = [grad_outputs]

        # 把模型分片和激活進行reverse(注意,沒有把原始分配和激活進行reverse,這裏是reverse之後的結果返回,不影響原始數據),因爲計算梯度是從後向前,所以把-1放到第一個位置,依次類推,這樣可以方便使用backward_load和backward_drop。
        
        # 然後遍歷模型分片,針對每一個分片進行處理
        for model_shard, activation in zip(
            reversed(model_instance.model_slices), reversed(model_instance._activations[:-1])
        ):
            with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:backward_load"):
                # Move the activation to the GPU.
                # 把當前分片的激活移動到GPU
                activation = tuple([a.cuda() for a in list(activation)])

                # 把當前分片的模型移動到GPU
                # Move the model shard to the GPU.
                model_shard.backward_load()

            # Store the BW pass state.
            # 暫存反向傳播狀態
            bwd_rng_state = torch.get_rng_state()

            # TODO(anj-s): Why detach inputs?
            activation = torch.utils.checkpoint.detach_variable(activation)
            # Get the last gradient calculation.
            final_grads = all_grads[-1] # 這將會是最終生成的梯度

            if isinstance(activation, torch.Tensor):
                activation = (activation,)
            if isinstance(final_grads, torch.Tensor):
                final_grads = (final_grads,)
            # Iterate through all the inputs/outputs of a shard (there could be multiple).
            chunked_grad_list: List[Any] = []
            # Chunk the activation and grad based on the number of microbatches that are set.
            # 因爲可能有多個微批次,所以需要把梯度和激活分別做chunk操作
            for chunked_activation, chunked_grad in zip(
                torch.chunk(*activation, model_instance._num_microbatches),  # type: ignore
                torch.chunk(*final_grads, model_instance._num_microbatches),  # type: ignore
            ):
                # Set the states to what it used to be before the forward pass.
                torch.set_rng_state(ctx.fwd_rng_state) # 暫時使用前向傳播狀態

                # 構建爲list
                if isinstance(chunked_activation, torch.Tensor):
                    chunked_activation = (chunked_activation,)  # type: ignore
                if isinstance(chunked_grad, torch.Tensor):
                    chunked_grad = (chunked_grad,)  # type: ignore

                # Since we need a grad value of a non leaf element we need to set these properties.
                for a in chunked_activation:
                    if a.dtype == torch.long:
                        continue
                    a.requires_grad = True # 因爲需要計算非葉子結點,所以將其設置爲需要梯度
                    a.retain_grad()

                with torch.autograd.profiler.record_function(
                    "fairscale.experimental.nn.offload:forward_pass_with_enable_grad"
                ):
                    with torch.enable_grad():
                        # calculate the output of the last shard wrt to the stored activation at the slice boundary.
                        outputs = model_shard(*chunked_activation) # 前向傳播

                # Set the states back to what it was at the start of this function.
                torch.set_rng_state(bwd_rng_state) # 恢復狀態
                
                with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:backward_pass"):
                    torch.autograd.backward(outputs, chunked_grad) # 反向傳播
                    
                intermediate_grads = []
                for a in chunked_activation:
                    if a.grad is not None:
                        intermediate_grads.append(a.grad)
                if None not in intermediate_grads:
                    chunked_grad_list += intermediate_grads
             
            # 把梯度列表添加到all_grads之上
            if chunked_grad_list:
                # Append the list of grads to the all_grads list and this should be on the GPU.
                all_grads.append(torch.cat(chunked_grad_list).squeeze(-1))  # type: ignore
                
            with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:backward_drop"):
                # Move the shard back to the CPU. This should move all the grad tensors to CPU as well.
                # We don't need to move activations since we are using a copy of the tensors on the GPU.
                model_shard.backward_drop() # 分區移動到CPU
           
        # 之前的reverse沒有對 model_instance._activations 做修改
        detached_inputs = model_instance._activations[0]
        # 從輸入之中拿到其梯度
        grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs)
        return (None, None) + grads # 返回梯度

邏輯拓展如下:

至此,FSDP 分析完畢,我們下一個系列將會通過 NVIDIA Megatron 來介紹模型並行,敬請期待。

0xFF

https://arxiv.org/pdf/2101.06840.pdf

https://www.deepspeed.ai/tutorials/zero-offload/

DeepSpeed: Extreme-scale model training for everyone

https://www.microsoft.com/en-us/research/blog/zero-infinity-and-deepspeed-unlocking-unprecedented-model-scale-for-deep-learning-training/

https://www.marktechpost.com/2021/02/01/microsoft-and-the-university-of-california-merced-introduces-zero-offload-a-novel-heterogeneous-deeplearning-training-technology-to-train-multi-billion-parameter-models-on-a-single-gpu/

[1] Li et al. “PyTorch Distributed: Experiences on Accelerating Data Parallel Training” VLDB 2020.

[2] Cui et al. “GeePS: Scalable deep learning on distributed GPUs with a GPU-specialized parameter server” EuroSys 2016

[3] Shoeybi et al. “Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism.” arXiv preprint arXiv:1909.08053 (2019).

[4] Narayanan et al. “Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM.” arXiv preprint arXiv:2104.04473 (2021).

[5] Huang et al. “GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism.” arXiv preprint arXiv:1811.06965 (2018).

[6] Narayanan et al. “PipeDream: Generalized Pipeline Parallelism for DNN Training.” SOSP 2019.

[7] Narayanan et al. “Memory-Efficient Pipeline-Parallel DNN Training.” ICML 2021.

[8] Shazeer et al. “The Sparsely-Gated Mixture-of-Experts Layer Noam.” arXiv preprint arXiv:1701.06538 (2017).

[9] Lepikhin et al. “GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding.” arXiv preprint arXiv:2006.16668 (2020).

[10] Fedus et al. “Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity.” arXiv preprint arXiv:2101.03961 (2021).

[11] Narang & Micikevicius, et al. “Mixed precision training.” ICLR 2018.

[12] Chen et al. 2016 “Training Deep Nets with Sublinear Memory Cost.” arXiv preprint arXiv:1604.06174 (2016).

[13] Jain et al. “Gist: Efficient data encoding for deep neural network training.” ISCA 2018.

[14] Shazeer & Stern. “Adafactor: Adaptive learning rates with sublinear memory cost.” arXiv preprint arXiv:1804.04235 (2018).

[15] Anil et al. “Memory-Efficient Adaptive Optimization.” arXiv preprint arXiv:1901.11150 (2019).

[16] Rajbhandari et al. “ZeRO: Memory Optimization Towards Training A Trillion Parameter Models Samyam.” arXiv preprint arXiv:1910.02054 (2019).

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章