Pytorch1.0.0系列解決方案RuntimeError: arguments are located on different GPUs at /pytorch/aten/src/THC/generic/THCTensorIndex.cu:519
報錯信息
RuntimeError: arguments are located on different GPUs at /pytorch/aten/src/THC/generic/THCTensorIndex.cu:519
Pytorch1.0.0當中的多GPU使用
# 絕對誤差損失函數測試
import os
import re
import torch
import torch.nn as nn
#實現例如: python filename.py --gpu_id=0,1,2,3,4,5,6,7 的參數傳遞
parser = argparse.ArgumentParser(description='manual')
parser.add_argument('--gpu_id', help="GPU_ID", type=str, default = "0,1,2,3,4")
USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda:"+re.split(r",",args.gpu_id)[0] if USE_CUDA else "cpu")
gpu_id = list(map(int, re.split(r",",args.gpu_id)))
# 需要強調的是,所有.cuda()均需要用.to(device)來替代,否則在運行程序的時候就會報錯:
# RuntimeError: arguments are located on different GPUs at /pytorch/aten/src/THC/generic/THCTensorIndex.cu:519
class yourmodel(nn.Module):
def __init__(self,
args1 = default1,
argsn = defaultn):
...
...
def train()
...
def predict()
...
...
if __name__ == '__main__':
model = yourmodel()
model = torch.nn.DataParallel(model, device_ids = gpu_id)
model.to(device)
if args.run_type == 'train':
#seq.train() # 單GPU
model.module.train() # 加上.module
elif args.run_type == 'predict':
#seq.predict() # 單GPU
model.module.predict() # 加上.module
如果出現報錯RuntimeError: arguments are located on different GPUs at
可能的原因1:
在網絡外要用到網絡中的子模塊,卻沒有加上.module,例如:
if __name__ == '__main__':
model = yourmodel()
model = torch.nn.DataParallel(model, device_ids = gpu_id)
model.to(device)
if args.run_type == 'train':
seq.train() # 單GPU
elif args.run_type == 'predict':
seq.predict() # 單GPU
正確的做法:
if __name__ == '__main__':
model = yourmodel()
model = torch.nn.DataParallel(model, device_ids = gpu_id)
model.to(device)
if args.run_type == 'train':
seq.module.train() # 多GPU
elif args.run_type == 'predict':
seq.module.predict() # 多GPU
可能的原因2:
某一處沒有使用".to(device)",例如:
if self.use_cuda:
decoder_input = decoder_input.cuda()
decoder_context = decoder_context.cuda()
正確的做法:
if self.use_cuda:
decoder_input = decoder_input.to(device)
decoder_context = decoder_context.to(device)