PyTorch 大批量數據 如何訓練?

對於大多數的個人學習小夥伴來說,無法擁有一臺性能超強的深度學習主機,更沒有運算超羣的服務器來供自己訓練模型,但是又不得不對進行訓練時,矛盾就產生了!

在深度學習訓練中,我們經常遇到 GPU 的內存太小的問題,如果我們的數據量比較大,別說大批量(large batch size)訓練了,有時候甚至連一個訓練樣本都放不下。但是隨機梯度下降(SGD)中,如果能使用更大的 Batch Size 訓練,一般能得到更好的結果。

那麼問題來了:當 GPU 的內存不夠時,如何使用大批量(large batch size)樣本來訓練神經網絡呢?

這篇文章將以 PyTorch 爲例,講解一下幾點:

  1. 當 GPU 的內存小於 Batch Size 的訓練樣本,或者甚至連一個樣本都塞不下的時候,怎麼用單個或多個 GPU 進行訓練?
  2. 怎麼儘量高效地利用多 GPU?

1、單個或多個 GPU 進行大批量訓練

如果你也遇到過 CUDA RuntimeError: out of memory 的錯誤,那麼說明你也遇到了這個問題。

在這裏插入圖片描述


PyTorch 的開發人員都出來了,估計一臉黑線:兄弟,這不是 bug,是你內存不夠…

有一個方法可以解決這個問題:
梯度累加(accumulating gradients)。

一般在 PyTorch 中,我們是這樣來更新梯度的:

predictions = model(inputs)               			# 前向計算
loss = loss_function(predictions, labels) 			# 計算損失函數
loss.backward()                           			# 後向計算梯度
optimizer.step()                          			# 優化器更新梯度
predictions = model(inputs)               			# 用更新過的參數值進行下一次前向計算

在上看的代碼註釋中,在計算梯度的 loss.backward() 操作中,每個參數的梯度被計算出來後,都被存儲在各個參數對應的一個張量裏:parameter.grad。然後優化器就會根據這個來更新每個參數的值,就是 optimizer.step()。

而梯度累加(accumulating gradients)的基本思想就是, 在優化器更新參數前,也就是執行 optimizer.step() 前,我們進行多次梯度計算,保存在 parameter.grad 中,然後累加梯度再更新。這個在 PyTorch 中特別容易實現,因爲 PyTorch 中,梯度值本身會保留,除非我們調用 model.zero_grad() or optimizer.zero_grad()。

下面是一個梯度累加的例子,其中 accumulation_steps 就是要累加梯度的循環數:

model.zero_grad()                                   # 重置保存梯度值的張量
for i, (inputs, labels) in enumerate(training_set):
    predictions = model(inputs)                     # 前向計算
    loss = loss_function(predictions, labels)       # 計算損失函數
    loss = loss / accumulation_steps                # 對損失正則化 (如果需要平均所有損失)
    loss.backward()                                 # 計算梯度
    if (i 1) % accumulation_steps == 0:             # 重複多次前面的過程
        optimizer.step()                            # 更新梯度
        model.zero_grad()                           # 重置梯度

2、如果連一個樣本都不放下怎麼辦?

如果樣本特別大,別說 batch training,要是 GPU 的內存連一個樣本都不下怎麼辦呢?

答案是使用梯度檢查點(gradient-checkpoingting),用計算量來換內存。

基本思想就是,在反向傳播的過程中,把梯度切分成幾部分,分別對網絡上的部分參數進行更新(見下圖)。但這種方法的速度很慢,因爲要增加額外的計算量。但在某些例子上又很有用,比如訓練長序列的 RNN 模型等。

在這裏插入圖片描述

可以參考 PyTorch 官方文檔對 Checkpoint 的描述:https://pytorch.org/docs/stable/checkpoint.html

TORCH.UTILS.CHECKPOINT

NOTE:

Checkpointing is implemented by rerunning a forward-pass segment for each checkpointed segment during backward. This can cause persistent states like the RNG state to be advanced than they would without checkpointing. By default, checkpointing includes logic to juggle the RNG state such that checkpointed passes making use of RNG (through dropout for example) have deterministic output as compared to non-checkpointed passes. The logic to stash and restore RNG states can incur a moderate performance hit depending on the runtime of checkpointed operations. If deterministic output compared to non-checkpointed passes is not required, supply preserve_rng_state=False to checkpoint or checkpoint_sequential to omit stashing and restoring the RNG state during each checkpoint.

