[源碼解析] PyTorch 分佈式(14) --使用 Distributed Autograd 和 Distributed Optimizer

[源碼解析] PyTorch 分佈式(14) --使用 Distributed Autograd 和 Distributed Optimizer

0x00 摘要

在前面的文章之中,我們已經學習了PyTorch 分佈式的基本模塊,接下來我們通過幾篇文章來看看如何把這些模塊應用到實踐之中,順便把PyTorch分佈式邏輯整體梳理一下。本文介紹如何把分佈式自動微分和分佈式優化器結合起來訓練一個模型。

本文以 https://pytorch.org/tutorials/intermediate/rpc_tutorial.html 的部分翻譯爲基礎,加入了自己的理解。

PyTorch分佈式其他文章如下:

深度學習利器之自動微分(1)

深度學習利器之自動微分(2)

[源碼解析]深度學習利器之自動微分(3) --- 示例解讀

[源碼解析]PyTorch如何實現前向傳播(1) --- 基礎類(上)

[源碼解析]PyTorch如何實現前向傳播(2) --- 基礎類(下)

[源碼解析] PyTorch如何實現前向傳播(3) --- 具體實現

[源碼解析] Pytorch 如何實現後向傳播 (1)---- 調用引擎

[源碼解析] Pytorch 如何實現後向傳播 (2)---- 引擎靜態結構

[源碼解析] Pytorch 如何實現後向傳播 (3)---- 引擎動態邏輯

[源碼解析] PyTorch 如何實現後向傳播 (4)---- 具體算法

[源碼解析] PyTorch 分佈式(1)------歷史和概述

[源碼解析] PyTorch 分佈式(2) ----- DataParallel(上)

[源碼解析] PyTorch 分佈式(3) ----- DataParallel(下)

[源碼解析] PyTorch 分佈式(4)------分佈式應用基礎概念

[源碼解析] PyTorch分佈式(5) ------ DistributedDataParallel 總述&如何使用

[源碼解析] PyTorch分佈式(6) ---DistributedDataParallel -- 初始化&store

[源碼解析] PyTorch 分佈式(7) ----- DistributedDataParallel 之進程組

[源碼解析] PyTorch 分佈式(8) -------- DistributedDataParallel之論文篇

[源碼解析] PyTorch 分佈式(9) ----- DistributedDataParallel 之初始化

[源碼解析] PyTorch 分佈式(10)------DistributedDataParallel 之 Reducer靜態架構

[源碼解析] PyTorch 分佈式(11) ----- DistributedDataParallel 之 構建Reducer和Join操作

[源碼解析] PyTorch 分佈式(12) ----- DistributedDataParallel 之 前向傳播

[源碼解析] PyTorch 分佈式(13) ----- DistributedDataParallel 之 反向傳播

[源碼解析] PyTorch 分佈式 Autograd (1) ---- 設計

[源碼解析] PyTorch 分佈式 Autograd (2) ---- RPC基礎

[源碼解析] PyTorch 分佈式 Autograd (3) ---- 上下文相關

[源碼解析] PyTorch 分佈式 Autograd (4) ---- 如何切入引擎

[源碼解析] PyTorch 分佈式 Autograd (5) ---- 引擎(上)

[源碼解析] PyTorch 分佈式 Autograd (6) ---- 引擎(下)

[源碼解析] PyTorch分佈式優化器(1)----基石篇

[源碼解析] PyTorch分佈式優化器(2)----數據並行優化器

[源碼解析] PyTorch分佈式優化器(3)---- 模型並行

0x01 說明

首先要做一下說明,原文有兩部分:強化學習和RNN,本文只是翻譯了RNN部分。而且本文沒有完全按照原文順序進行翻譯,而是按照自己理解的思路重新組織了文章,用一種從上至下的角度來看這個系統。

本文使用RNN模型來展示如何使用RPC API構建分佈式模型並行訓練。示例RNN模型非常小,可以很容易地放入單個GPU中,但我們仍然將它的層分在兩個不同worker來之上來演示如何分佈式訓練。開發人員可以應用類似的技術在多個設備和機器上分發更大的模型。

注:在官方這些分佈式文章中,worker 有時指代分佈式系統之中所有進程,而實際訓練進程往往叫做 trainer,本文的worker 就包括一個 trainer 和 一個參數服務器。

0x02 啓動

在啓動階段,run_worker 方法會啓動一個 trainer 和 一個參數服務器,參數服務器在代碼之中沒有任何行爲。

