分類網絡預測結果代碼解析

分類網絡到底預測了什麼

以前沒搞懂分類網絡中,一張圖片經過了神經網絡後怎麼就變成圖片類別的,現在研究出了一點自己的體會,分享給大家,純屬原創。

分類網絡基本套路代碼

outputs = net(inputs)
        loss = criterion(outputs, targets)
        # loss is variable , if add it(+=loss) directly, there will be a bigger ang bigger graph.
        test_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct+=predicted.eq(targets.data).cpu().sum()
        

這裏的net是分類網絡,有很多種,當我們把一張圖片歸一化後輸入到網絡中,這張圖片就開始被各鍾各樣的卷積核給卷積,卷積的操作我就不在這解釋了。
torch.Size([100, 3, 32, 32])
torch.Size([100, 128])
torch.Size([100, 10])
tensor(-2.2718, device=‘cuda:0’, grad_fn=)
原本輸入了100張圖片的數據,每張圖片都是彩色(3表示彩色)每張圖片都是32×32的大小,輸入到網絡中經過卷積層卷積,變成了一個個的矩陣,矩陣維度是四維。
再經過連接層後變成了二維矩陣,也就是上面的最後一個torch size,100是100張圖片,10是連接層中最後一層的輸出通道,這個輸出通道自己根據要判斷的圖片類別而設置,每個通道能經過加權求和得到一個數字,這裏也就是得到10個數字。
到此,原本的一張彩色圖片,經過網絡後變成了10個數字(設置圖片類別數是10),這10個數字代表對應類別的概率,比如一張圖片它是0類別(用0到9的數字分別表示10類)那麼在10個通道中的第一個通道就是判斷爲0類別的概率。
我們只需用torch.max這個代碼就能得到一張圖片輸出後的10個數字中,哪個數字最大,那個最大的數字代表的通道就是預測的類別,比如第一個通道的數字最大,那麼就對應0類別。

到此就實現了網絡預測圖片是哪一類別了,當然準不準要另說。

學生黨一枚,如有不對的地方請給我留言,謝謝!

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章