神經網絡批處理 | PyTorch系列(十九)

點擊上方AI算法與圖像處理”,選擇加"星標"或“置頂”

重磅乾貨,第一時間送達

文 |AI_study

原標題:Neural Network Batch Processing - Pass Image Batch To PyTorch CNN

  • 準備數據

  • 建立模型

    • 瞭解批處理如何傳遞到網絡

  • 訓練模型

  • 分析模型的結果

在上一節中,我們瞭解了前向傳播以及如何將單個圖像從訓練集中傳遞到我們的網絡。現在,讓我們看看如何使用一批圖像來完成此操作。我們將使用數據加載器來獲取批處理,然後,在將批處理傳遞到網絡之後,我們將解釋輸出。

傳遞一個 batch的圖像到網絡  

首先,回顧一下上一節的代碼設置。我們需要以下內容:

  1. imports。

  2. 訓練集。

  3. 網絡類定義。

  4. To disable gradient tracking。(可選的)

  5. 網絡類實例。

現在,我們將使用我們的訓練集來創建一個新的DataLoader實例,並設置我們的batch_size = 10,這樣輸出將更易於管理。

> data_loader = torch.utils.data.DataLoader(
     train_set, batch_size=10
)

我們將從數據加載器中提取一個批次,並從該批次中解壓縮圖像和標籤張量。我們將使用複數形式命名變量,因爲當我們在數據加載器迭代器上調用next時,我們知道數據加載器會返回一批10張圖片。

> batch = next(iter(data_loader))
> images, labels = batch

這給了我們兩個張量,一個圖像張量和一個對應標籤的張量。

在上一節中,當我們從訓練集中提取單個圖像時,我們不得不unsqueeze() 張量以添加另一個維度,該維度將有效地將單例圖像轉換爲一個大小爲1的batch。現在我們正在使用數據加載器,默認情況下我們正在處理批處理,因此不需要進一步的處理。

數據加載器返回一批圖像,這些圖像被打包到單個張量中,該張量具有反映以下軸的形狀。

(batch size, input channels, height, width)

這意味着張量的形狀是良好的形狀,無需將其unsqueeze()。

> images.shape
torch.Size([10, 1, 28, 28])


> labels.shape
torch.Size([10])

讓我們解釋這兩種形狀。圖像張量的第一個軸告訴我們,我們有一批十張圖像。這十個圖像具有一個高度和寬度爲28的單一顏色通道。

標籤張量的單軸形狀爲10,與我們批中的十張圖像相對應。每個圖像一個標籤。

好的。通過將圖像張量傳遞到網絡來進行預測。

> preds = network(images)


> preds.shape
torch.Size([10, 10])


> preds
tensor(
    [
        [ 0.1072, -0.1255, -0.0782, -0.1073,  0.1048,  0.1142, -0.0804, -0.0087,  0.0082,  0.0180],
        [ 0.1070, -0.1233, -0.0798, -0.1060,  0.1065,  0.1163, -0.0689, -0.0142,  0.0085,  0.0134],
        [ 0.0985, -0.1287, -0.0979, -0.1001,  0.1092,  0.1129, -0.0605, -0.0248,  0.0290,  0.0066],
        [ 0.0989, -0.1295, -0.0944, -0.1054,  0.1071,  0.1146, -0.0596, -0.0249,  0.0273,  0.0059],
        [ 0.1004, -0.1273, -0.0843, -0.1127,  0.1072,  0.1183, -0.0670, -0.0162,  0.0129,  0.0101],
        [ 0.1036, -0.1245, -0.0842, -0.1047,  0.1097,  0.1176, -0.0682, -0.0126,  0.0128,  0.0147],
        [ 0.1093, -0.1292, -0.0961, -0.1006,  0.1106,  0.1096, -0.0633, -0.0163,  0.0215,  0.0046],
        [ 0.1026, -0.1204, -0.0799, -0.1060,  0.1077,  0.1207, -0.0741, -0.0124,  0.0098,  0.0202],
        [ 0.0991, -0.1275, -0.0911, -0.0980,  0.1109,  0.1134, -0.0625, -0.0391,  0.0318,  0.0104],
        [ 0.1007, -0.1212, -0.0918, -0.0962,  0.1168,  0.1105, -0.0719, -0.0265,  0.0207,  0.0157]
    ]
)

