Pytorch使用本地圖片實現mnist

一、下載圖片解壓即可

鏈接:https://pan.baidu.com/s/146E7BKbQzHO1q0r5x35Dgg 
提取碼:bczs

二、代碼實現

原始代碼路徑:https://github.com/pytorch/examples/tree/master/mnist,只要稍作修改即可,因爲圖片是RGB的,所以有2種方案

方案 1、直接使用RGB(修改如下代碼)

1. 將channel數修改爲3

self.conv1 = nn.Conv2d(3, 32, 3, 1)

2. 修改訓練和測試代碼

    train_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder('d:/mnist/train',
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=args.batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder('d:/mnist/test', transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=args.test_batch_size, shuffle=True, **kwargs)

 

方案2、RGB轉換成灰階圖(修改如下代碼)

    train_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder('d:/mnist/train',
                       transform=transforms.Compose([
                           transforms.Grayscale(num_output_channels=1),
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,)),
                       ])),
        batch_size=args.batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder('d:/mnist/test', transform=transforms.Compose([
                           transforms.Grayscale(num_output_channels=1),
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=args.test_batch_size, shuffle=True, **kwargs)

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章