import numpy as np
import torch
from torch import nn
from torch. autograd import Variable
from PIL import Image
import matplotlib. pylab as plt
% matplotlib inline
卷積層
導入圖片並且轉化爲灰度圖
kaggle環境中所有本地數據都有上傳才能加載!
im = Image. open ( '../input/cat.jpg' ) . convert( 'L' )
將圖片轉化爲矩陣
im = np. array( im, dtype= 'float32' )
查看圖片,將圖片矩陣轉化爲uint8類型,uint8是無符號八位整型,表示範圍是[0, 255]的整數,colormap爲gray
plt. imshow( im. astype( 'uint8' ) , cmap= 'gray' )
<matplotlib.image.AxesImage at 0x7f301f697160>
im. shape
(121, 121)
將圖片矩陣轉化爲Tensor,並將圖片大小轉化爲卷積輸入的要求大小
print ( im. shape)
im = torch. from_numpy( im. reshape( 1 , 1 , im. shape[ 0 ] , im. shape[ 1 ] ) )
print ( im. shape)
(121, 121)
torch.Size([1, 1, 121, 121])
使用nn.Conv2d定義卷積層。參數:輸入通道數,輸出通道數,kernel大小,偏置爲0
conv1 = nn. Conv2d( 1 , 1 , 3 , bias= False )
定義kernel矩陣內容,爲3X3矩陣,中間爲8,其餘爲-1
sobel_kernel = np. array( [ [ - 1 , - 1 , - 1 ] , [ - 1 , 8 , - 1 ] , [ - 1 , - 1 , - 1 ] ] , dtype= 'float32' )
print ( sobel_kernel. shape)
(3, 3)
將卷積核大小設置與卷積層通道相一致
sobel_kernel = sobel_kernel. reshape( ( 1 , 1 , 3 , 3 ) )
print ( sobel_kernel. shape)
(1, 1, 3, 3)
將定義好得卷積核轉化爲Tensor,並賦值給卷積層權重
conv1. weight. data = torch. from_numpy( sobel_kernel)
使用卷積層對圖片進行卷積
edge1 = conv1( Variable( im) )
print ( edge1. shape)
torch.Size([1, 1, 119, 119])
用函數squeeze去掉卷積操作後的其他爲1的維數,並且將Tensor轉化爲numpy
print ( edge1. data. squeeze( ) . shape)
edge1 = edge1. data. squeeze( ) . numpy( )
torch.Size([119, 119])
將經過卷積後的圖片顯示出來
plt. imshow( edge1, cmap= 'gray' )
<matplotlib.image.AxesImage at 0x7f301b5afac8>
池化層
導入最大池化層,並定義池化大小
pool1 = nn. MaxPool2d( 2 , 2 )
打印出張量格式的圖片大小
print ( 'before max pool, image shape: {} x {}' . format ( im. shape[ 2 ] , im. shape[ 3 ] ) )
before max pool, image shape: 121 x 121
將Tensor轉化爲Variable,進行最大池化操作。並將結果轉化爲numpy
small_im1 = pool1( Variable( im) )
small_im1 = small_im1. data. squeeze( ) . numpy( )
打印出經過池化層後的圖片大小
print ( 'after max pool, image shape: {} X {}' . format ( small_im1. shape[ 0 ] , small_im1. shape[ 1 ] ) )
after max pool, image shape: 60 X 60
打印出圖片
plt. imshow( small_im1, cmap= 'gray' )
<matplotlib.image.AxesImage at 0x7f301b594438>