def run_worker(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    if rank == 1:
        # 啓動了trainer
        rpc.init_rpc("trainer", rank=rank, world_size=world_size)
        # trainer 業務邏輯
        _run_trainer()
    else:
        # 啓動了參數服務器
        rpc.init_rpc("ps", rank=rank, world_size=world_size)
        # parameter server do nothing
        pass

    # block until all rpcs finish
    rpc.shutdown()


if __name__=="__main__":
    world_size = 2
    mp.spawn(run_worker, args=(world_size, ), nprocs=world_size, join=True)

具體如下圖:

           torch.multiprocessing.spawn
                      +
                      |
                      |
    +-----------------+--------------------+
    |                                      |
    |                                      |
    v                                      v
+---+---------------------+   +------------+-------------+
| "ps"          rank = 0  |   | "trainer"      rank = 1  |
|                         |   |                          |
|                         |   |                          |
|                         |   |                          |
|                         |   |                          |
|                         |   |                          |
|                         |   |                          |
|                         |   |                          |
|                         |   |                          |
|                         |   |                          |
|                         |   |                          |
|                         |   |                          |
|                         |   |                          |
|                         |   |                          |
+-------------------------+   +--------------------------+

0x03 Trainer

我們接下來看看訓練循環。初始化模型參數後,我們創建"RNNModel"和"DistributedOptimizer"。分佈式優化器將獲取參數"RRefs"的列表,查找這些參數所有的不同的 owner workers,並使用給定參數(即"lr=0.05")在每個owner worker上創建給定的本地優化器(在本例中即"SGD",您也可以使用其他本地優化器)。

在訓練循環中,它做如下操作:

  • 首先創建分佈式autograd context,這將幫助分佈式autograd引擎查找梯度和涉及的RPC send/recv 函數。
  • 然後,它像本地模型一樣開始向前傳播,並且運行分佈式向後傳播。對於分佈式後向傳播,您只需要指定根的列表(list of roots),在本例中,它是loss 張量。分佈式autograd引擎將自動遍歷分佈式計算圖並正確寫入梯度。
  • 接下來,它在分佈式優化器上運行'step'函數,該函數將與所有相關的本地優化器聯繫以更新模型參數。與本地訓練相比,一個區別是用戶不需要運行 zero_grad() ,因爲每個autograd context 都有專用的空間來存儲梯度,這樣每次迭代創建一個上下文時,來自不同迭代的梯度不會累積到同一組張量之上。

具體代碼如下:

def run_trainer():
    batch = 5
    ntoken = 10
    ninp = 2
    nhid = 3
    nindices = 3
    nlayers = 4
    hidden = (
        torch.randn(nlayers, nindices, nhid),
        torch.randn(nlayers, nindices, nhid)
    )

    model = rnn.RNNModel('ps', ntoken, ninp, nhid, nlayers)

    # setup distributed optimizer
    opt = DistributedOptimizer( # 創建分佈式優化器
        optim.SGD,
        model.parameter_rrefs(),
        lr=0.05,
    )

    criterion = torch.nn.CrossEntropyLoss()

    def get_next_batch():
        for _ in range(5):
            data = torch.LongTensor(batch, nindices) % ntoken
            target = torch.LongTensor(batch, ntoken) % nindices
            yield data, target

    # train for 10 iterations
    for epoch in range(10):
        for data, target in get_next_batch():
            # create distributed autograd context
            with dist_autograd.context() as context_id: # 創建分佈式上下文
                hidden[0].detach_()
                hidden[1].detach_()
                output, hidden = model(data, hidden)
                loss = criterion(output, target)
                # run distributed backward pass
                dist_autograd.backward(context_id, [loss]) # 執行分佈式後向傳播
                # run distributed optimizer
                opt.step(context_id) # 分佈式優化器進行更新
                # not necessary to zero grads since they are
                # accumulated into the distributed autograd context
                # which is reset every iteration.
        print("Training epoch {}".format(epoch))

邏輯擴展爲:

           torch.multiprocessing.spawn
                      +
                      |
                      |
    +-----------------+--------------------+
    |                                      |
    |                                      |
    v                                      v
+---+---------------------+   +------------+-----------------------------------+
| "ps"          rank = 0  |   | "trainer"      rank = 1                        |
|                         |   |                                                |
|                         |   |                                                |
|                         |   |                                                |
|                         |   |    model = rnn.RNNModel('ps')                  |
|                         |   |                                                |
|                         |   |                                                |
|                         |   |    dist_autograd.backward(context_id, [loss])  |
|                         |   |                                                |
|                         |   |                                                |
|                         |   |    DistributedOptimizer.step(context_id)       |
|                         |   |                                                |
|                         |   |                                                |
|                         |   |                                                |
+-------------------------+   +------------------------------------------------+

0x04 模型

我們接下來看看具體模型。

4.1 組件

RNN模型設計借鑑了PyTorch示例庫 example中的word語言模型,該模型包含三個主要組件:嵌入表、LSTM層和解碼器。

4.1.1 參考代碼

我們有必要貼出原始參考代碼來比對,可以看到,Embedding 和 Linear 都是作爲 RNNModel 的成員變量存在,整個 RNNModel 耦合的非常緊密。

class RNNModel(nn.Module):
    """Container module with an encoder, a recurrent module, and a decoder."""

    def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5, tie_weights=False):
        super(RNNModel, self).__init__()
        self.ntoken = ntoken
        self.drop = nn.Dropout(dropout)
        self.encoder = nn.Embedding(ntoken, ninp) # 嵌入表成員變量
        if rnn_type in ['LSTM', 'GRU']:
            self.rnn = getattr(nn, rnn_type)(ninp, nhid, nlayers, dropout=dropout)
        else:
            nonlinearity = {'RNN_TANH': 'tanh', 'RNN_RELU': 'relu'}[rnn_type]
            self.rnn = nn.RNN(ninp, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout)
        self.decoder = nn.Linear(nhid, ntoken) # 解碼器成員變量

			  # 省略後部分代碼

