Libtorch c++ 搭建全連接網絡識別MINST手寫數字

這是個完整的例子,用全連接網絡方法識別手寫數字,分爲三部分,(1)搭建網絡,(2)讀取MNIST數據,(3)優化器設置,(4)訓練網絡。

1、網絡搭建部分

用struct的方式建立自定義網絡Net,它繼承自torch::nn::Module,實現了forward函數,

該網絡中註冊的內置網絡模塊是三個線性網絡,fc1,fc2,fc3

神經元的個數,fc1爲(784,64),fc2(64,32),fc3(32,10),fc1的輸入層個數爲784,是因爲MNIST圖像的像素爲28*28=784,網絡的最後一層是log_softmax,輸入是fc3的輸出,fc3的輸出神經元個數爲10。它的輸出神經元個數與輸入個數一致,還是10,log_softmax()層的作用是接受一個實數向量計算概率分佈然後取對數。整體的網絡結構如下:

MNIST picture->tensor->fc1(784, 64)->relu()->dropout()->fc2(64,32)->relu()->fc3(32, 10)->log_softmax()->prediction

可以知道輸入是長度爲784的向量,輸出是長度爲10的概率分佈的對數。

2、數據讀取部分

與pytorch類似,libtorch也要求必須用make_data_loader這種多線形成數據加載的方式加載數據。這個沒有新建數據集,而是採用了libtorch自帶的MNIST數據集(libtorch通過torch::data::datasets封裝了一些常見的數據集,方便使用),直接加載即可。然後通過map映射的方式把數據形成批量的,數據轉化方式是棧(stack)的方式

3、優化器設置

優化器設置比較簡單,選擇的是隨機提取下降,固定學習率。

4、訓練部分

通過for訓練的方式進行迭代訓練,每個循環內,重置梯度爲0,然後計算損失函數loss,通過loss的反向計算backward計算梯度,通過優化器optimizer.step()調正網絡的權重,然後執行下次訓練,每個一定次數保存當前的網絡。

其中損失函數選擇的是nll_loss,它輸入要求是一個對數概率向量和一個目標標籤. 它不會爲我們計算對數概率. 適合網絡的最後一層是log_softmax.

保存當前網絡的方法是,torch::save(net, "net.pt")

#include <iostream>
#include <torch/script.h>
#include <torch/csrc/api/include/torch/torch.h>

// Define a new Module.
struct Net : torch::nn::Module {
    Net() {
        // Construct and register two Linear submodules.
        fc1 = register_module("fc1", torch::nn::Linear(784, 64));
        fc2 = register_module("fc2", torch::nn::Linear(64, 32));
        fc3 = register_module("fc3", torch::nn::Linear(32, 10));
    }

    // Implement the Net's algorithm.
    torch::Tensor forward(torch::Tensor x) {
        // Use one of many tensor manipulation functions.
        x = torch::relu(fc1->forward(x.reshape({ x.size(0), 784 })));
        x = torch::dropout(x, /*p=*/0.5, /*train=*/is_training());
        x = torch::relu(fc2->forward(x));
        x = torch::log_softmax(fc3->forward(x), /*dim=*/1);
        return x;
    }

    // Use one of many "standard library" modules.
    torch::nn::Linear fc1{ nullptr }, fc2{ nullptr }, fc3{ nullptr };
};

int main()
{
    std::cout << "Hello World!\n";
    //torch::Tensor tensor = torch::eye(3);
    torch::Tensor tensor = torch::rand({ 2,3 });
    std::cout << tensor << std::endl;
    torch::save(tensor, "tensor.pt");

    const static int WIDTH = 512, HEIGHT = 512;

    // Create a new Net.
    auto net = std::make_shared<Net>();

    // Create a multi-threaded data loader for the MNIST dataset.
    auto data_loader = torch::data::make_data_loader(
        torch::data::datasets::MNIST("../../pytorchCpp/data/mnist/MNIST/raw").map(
            torch::data::transforms::Stack<>()),
        /*batch_size=*/64);

    // Instantiate an SGD optimization algorithm to update our Net's parameters.
    torch::optim::SGD optimizer(net->parameters(), /*lr=*/0.01);
    
    std::vector<double> lossVec;
    for (size_t epoch = 1; epoch <= 10; ++epoch) {
        size_t batch_index = 0;
        // Iterate the data loader to yield batches from the dataset.
        for (auto& batch : *data_loader) {
            // Reset gradients.
            optimizer.zero_grad();
            // Execute the model on the input data.
            torch::Tensor prediction = net->forward(batch.data);
            // Compute a loss value to judge the prediction of our model.
            torch::Tensor loss = torch::nll_loss(prediction, batch.target);
            // Compute gradients of the loss w.r.t. the parameters of our model.
            loss.backward();
            // Update the parameters based on the calculated gradients.
            optimizer.step();
            // Output the loss and checkpoint every 100 batches.
            if (++batch_index % 100 == 0) {
                std::cout << "Epoch: " << epoch << " | Batch: " << batch_index
                    << " | Loss: " << loss.item<float>() << std::endl;
                lossVec.push_back(loss.item<double>());
                // Serialize your model periodically as a checkpoint.
                torch::save(net, "net.pt");
            }
            
        }
    }
  
        
    
}

 

 損失函數

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章