1.多线程数据处理
在Pytorch入门中,经常会用CIFAR10数据集,用它来处理数据集,此时在Windows下,如果线程数目大于1,则运行时会出现错误,而在Linux下这不会产生错误
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)
最好的方法是将所有操作包装在函数中,然后在if __name__ == '__main__'
子句中调用它们:
# Imports for dataset generation, training, etc
def load_datasets(...):
# Code to load the datasets with multiple workers
def train(...):
# Code to train the model
if __name__ == '__main__':
load_datasets()
train()