4.1.2 分佈式修改

我們看看如何依據分佈式的特點來對上面模型進行修改。

下面的代碼將嵌入表(embedding table)和解碼器包裝到子模塊(sub-modules)中,以便將它們的構造函數傳遞給RPC API。在EmbeddingTable子模塊中,我們有意將嵌入層放在GPU上以做演示。在v1.4中,RPC總是在目標工作進程上創建CPU張量參數或返回值。如果函數採用GPU張量,則需要顯式地將其移動到適當的設備。

class EmbeddingTable(nn.Module):
    r"""
    Encoding layers of the RNNModel
    """
    def __init__(self, ntoken, ninp, dropout):
        super(EmbeddingTable, self).__init__()
        self.drop = nn.Dropout(dropout)
        self.encoder = nn.Embedding(ntoken, ninp).cuda()
        self.encoder.weight.data.uniform_(-0.1, 0.1)

    def forward(self, input):
        return self.drop(self.encoder(input.cuda()).cpu()


class Decoder(nn.Module):
    def __init__(self, ntoken, nhid, dropout):
        super(Decoder, self).__init__()
        self.drop = nn.Dropout(dropout)
        self.decoder = nn.Linear(nhid, ntoken)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-0.1, 0.1)

    def forward(self, output):
        return self.decoder(self.drop(output))

4.2 RNN 模型

前面提到,爲了實現分佈式模型並行訓練,開發人員可以將模型劃分爲子模塊。有了上面的子模塊,我們現在可以使用RPC將它們組合在一起,創建一個RNN模型。我們將調用RPC遠程創建子模塊實例,並在必要時使用RRef查找它們。正如您在下面的代碼中所看到的,它看起來非常類似於單機模型並行訓練。主要區別在於用RPC函數替換 Tensor.to(device)

ps表示一個參數服務器,它承載嵌入表和解碼器的參數。構造函數使用remote API在參數服務器上創建EmbeddingTable對象和解碼器對象,並在本地創建LSTM子模塊。

在向前傳播過程中,trainer使用EmbeddingTable RRef查找遠程子模塊,並使用RPC將輸入數據傳遞給EmbeddingTable並獲取查找結果。然後,它通過本地LSTM層運行嵌入,最後使用另一個RPC將輸出發送到解碼器子模塊。

class RNNModel(nn.Module):
    def __init__(self, ps, ntoken, ninp, nhid, nlayers, dropout=0.5):
        super(RNNModel, self).__init__()

        # setup embedding table remotely
        self.emb_table_rref = rpc.remote(ps, EmbeddingTable, args=(ntoken, ninp, dropout))
        # setup LSTM locally
        self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)
        # setup decoder remotely
        self.decoder_rref = rpc.remote(ps, Decoder, args=(ntoken, nhid, dropout))

    def forward(self, input, hidden):
        # pass input to the remote embedding table and fetch emb tensor back
        emb = _remote_method(EmbeddingTable.forward, self.emb_table_rref, input)
        output, hidden = self.rnn(emb, hidden)
        # pass output to the rremote decoder and get the decoded output back
        decoded = _remote_method(Decoder.forward, self.decoder_rref, output)
        return decoded, hidden