The stashing logic saves and restores the RNG state for the current device and the device of all cuda Tensor arguments to the run_fn. However, the logic has no way to anticipate if the user will move Tensors to a new device within the run_fn itself. Therefore, if you move Tensors to a new device (“new” meaning not belonging to the set of [current device + devices of Tensor arguments]) within run_fn, deterministic output compared to non-checkpointed passes is never guaranteed.

  • torch.utils.checkpoint.checkpoint(function, *args, **kwargs)

    Checkpoint a model or part of the model

    Checkpointing works by trading compute for memory. Rather than storing all intermediate activations of the entire computation graph for computing backward, the checkpointed part does not save intermediate activations, and instead recomputes them in backward pass. It can be applied on any part of a model.

    Specifically, in the forward pass, function will run in torch.no_grad() manner, i.e., not storing the intermediate activations. Instead, the forward pass saves the inputs tuple and the function parameter. In the backwards pass, the saved inputs and function is retrieved, and the forward pass is computed on function again, now tracking the intermediate activations, and then the gradients are calculated using these activation values.

    WARNING:
    Checkpointing doesn’t work with torch.autograd.grad(), but only with torch.autograd.backward().

    WARNING:
    If function invocation during backward does anything different than the one during forward, e.g., due to some global variable, the checkpointed version won’t be equivalent, and unfortunately it can’t be detected.

    Parameters:

    • function – describes what to run in the forward pass of the model or part of the model. It should also know how to handle the inputs passed as the tuple. For example, in LSTM, if user passes (activation, hidden), function should correctly use the first input as activation and the second input as hidden
    • preserve_rng_state (bool, optional, default=True) – Omit stashing and restoring the RNG state during each checkpoint.
    • args – tuple containing inputs to the function

    Returns: Output of running function on *args

  • torch.utils.checkpoint.checkpoint_sequential(functions, segments, input, **kwargs)

    A helper function for checkpointing sequential models.

    Sequential models execute a list of modules/functions in order (sequentially). Therefore, we can divide such a model in various segments and checkpoint each segment. All segments except the last will run in torch.no_grad() manner, i.e., not storing the intermediate activations. The inputs of each checkpointed segment will be saved for re-running the segment in the backward pass.
    See checkpoint() on how checkpointing works.

    WARNING:
    Checkpointing doesn’t work with torch.autograd.grad(), but only with torch.autograd.backward().

    Parameters:

    • functions – A torch.nn.Sequential or the list of modules or functions (comprising the model) to run sequentially.
    • egments – Number of chunks to create in the model
    • input – A Tensor that is input to functions
    • preserve_rng_state (bool, optional, default=True) – Omit stashing and restoring the RNG state during each checkpoint.

    Returns: Output of running functions sequentially on *inputs

    Example:

    >>> model = nn.Sequential(...)
    >>> input_var = checkpoint_sequential(model, chunks, input_var)
    

3、多 GPU 訓練方法

最簡單、最暴力、最土豪的解決辦法就是上多GPU進行訓練。PyTorch 中多 GPU 訓練的方法是使用 torch.nn.DataParallel。

非常簡單,只需要一行代碼:

torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)

該函數實現了在module級別上的數據並行使用,注意batch size要大於GPU的數量。

參數 :

  • module:需要多GPU訓練的網絡模型
  • device_ids: GPU的編號(默認全部GPU)
  • output_device:(默認是device_ids[0])
  • dim:tensors被分散的維度,默認是0
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

定義device,其中需要注意的是“cuda:0” 代表起始的 device_id 爲 0,如果直接是 “cuda”,同樣默認是從 0 開始。可以根據實際需要修改起始位置,如 “cuda:1”。

model = Model()
if torch.cuda.device_count() > 1:
  model = nn.DataParallel(model,device_ids=[0,1,2])
 
model.to(device)

這裏注意,如果是單GPU,直接model.to(device)就可以在單個GPU上訓練,但如果是多個GPU就需要用到nn.DataParallel函數,然後在進行一次to(device)。

需要注意:device_ids的起始編號要與之前定義的device中的“cuda:0”相一致,不然會報錯。

如果不定義device_ids,如model = nn.DataParallel(model),默認使用全部GPU。定義了device_ids就可以使用指定的GPU,但一定要注意與一開始定義device對應。

通過以上代碼,就可以實現網絡的多GPU訓練。

parallel_model = torch.nn.DataParallel(model) # 就是這裏!
 
predictions = parallel_model(inputs)          # 前向計算
loss = loss_function(predictions, labels)     # 計算損失函數
loss.mean().backward()                        # 計算多個GPU的損失函數平均值,計算梯度
optimizer.step()                              # 反向傳播
predictions = parallel_model(inputs)

在使用torch.nn.DataParallel 的過程中,我們經常遇到一個問題:第一個GPU的計算量往往比較大。我們先來看一下多 GPU 的訓練過程原理:

在這裏插入圖片描述

在上圖第一行第四個步驟中,GPU-1 其實彙集了所有 GPU 的運算結果。這個對於多分類問題還好,但如果是自然語言處理模型就會出現問題,導致 GPU-1 彙集的梯度過大,直接爆掉。

那麼就要想辦法實現多 GPU 的負載均衡,方法就是讓 GPU-1 不彙集梯度,而是保存在各個 GPU 上。這個方法的關鍵就是要分佈化我們的損失函數,讓梯度在各個 GPU 上單獨計算和反向傳播。這裏又一個開源的實現:https://github.com/zhanghang1989/PyTorch-Encoding。這裏是一個修改版,可以直接在我們的代碼裏調用。

實例:

from parallel import DataParallelModel, DataParallelCriterion
 
parallel_model = DataParallelModel(model)               # 並行化model
parallel_loss  = DataParallelCriterion(loss_function)   # 並行化損失函數
 
predictions = parallel_model(inputs)      				# 並行前向計算
                                          				# "predictions"是多個gpu的結果的元組
loss = parallel_loss(predictions, labels) 				# 並行計算損失函數
loss.backward()                           				# 計算梯度
optimizer.step()                          				# 反向傳播
predictions = parallel_model(inputs)

如果你的網絡輸出是多個,可以這樣分解:

output_1, output_2 = zip(*predictions)

如果有時候不想進行分佈式損失函數計算,可以這樣手動彙集所有結果:

gathered_predictions = parallel.gather(predictions)

下圖展示了負載均衡以後的原理:

在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章