預測張量的形狀爲10 x 10,這給了我們兩個長度爲10的軸。這反映了以下事實:我們有十個圖像,並且對於這十個圖像中的每一個,我們都有十個預測類別。

(batch size, number of prediction classes)

第一維的元素是長度爲十的數組。這些數組元素中的每一個包含對應圖像每個類別的十個預測。

第二維的元素是數字。每個數字都是特定輸出類別的分配值。輸出類別由索引編碼,因此每個索引代表一個特定的輸出類別。該映射由該表給出。

Fashion MNIST 類

Argmax的使用:預測與標籤

爲了對照標籤檢查預測,我們使用argmax() 函數找出哪個索引包含最高的預測值。一旦知道哪個索引具有最高的預測值,就可以將索引與標籤進行比較,以查看是否存在匹配項。

爲此,我們在預測張量上調用argmax() 函數,並指定第二維。

第二個維度是我們的預測張量的最後一個維度。請記住,在我們所有關於張量的工作中,張量的最後一個維度始終包含數字,而其他所有維度都包含其他較小的張量。

在預測張量的情況下,我們有十組數字。argmax() 函數的作用是查看這十組中的每組,找到最大值,然後輸出其索引。

對於每組十個數字:

  1. 查找最大值。

  2. 輸出指標

對此的解釋是,對於批次中的每個圖像,我們正在找到具有最高值的預測類別(每列的最大值)。這是網絡預測的類別。

> preds.argmax(dim=1)
tensor([5, 5, 5, 5, 5, 5, 4, 5, 5, 4])


> labels
tensor([9, 0, 0, 3, 0, 2, 7, 2, 5, 5])

argmax() 函數的結果是十個預測類別的張量。每個數字是出現最大值的索引。我們有十個數字,因爲有十個圖像。一旦有了這個具有最大值的索引張量,就可以將其與標籤張量進行比較。

> preds.argmax(dim=1).eq(labels)
tensor([0, 0, 0, 0, 0, 0, 0, 0, 1, 0], dtype=torch.uint8)


> preds.argmax(dim=1).eq(labels).sum()
tensor(1)

爲了實現比較,我們使用eq() 函數。eq() 函數計算argmax輸出和標籤張量之間的逐元素相等運算。

如果argmax輸出中的預測類別與標籤匹配,則爲1,否則爲0。

最後,如果在此結果上調用sum() 函數,則可以將輸出縮減爲該標量值張量內的單個正確預測數。

我們可以將最後一個調用包裝到名爲get_num_correct() 的函數中,該函數接受預測和標籤,並使用item()方法返回Python數目的正確預測。

def get_num_correct(preds, labels):
    return preds.argmax(dim=1).eq(labels).sum().item()

調用此函數,我們可以看到我們得到了值1。

> get_num_correct(preds, labels)
1

總結

現在,我們應該對如何將一批輸入傳遞到網絡以及在處理卷積神經網絡時預期的形狀有一個很好的瞭解。

文章中內容都是經過仔細研究的,本人水平有限,翻譯無法做到完美,但是真的是費了很大功夫,希望小夥伴能動動你性感的小手,分享朋友圈或點個“在看”,支持一下我 ^_^

英文原文鏈接是:

https://deeplizard.com/learn/video/p1xZ2yWU1eo

加羣交流

歡迎小夥伴加羣交流,目前已有交流羣的方向包括:AI學習交流羣,目標檢測,秋招互助,資料下載等等;加羣可掃描並回復感興趣方向即可(註明:地區+學校/企業+研究方向+暱稱)

 謝謝你看到這裏! ????

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章