Pytorch官網有個簡單的示例
https://pytorch.org/tutorials/beginner/former_torchies/parallelism_tutorial.html
其實用起來還是比較簡單的,大致如下:
from torch.nn import DataParallel
model = model.cuda()
model = DataParallel(model, list(range(torch.cuda.device_count()))).cuda()
# AttributeError: 'DataParallel' object has no attribute XXX
model.module.XXX