【Pytorch】RuntimeError: arguments are located on different GPUs

0x00 前言

Pytorch裏使用optimizer的時候,由於其會記錄step等信息,
有時會希望將optimizer的內容記錄下來,以備之後繼續使用,
那麼自然而然的會想到使用API中自帶的
torch.save(object, path)
torch.load(path)

再配合上
optimizer.state_dict()
optimizer.load_state_dict(obj)
來實現這一需求了~

於是,大家自然而然地會自信滿滿敲出如下這樣的語句——

torch.save(optimizer.state_dict(), path)
optimizer.load_state_dict(torch.load(path))

並收穫如下的Error——

RuntimeError                              Traceback (most recent call last)
<ipython-input-160-19f8d61b5e53> in <module>()
     37         optimizer.zero_grad()
     38         loss.backward()
---> 39         optimizer.step()
     40         print(model.state_dict()['linear_layer.weight'])
     41 

/usr/local/anaconda2/lib/python2.7/site-packages/torch/optim/adam.pyc in step(self, closure)
     63 
     64                 # Decay the first and second moment running average coefficient
---> 65                 exp_avg.mul_(beta1).add_(1 - beta1, grad)
     66                 exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
     67 

RuntimeError: arguments are located on different GPUs at /opt/conda/conda-bld/pytorch_1503966894950/work/torch/lib/THC/generated/../generic/THCTensorMathPointwise.cu:215

0x01 解決方案

二話不說上解決方案是我的習慣

# Load from dict
optimizer.load_state_dict(check_point['optim'])

# Load from file
optimizer.load_state_dict(torch.load(optim_path))

# Add this
for state in optimizer.state.values():
    for k, v in state.items():
        print (type(v))
        if torch.is_tensor(v):
            state[k] = v.cuda(cuda_id)

0x02 原理解釋

然後在慢慢的講爲啥子~
首先,這個方案是我在Issue中翻看到的:
Thanks to pytorch/issues/2830

可以這麼理解,舉例說明,雖說你之前是放在GPU3上的,數據類型叫做 cuda.Tensor(GPU 3),
但是天曉得你這個GPU3是哪臺機器上的GPU3哦,機器問了一下GPU3:是不是你家的啊,
GPU3看了一眼計算完被打掃乾淨的戰場,已經空無一物——“不是吧,我家沒人啊”,
然後就委婉的拒絕了它。

所以,我們可以對load完畢的optimizer逐個詢問,只要是個tensor,我們就再把它介紹給GPU3一次~

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章