MXNet學習5——Data Parallelism with Multi-devices

概要

上一篇文章講到Mixed Programing,主要描述了混合Symbol和NDArray從頭搭建網絡並訓練,這一篇緊接上文介紹MXNet中如何設計程序並行訓練。官方代碼請看這裏

正文

MXNet是可以自動將計算並行化的,如果有多個設備可用。這使得在MXNet上開發並行程序與跟串行程序一樣簡單。
下面我們展示如何使用多個設備(比如CPU,GPU)開發一個數據並行的訓練程序。這裏一個設備(device)表示一個包含自身內存的計算資源。可以是一個GPU也可以是所有CPUs,需要注意的是即使機器有多個CPU,但是使用mx.cpu()可以將所有CPU的資源看成一個整體,對上層是透明的。

下面nvidia的一張圖,展示設備的內存結構以及數據如何在設備之間通信
這裏寫圖片描述

假設每一次迭代我們訓練一個總量爲n 的minibatch,在數據並行模式中,我們將這些batch分配按照不同設備的性能給所有的設備。每個設備計算一個batch的梯度,然後所有的梯度將被合併。
我們將上一篇裏的代碼擴展下,可以再多個設備下訓練。新函數的參數包括一個網絡,數據迭代生成器,一系列設備以及對應的性能值。

"""
train方法訓練網絡參數,最後返回的是準確率
data是lambda函數 
參數從dev[0]複製到其他dev
不同設備計算得到的梯度在dev[0]處求和
data_shape=[batch_size,num_features]
devs是存儲設備的list,devs_power是對應的性能
workloads是不同設備分配的數據,總和就是batch_size
round是四捨五入函數
zip將兩個list構建成一個dict,前一個參數是key,後一個是value
"""
def train(network, data_shape, data, devs, devs_power):    
    # partition the batch into each device
    batch_size = float(data_shape[0])
    workloads = [int(round(batch_size/sum(devs_power)*p)) for p in devs_power]
    print('workload partition: ', zip(devs, workloads))
    # create an executor for each device
    ###  [p]+data_shape[1:]=[p, num_features]
    ### data是定義數據的shape,每個設備的訓練數據個數(總共batch_size)
    exs = [network.simple_bind(ctx=d, data=tuple([p]+data_shape[1:])) for d, p in zip(devs, workloads)]
    args = [dict(zip(network.list_arguments(), ex.arg_arrays)) for ex in exs]    
    # initialize weight on dev 0
    for name in args[0]:
        arr = args[0][name]
        if 'weight' in name:
            arr[:] = mx.random.uniform(-0.1, 0.1, arr.shape)
        if 'bias' in name:
            arr[:] = 0
    # run 50 iterations
    learning_rate = 0.1 
    acc = 0
    for i in range(50):
        # broadcast weight from dev 0 to all devices
        for j in range(1, len(devs)):
            for name, src, dst in zip(network.list_arguments(), exs[0].arg_arrays, exs[j].arg_arrays):
                if 'weight' in name or 'bias' in name:
                    src.copyto(dst)
        # get data    
        ### data是lambda函數             
        x, y = data() 
        for j in range(len(devs)):
            # partition and assign data
            ### idx是一個range, 如果sum(workloads[:1])=12,sum(workloads[:1])=15,則值爲range(12,15),該設備計算的數據是第12,13,14個數據(總共是batch_size數)。注意list(range(12,15))的意思
            idx = range(sum(workloads[:j]), sum(workloads[:j+1]))
            ### x是一個batch中所有的數據,x[idx,:]與args[j]['data']的shape相同
            ### 此處是根據每個設備裝載的數據量填充數據
            args[j]['data'][:] = x[idx,:].reshape(args[j]['data'].shape)
            args[j]['out_label'][:] = y[idx].reshape(args[j]['out_label'].shape)
            # forward and backward
            exs[j].forward(is_train=True)
            exs[j].backward()
            # sum over gradient on dev 0
            if j > 0:
                for name, src, dst in zip(network.list_arguments(), exs[j].grad_arrays, exs[0].grad_arrays):
                    if 'weight' in name or 'bias' in name:
                    ### as_in_context(context),是在不同設備上傳遞參數使用的,將自身的參數複製一份到context上
                        dst += src.as_in_context(dst.context)
        # update weight on dev 0        
        for weight, grad in zip(exs[0].arg_arrays, exs[0].grad_arrays):            
            weight[:] -= learning_rate * (grad / batch_size)
        # monitor
        if i % 10 == 0:
        ### np.concatenate 將兩個數組拼接
            pred = np.concatenate([mx.nd.argmax_channel(ex.outputs[0]).asnumpy() for ex in exs])
            acc = (pred == y).sum() / batch_size
            print('iteration %d, accuracy %f' % (i, acc))
    return acc

batch_size = 100
acc = train(net, [batch_size, num_features], lambda : toy_data.get(batch_size), [mx.cpu(), mx.gpu()], [1, 5])
assert acc > 0.95, "Low training accuracy."

(‘workload partition: ‘, [(cpu(0), 17), (gpu(0), 83)])
iteration 0, accuracy 0.170000
iteration 10, accuracy 1.000000
iteration 20, accuracy 1.000000
iteration 30, accuracy 1.000000
iteration 40, accuracy 1.000000

