一、下載圖片解壓即可
鏈接:https://pan.baidu.com/s/146E7BKbQzHO1q0r5x35Dgg
提取碼:bczs
二、代碼實現
原始代碼路徑:https://github.com/pytorch/examples/tree/master/mnist,只要稍作修改即可,因爲圖片是RGB的,所以有2種方案
方案 1、直接使用RGB(修改如下代碼)
1. 將channel數修改爲3
self.conv1 = nn.Conv2d(3, 32, 3, 1)
2. 修改訓練和測試代碼
train_loader = torch.utils.data.DataLoader(
datasets.ImageFolder('d:/mnist/train',
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.ImageFolder('d:/mnist/test', transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.test_batch_size, shuffle=True, **kwargs)
方案2、RGB轉換成灰階圖(修改如下代碼)
train_loader = torch.utils.data.DataLoader(
datasets.ImageFolder('d:/mnist/train',
transform=transforms.Compose([
transforms.Grayscale(num_output_channels=1),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
])),
batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.ImageFolder('d:/mnist/test', transform=transforms.Compose([
transforms.Grayscale(num_output_channels=1),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.test_batch_size, shuffle=True, **kwargs)