Learning both Weights and Connections for Efficient Neural Networks 論文pytorch復現

網絡的定義以及初次訓練

``````##############先導入需要的包###############################
import torch
import torch.nn as nn
from torchvision import datasets, transforms
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
##############導入手寫數字體數據###############################
batch_size  = 128
transform=transforms.Compose([
transforms.ToTensor(),
]))
batch_size=batch_size,
shuffle=True,num_workers = 20)
##############顯示手寫數字體數據###############################
for i in range(10):
plt.figure()
##############定義網絡Lenet-300-100網絡###############################
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(28*28, 300)
self.fc2 = nn.Linear(300, 100)
self.fc3 = nn.Linear(100, 10)
def forward(self, x):
x = x.view(-1,28*28)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)

return F.log_softmax(x, dim=1)
def train(model, device, train_loader, optimizer, epoch, log_interval=100):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
output = model(data)

loss = F.nll_loss(output, target)

loss.backward()
optimizer.step()
if batch_idx % log_interval == 0:
print("Train Epoch: {} [{}/{} ({:0f}%)]\tLoss: {:.6f}".format(
100. * batch_idx / len(train_loader), loss.item()
))
model.eval()
test_loss = 0
correct = 0
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()

print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
##############初次訓練###############################
lr = 0.01
momentum = 0.25
torch.manual_seed(53113)
batch_size = test_batch_size = 128
kwargs = {'num_workers': 40, 'pin_memory': True} if use_cuda else {}
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=True, **kwargs)
datasets.MNIST('./mnist_data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=test_batch_size, shuffle=True, **kwargs)

model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
epochs = 20
for epoch in range(1, epochs + 1):
``````

訓練好網絡信息的顯示

``````#####################獲取一張圖像信息#################################
output = model.forward(image)
pred = output.argmax(dim=1, keepdim=True)
print(pred)  #預測值
#####################顯示這張圖像#################################
plt.figure()
plt.imshow(image,cmap="gray")
plt.show()
#####################圖像預處理#################################
image = image.reshape(-1,28*28)
image = image.data.numpy()
image_count = image.copy()
#####################計算圖像非零點個數#################################
image_count[image_count != 0] = 1
count = np.sum(image_count)
print(count)
#####################轉化到numpy對數據進行處理#################################
fc1 = model.fc1.weight.data.cpu().numpy()
fc2 = model.fc2.weight.data.cpu().numpy()
fc3 = model.fc3.weight.data.cpu().numpy()
hidden1 = image.dot(fc1.T)
hidden1 = np.maximum(0, hidden1)   #激活
hidden2 = hidden1.dot(fc2.T)
hidden2 = np.maximum(0, hidden2)
hidden3 = hidden2.dot(fc3.T)
#####################第一層神經元激活的個數#################################
hidden1_count = hidden1.copy()
hidden1_count[hidden1_count > 0] = 1
count = np.sum(hidden1_count)
print(count)
#####################第二層神經元激活的個數#################################
hidden2_count = hidden2.copy()
hidden2_count[hidden2_count > 0] = 1
count = np.sum(hidden2_count)
print(count)
#####################輸出預測結果#################################
out = np.exp(hidden3)
out = out / np.sum(out)
print(np.argmax(out))
#####################第一層神經元參數信息的圖像深度顯示，從這裏可以觀測出權重的變化#################################
fc1_plt =  np.abs(fc1)
print("min",np.min(fc1_plt),"max",np.max(fc1_plt))
plt.figure(figsize=(50,50))
im = plt.imshow(fc1_plt, vmin = np.min(fc1_plt), vmax =  np.max(fc1_plt) ,cmap = 'seismic')
plt.show()
``````

剪枝操作

``````##########################函數定義###########################################
def expand_model(model, layers=torch.Tensor()):
for layer in model.children():
layers = torch.cat((layers.view(-1), layer.weight.view(-1)))  #將所有的參數拼接在一起
return layers

def calculate_threshold(model, rate):   #求取所有參數的閾值所在的數值大小
empty = torch.Tensor()
if torch.cuda.is_available():
empty = empty.cuda()
pre_abs = expand_model(model, empty)    #獲取所有的參數爲一行
weights = torch.abs(pre_abs)            #求絕對值

return np.percentile(weights.detach().cpu().numpy(), rate)
def prune(model, threshold):
model.fc1.weight.data = torch.mul(torch.gt(torch.abs(model.fc1.weight.data), threshold), model.fc1.weight.data)
model.fc2.weight.data = torch.mul(torch.gt(torch.abs(model.fc2.weight.data), threshold), model.fc2.weight.data)
model.fc3.weight.data = torch.mul(torch.gt(torch.abs(model.fc3.weight.data), threshold), model.fc3.weight.data)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
for epoch in range(1, epochs + 1):
##########################按百分比計算剪枝的閾值###########################################
threshold = calculate_threshold(model, 96)
print(threshold)
##########################進行剪枝操作並計算測試正確率###########################################
prune(model, threshold)
##########################剪枝操作後進行再次訓練###########################################
``````

權重柱狀圖繪製

``````# 提取繪製的數據
def paraCount(layers, location, index):
values = []
for value in layers:
if location[index] == 0:
values.append(value)
index += 1
values = np.array(values)
return  values
###########################繪製剪枝後的整體參數分佈############################
layers = expand_model(model)
layers = layers.detach().cpu().numpy()
values = paraCount(layers, location, 0)
plt.figure(figsize=(20,5))
plt.hist(values, bins = 1000)
plt.show()
###########################繪製第一層網絡參數分佈############################
fc1_hist = fc1.reshape(-1)
fc1_hist = paraCount(fc1_hist, location, 0)
plt.figure(figsize=(20,5))
plt.hist(fc1_hist, bins = 300)
plt.show()
###########################繪製第二層網絡參數分佈############################
fc2_hist = fc2.reshape(-1)
fc2_hist = paraCount(fc2_hist, location, fc1.shape[0] * fc1.shape[1])
plt.figure(figsize=(20,5))
plt.hist(fc2_hist, bins = 100)
plt.show()
###########################繪製第三層網絡參數分佈############################
fc3_hist = fc3.reshape(-1)
fc3_hist = paraCount(fc3_hist,location, fc1.shape[0] * fc1.shape[1] + fc2.shape[0] * fc2.shape[1])
plt.figure(figsize=(20,5))
plt.hist(fc3_hist, bins = 100)
plt.show()
``````

實驗

fc1 169
fc2 70
fc3 結果是4

fc1 167
fc2 71
fc3 結果是4

fc1 172
fc2 73
fc3 結果是4

fc1 169
fc2 74
fc3 結果是4

fc1 167
fc2 73
fc3 結果是4

fc1 171
fc2 73
fc3 結果是4

fc1 176
fc2 74
fc3 結果是4

fc1 174
fc2 73
fc3 結果是4

fc1 181
fc2 77
fc3 結果是4

fc1 184
fc2 78
fc3 結果是4

fc1 189
fc2 84
fc3 結果是4

fc1 174
fc2 76
fc3 結果是4

fc1 169
fc2 74
fc3 結果是4

fc1 167
fc2 75
fc3 結果是4

fc1 167
fc2 77
fc3 結果是4

fc1 167
fc2 77
fc3 結果是4

剪枝99%參數

fc1 124
fc2 69
fc3 結果是4

fc1 101
fc2 66
fc3 結果是4

fc1 78
fc2 64
fc3 結果是4

fc1 25
fc2 56
fc3 結果是1

權重柱狀圖

99%裁剪的情況下，一共使用了2662參數。