因此,邏輯圖拓展如下:

                 torch.multiprocessing.spawn
                            +
                            |
                            |
          +-----------------+--------------------+
          |                                      |
          |                                      |
          v                                      v
+---------+------------+   +---------------------+-------------------------------------+
|"ps"        rank = 0  |   | "trainer"      rank = 1                                   |
|                      |   |                                                           |
|                      |   |   model = rnn.RNNModel('ps')                              |
|                      |   |                                                           |
| +---------------+    |   |   +---------------------------------------+               |
| |EmbeddingTable |    |   |   | RNNModel                              |               |
| |               |    |   |   |                                       |               |
| |               | <--------------+ self.emb_table_rref               |               |
| +---------------+    |   |   |                                       |               |
| +---------------+    |   |   |                                       |               |
| |Decoder        | <--------------+ self.decoder_rref                 |               |
| |               |    |   |   |                                       |               |
| |               |    |   |   |     self.rnn = LSTM                   |               |
| |               |    |   |   |                                       |               |
| +---------------+    |   |   +---------------------------------------+               |
|                      |   |                                                           |
|                      |   |                                                           |
|                      |   |   forward() {                                             |
|                      |   |       emb = _remote_method(EmbeddingTable.forward, input) |
|                      |   |       output, hidden = self.rnn(emb, hidden)              |
+----------------------+   |       decoded = _remote_method(Decoder.forward, output)   |
                           |   }                                                       |
                           |                                                           |
                           |                                                           |
                           |   dist_autograd.backward(context_id, [loss])              |
                           |                                                           |
                           |                                                           |
                           |   DistributedOptimizer.step(context_id)                   |
                           |                                                           |
                           +-----------------------------------------------------------+


4.3 分佈式優化器

在介紹分佈式優化器之前,讓我們添加一個helper函數,此函數用來生成模型參數的RRefs列表,分佈式優化器將使用該列表。在本地訓練中,應用程序可以調用 Module.parameters()來獲取對所有參數張量的引用,並將其傳遞給本地優化器進行後續更新。但是,由於某些參數存在於遠程機器上,因此同一API在分佈式訓練場景中不起作用。因此,分佈式優化器不採用參數"張量"列表,而是採用"RRef"列表,本地和遠程模型參數的每個模型參數都有一個"RRef"。helper函數非常簡單,只需調用Module.parameters() 並在每個參數上創建一個本地'RRef'。

def _parameter_rrefs(module):
    param_rrefs = []
    for param in module.parameters():
        param_rrefs.append(RRef(param))
    return param_rrefs

然後,由於RNNModel包含三個子模塊,我們需要調用 _parameter_rrefs 三次,並將其封裝到另一個helper函數中。

class RNNModel(nn.Module):
    ...
    def parameter_rrefs(self):
        remote_params = []
        # get RRefs of embedding table
        remote_params.extend(_remote_method(_parameter_rrefs, self.emb_table_rref))
        # create RRefs for local parameters
        remote_params.extend(_parameter_rrefs(self.rnn))
        # get RRefs of decoder
        remote_params.extend(_remote_method(_parameter_rrefs, self.decoder_rref))
        return remote_params

在 trainer 之中,使用如下來生成分佈式優化器,這樣就把遠端的一些參數作爲優化對象。

# setup distributed optimizer
opt = DistributedOptimizer(
    optim.SGD,
    model.parameter_rrefs(),
    lr=0.05,
)

我們最後拓展如下:

  • (1) RNNModel 的 emb_table_rref 成員變量指向參數服務器上的EmbeddingTable。
  • (2) RNNModel 的 decoder_rref 成員變量指向參數服務器上的Decoder。
  • (3) RNNModel 的 rnn 成員變量指向本地的LSTM。
  • DistributedOptimizer 內部的三個待優化變量分別指向:4) 參數服務器上的EmbeddingTable 的 參數,5) 參數服務器上的Decoder 的參數,6) 本地LSTM的參數。