Note that the previous network is too small to see any performance benefits moving to multiple devices on such a network. Now we consider use a slightly more complex network: LeNet-5 for hands digits recognition. We first define the network.
上述的例子網絡比較小所以體現不出數據並行的優勢,接下來考慮一個稍微複雜的網絡:LeNet-5手寫數字識別。下面首先來定義網絡

需要注意的是以下代碼我沒有跑成功,無法獲取MNIST的數據,另外Centos7 python3.5死活裝不上sklearn,還沒有解決

def lenet():
    data = mx.sym.Variable('data')
    # first conv
    conv1 = mx.sym.Convolution(data=data, kernel=(5,5), num_filter=20)
    tanh1 = mx.sym.Activation(data=conv1, act_type="tanh")
    pool1 = mx.sym.Pooling(data=tanh1, pool_type="max",
                           kernel=(2,2), stride=(2,2))
    # second conv
    conv2 = mx.sym.Convolution(data=pool1, kernel=(5,5), num_filter=50)
    tanh2 = mx.sym.Activation(data=conv2, act_type="tanh")
    pool2 = mx.sym.Pooling(data=tanh2, pool_type="max",
                           kernel=(2,2), stride=(2,2))
    # first fullc
    flatten = mx.sym.Flatten(data=pool2)
    fc1 = mx.sym.FullyConnected(data=flatten, num_hidden=500)
    tanh3 = mx.sym.Activation(data=fc1, act_type="tanh")
    # second fullc
    fc2 = mx.sym.FullyConnected(data=tanh3, num_hidden=10)
    # loss
    lenet = mx.sym.SoftmaxOutput(data=fc2, name='out')
    return lenet
mx.viz.plot_network(lenet(), shape={'data':(128,1,28,28)}).view()

這裏寫圖片描述

from sklearn.datasets import fetch_mldata
import numpy as np 
import matplotlib.pyplot as plt

class MNIST:
    def __init__(self):
        mnist = fetch_mldata('MNIST original')
        ### np.random.permutation將數組中的數據按照shape隨機交換位置
        p = np.random.permutation(mnist.data.shape[0])
        self.X = mnist.data[p]
        self.Y = mnist.target[p]
        self.pos = 0        
    def get(self, batch_size):
        p = self.pos
        self.pos += batch_size
        return self.X[p:p+batch_size,:], self.Y[p:p+batch_size]
    def reset(self):
        self.pos = 0        
    def plot(self):
        for i in range(10):
            plt.subplot(1,10,i+1)
            plt.imshow(self.X[i].reshape((28,28)), cmap='Greys_r')
            plt.axis('off')
        plt.show()

mnist = MNIST()
mnist.plot()

這裏寫圖片描述

接下來在單個GPU下訓練lenet

# @@@ AUTOTEST_OUTPUT_IGNORED_CELL
import time
batch_size = 1024
shape = [batch_size, 1, 28, 28]
mnist.reset()
tic = time.time()
acc = train(lenet(), shape, lambda:mnist.get(batch_size), [mx.gpu(),], [1,])
assert acc > 0.8, "Low training accuracy."
print('time for train lenent on cpu %f sec' % (time.time() - tic))
('workload partition: ', [(gpu(0), 1024)])

iteration 0, accuracy 0.071289
iteration 10, accuracy 0.815430
iteration 20, accuracy 0.896484
iteration 30, accuracy 0.912109
iteration 40, accuracy 0.932617
time for train lenent on cpu 2.708110 sec

接下來嘗試多個GPU訓練,最多需要4個GPU

# @@@ AUTOTEST_OUTPUT_IGNORED_CELL
for ndev in (2, 4):
    mnist.reset()
    tic = time.time()
    acc = train(lenet(), shape, lambda:mnist.get(batch_size), 
          [mx.gpu(i) for i in range(ndev)], [1]*ndev)
    assert acc > 0.9, "Low training accuracy."
    print('time for train lenent on %d GPU %f sec' % (
            ndev, time.time() - tic))

(‘workload partition: ‘, [(gpu(0), 512), (gpu(1), 512)])
iteration 0, accuracy 0.104492
iteration 10, accuracy 0.741211
iteration 20, accuracy 0.876953
iteration 30, accuracy 0.914062
iteration 40, accuracy 0.924805
time for train lenent on 2 GPU 1.623732 sec
(‘workload partition: ‘, [(gpu(0), 256), (gpu(1), 256), (gpu(2), 256), (gpu(3), 256)])
iteration 0, accuracy 0.092773
iteration 10, accuracy 0.777344
iteration 20, accuracy 0.887695
iteration 30, accuracy 0.908203
iteration 40, accuracy 0.916992
time for train lenent on 4 GPU 1.086430 sec

可以看到,使用多個GPU能加快速度。加速的效果並不好,這是因爲網絡仍然是很簡單,GPU之間數據通信的花費相對比較大。我們使用state-of-the-art的網絡中看到了更好的結果。下面的圖展示了加速的效果
這裏寫圖片描述


再次註明,本篇文章最後部分代碼沒有實踐,歡迎討論

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章