Pytorch官网有个简单的示例
https://pytorch.org/tutorials/beginner/former_torchies/parallelism_tutorial.html
其实用起来还是比较简单的,大致如下:
from torch.nn import DataParallel
model = model.cuda()
model = DataParallel(model, list(range(torch.cuda.device_count()))).cuda()
# AttributeError: 'DataParallel' object has no attribute XXX
model.module.XXX