六、測試網絡模型
(1) 基本概念理解
需要清楚幾個概念:準確度、精度、召回率
TP: True Positive,將正樣本預測爲正樣本的樣本數量(預測正確)
FN: False Negtive,將正樣本預測爲負樣本的樣本數量
FP: False Positive,將負樣本預測爲正樣本的樣本數量
TN: True Negtive,將負樣本預測爲正樣本的樣本數量(預測正確)
1. 準確度:準確度表示分類正確的樣本數所佔比例
ACC = ( TP + TN) / ( TP + TN + FP + FN)
2.精確度、精度:該概念是針對“預測結果”而言的。表示預測爲正類的樣本中有多少是真的正樣本
P = TP / TP + FP
3.召回率:該概念是針對“原始樣本”而言的。表示樣本中的正例有多少被分類正確了,也即一種是把原來的正類預測成正類(TP),另一種就是把原來的正類預測爲負類(FN)。
R = TP / TP + FN
(2) 測試網絡模型、計算準確度
from torch.utils.data import DataLoader
import torch
from MyData import MyDataset
import torchvision.transforms as trans
from PIL import ImageDraw
import matplotlib.pyplot as plt
def test(self):
testloader = DataLoader(dataset=self.test_dataset, batch_size=50, shuffle=True)
net = torch.load("models/net.pth")
total = 0
for x, y in testloader:
# x , y = x.cuda(), y.cuda()
category, axes = net(x)
total += (category.round() == y[:,4]).sum() # 預測值等於標籤的總數
index = category.round() == 1
"""
這裏表示有小黃人的圖片的索引集 (最後結果是True 和 False的集合)
形如:tensor([True, True, False, True, False, True, True, False, True, True])
"""
target = y[index] # 有小黃人的圖片的標籤(包括座標和分類標籤)
"""
還原有小黃人的圖片,因爲現在要可視化圖片,所以要把之前對圖片進行的歸一化和去均值操作逆向還原回去。
數據預處理的時候對其做了標準化:處理後的圖片=(原始img/255 - mean)/ std 那麼現在計算原始圖片,原始img =(處理後的圖片 * std + mean)*255
"""
x = (x[index].cpu() * MyDataset.std.reshape(-1, 3, 1, 1) + MyDataset.mean.reshape(-1, 3, 1, 1)) # 還原預測爲正樣本的數據。不用乘以255.。trans.ToPILImage("RGB"):自動會乘以255
for j, i in enumerate(axes[index]): # j 爲enumerate自動產生的索引
boxes = (i.data.cpu().numpy() * 224).astype(np.int32) # 還原預測座標並將其轉化爲無符號整型
target_box = (target[j, 0:4].data.cpu().numpy() * 224).astype(np.int32) # 還原目標座標並將其轉化爲無符號整型
img = trans.ToPILImage()(x[j]) # 轉換圖片
"""
torchvision.transforms.ToPILImage
對於一個Tensor的轉化過程是:
1. 將張量的每個元素乘上255
2. 將張量的數據類型有FloatTensor轉化成Uint8
3. 將張量轉化成numpy的ndarray類型
4. 對ndarray對象做transpose (1, 2, 0)的操作
5. 利用Image下的fromarray函數,將ndarray對象轉化成PILImage形式
6. 輸出PILImage
"""
plt.clf()
plt.axis("off")
draw = ImageDraw.Draw(img)
draw.rectangle(boxes.tolist(), outline="red") # 預測值
draw.rectangle(target_box.tolist(), outline="yellow") # 原始值
plt.imshow(img)
plt.pause(1)
# 刪除節點中的一些參數,爲了節省內存空間
del boxes, target_box, img, draw
del x, y, category, axes, index, target
print("正確率:", total/len(category.round)) # GC
(3) 計算網絡精度
"""
P = TP / TP + FP
TP: 預測爲正樣本的結果中,真正的正樣本的數量
FP: 預測爲正樣本的結果中,不是真正的正樣本的數量
1. 如何找到真正的正樣本?TP
分析: 因爲預測出來的正樣本中,既包含了真正的正樣本,也包含了假的正樣本。只有在標籤中才能準確的找到哪些是真的正樣本,哪些是負樣本。所以預測中的正樣本的下標與標籤中的正樣本的下標取交集後就可以找到預測結果中真的正樣本。
"""
# 先計算 TP +FP
TP + FP = (category.round() == 1).sum() # 預測爲正樣本的總數(包括真的正樣本和假的正樣本)
# 原始(標籤中)爲正樣本的下標
bool_index1 = y[:, 4] == 1
# 然後找出標籤中非零元素的索引, flatten()按行的方向降維,直接變成一行。
a_index = torch.nonzeros(bool_index1).flatten() # 找出了所有的1所在的位置,也就是真的正樣本的索引
# 預測爲正樣本的下標
bool_index2 = category.round() == 1 # 預測爲正樣本的索引集
# 取出預測值中非零元素的索引
b_index = torch.nonzero(bool_index2).flatten()
"""
求原始1所在的位置與預測後1所在的位置的交集。所得結果就是預測爲正樣本中,預測值中真的正樣本的位置。求len()得到的就是真的正樣本的個數.
"""
TP = np.intersectld(a_index,b_index)
print(TP)
# p = TP /(TP+FP)
# 精度
P = len(TP) / (category.round() == 1).sum()
(4) 計算網絡召回率
'''
召回率:表示樣本中的正例有多少被分類正確
R = TP / TP + FN
例如: 總共有100個樣本,80個正樣本,20個負樣本。但是預測的時候60個正樣本,將20個正樣本預測爲了負樣本。
將: 60/ (60 + 20)的值稱爲召回率
'''
# 召回率
R = len(TP) / (y[:,4] == 1).sum() # TP + FN 就表示正樣本的數量