分別對應下圖上的數字。

                 torch.multiprocessing.spawn
                            +
                            |
                            |
            +---------------+--------------------+
            |                                    |
            |                                    |
            v                                    v
  +---------+------------+ +---------------------+----------------------------------------+
  |"ps"        rank = 0  | | "trainer"                                         rank = 1   |
  |                      | |                                                              |
  |                      | |   model = rnn.RNNModel('ps')                                 |
  |                      | |                                                              |
  |  +---------------+   | |   +---------------------------------------+                  |
  |  |EmbeddingTable |   | |   | RNNModel                              |                  |
+--->+               |   | | 1 |                                       |                  |
| |  |               +<------------+ self.emb_table_rref               |    +------+      |
| |  +---------------+   | |   |                            3          |    |LSTM  |  6   |
| |                      | |   |     self.rnn +---------------------------->+      +<---+ |
| |  +---------------+   | | 2 |                                       |    |      |    | |
| |  |Decoder        +<------------+ self.decoder_rref                 |    +------+    | |
| |  |               |   | |   |                                       |                | |
| |  |               |   | |   +---------------------------------------+                | |
| |  |               |   | |                                                            | |
| |  +------+--------+   | |   forward() {                                              | |
| |         ^            | |       emb = _remote_method(EmbeddingTable.forward, input)  | |
| |         |            | |       output, hidden = self.rnn(emb, hidden)               | |
| |         |            | |       decoded = _remote_method(Decoder.forward, output)    | |
| |         |            | |   }                                                        | |
| +----------------------+ |                                                            | |
|           |              |   dist_autograd.backward(context_id, [loss])               | |
|           |              |                                                            | |
| 5         | 4            |  +------------------------------------------------------+  | |
|           |              |  | DistributedOptimizer                                 |  | |
|           |              |  |                                                      |  | |
|           |              |  |     remote_optimizers = [                            |  | |
+-------------------------------------------------------+ optim_rref1,               |  | |
            |              |  |                           optim_rref2+------------------+ |
            +-------------------------------------------+ optim_rref3                |    |
                           |  |                                                      |    |
                           |  |                          ]                           |    |
                           |  |     step(context_id)                                 |    |
                           |  +------------------------------------------------------+    |
                           +--------------------------------------------------------------+

手機如下:

4.4 比對

因爲前面提到:分佈式模型並行訓練看起來非常類似於單機模型並行訓練。主要區別在於用RPC函數替換 Tensor.to(device)。我們用GPU替代參數服務器,把上圖大致修改下做一下對比,可能不是非常確切,但是大家可以看出來分佈式訓練的關鍵點。

  +----------------------+ +-------------------------------------------------------------+
  | GPU                  | | CPU                                                rank = 0 |
  |                      | |                                                             |
  |                      | |   model = rnn.RNNModel()                                    |
  |                      | |                                                             |
  |  +---------------+   | |   +---------------------------------------+                 |
  |  |EmbeddingTable |   | |   | RNNModel                              |                 |
+--->+               |   | | 1 |                                       |                 |
| |  |               +<------------+ self.emb_table_rref               |   +------+      |
| |  +---------------+   | |   |                            3          |   |LSTM  |  6   |
| |                      | |   |     self.rnn +--------------------------->+      +<---+ |
| |  +---------------+   | | 2 |                                       |   |      |    | |
| |  |Decoder        +<------------+ self.decoder_rref                 |   +------+    | |
| |  |               |   | |   |                                       |               | |
| |  |               |   | |   +---------------------------------------+               | |
| |  |               |   | |                                                           | |
| |  +------+--------+   | |   forward() {                                             | |
| |         ^            | |       emb = EmbeddingTable.forward(input)                 | |
| |         |            | |       output, hidden = self.rnn(emb, hidden)              | |
| |         |            | |       decoded = Decoder.forward(output)                   | |
| |         |            | |   }                                                       | |
| +----------------------+ |                                                           | |
|           |              |   loss.backward()                                         | |
|           |              |                                                           | |
| 5         | 4            |  +----------------------------------------+               | |
|           |              |  | Optimizer                              |               | |
|           |              |  |                                        |               | |
|           |              |  |          param_groups = [              |               | |
+-------------------------------------------------------+ optim_rref1, |               | |
            |              |  |                                        |               | |
            |              |  |                           optim_rref2+-----------------+ |
            |              |  |                                        |                 |
            +-------------------------------------------+ optim_rref3  |                 |
                           |  |                          ]             |                 |
                           |  |          step()                        |                 |
                           |  |                                        |                 |
                           |  +----------------------------------------+                 |
                           +-------------------------------------------------------------+

手機如下:

0xFF 參考

GETTING STARTED WITH DISTRIBUTED RPC FRAMEWORK

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章