利用Libtorch c++創建並訓練DCGAN網絡生成手寫數字MNIST

目錄

什麼是對抗生成網絡GAN

生成網絡模塊

鑑別網絡模塊

數據集定義

數據加載

數據檢查的輸出結果

定義優化器

網絡訓練

模型的定期保存

全部源代碼


我們的目的是從MNIST 數據集生成圖片,將使用對抗生成網絡(GAN)完成這個任務。具體說,將採用DCGAN架構,它是最早最簡單的對抗生成網絡,但足以完成這項任務。

什麼是對抗生成網絡GAN

GAN由兩個不同的神經網絡模型組成:一個生成器和一個鑑別器。生成器接收來自噪聲分佈的樣本,其目的是將每個噪聲樣本轉換爲類似於目標分佈的圖像(在我們的例子中是MNIST數據集)。鑑別器依次從MNIST數據集接收真實圖像,或從生成器接收圖像。它被要求發出一個概率來判斷一個特定圖像是真實的(接近“1”)還是假的(接近“0”)。從鑑別器對生成器產生的圖像的真實性反饋用於進一步訓練生成器。關於真假圖片的反饋用於優化鑑別器。理論上,生成器和鑑別器之間的微妙平衡使它們協同改進,導致生成器生成的圖像與目標分佈不可區分,從而欺騙鑑別器(那時)的優秀眼睛,使真實圖像和僞圖像的概率都達到“0.5”。對我們來說,最終的結果是一臺機器,它接收噪聲作爲輸入,並生成數字的真實圖像作爲輸出。

生成網絡模塊

生成模塊包含一系列的二維轉置卷積、批量正態轉化、ReLU激活單元。在forward方法把多個模塊之間傳遞輸入和輸出。

生成網絡的作用是接受一個隨機數組成的數據序列,生成一個灰度圖片。輸入層的通道爲kNoiseSize=100,尺寸爲1*1,輸出的通道爲256,尺寸爲4*4,然後依次得到尺寸爲7*7, 14*14, 28*28的圖像,MNIST數據集的圖像尺寸就是28*28。採用二維卷積轉置的尺寸計算公式,

H_{out}=(H_{in}-1)*stride[0]-2*padding[0]+kernel_size[0]+output_padding[0]

尺寸變化過程如下:

自定義的網絡繼承nn::Module模塊,這裏採用了初始化列表的方式,一個好處是不需要再定義複雜的構造函數,第二個好處是使用初始化列表少了一次調用默認構造函數的過程,這對於數據密集型的類來說,是非常高效的。。另外,由於c++語言自身沒有反射功能,要求每個網絡層都需要通過register_module()函數進行手動註冊。另外,定義網絡之後,也需要通過宏的方式註冊自定義的網絡模塊,TORCH_MODULE(DCGANGenerator);

struct DCGANGeneratorImpl : nn::Module {
    DCGANGeneratorImpl(int kNoiseSize)
        : conv1(nn::ConvTranspose2dOptions(kNoiseSize, 256, 4)
            .bias(false)),
        batch_norm1(256),
        conv2(nn::ConvTranspose2dOptions(256, 128, 3)
            .stride(2)
            .padding(1)
            .bias(false)),
        batch_norm2(128),
        conv3(nn::ConvTranspose2dOptions(128, 64, 4)
            .stride(2)
            .padding(1)
            .bias(false)),
        batch_norm3(64),
        conv4(nn::ConvTranspose2dOptions(64, 1, 4)
            .stride(2)
            .padding(1)
            .bias(false))
    {
        // register_module() is needed if we want to use the parameters() method later on
        register_module("conv1", conv1);
        register_module("conv2", conv2);
        register_module("conv3", conv3);
        register_module("conv4", conv4);
        register_module("batch_norm1", batch_norm1);
        register_module("batch_norm2", batch_norm2);
        register_module("batch_norm3", batch_norm3);
    }

    torch::Tensor forward(torch::Tensor x) {
        x = torch::relu(batch_norm1(conv1(x)));
        x = torch::relu(batch_norm2(conv2(x)));
        x = torch::relu(batch_norm3(conv3(x)));
        x = torch::tanh(conv4(x));
        return x;
    }

    nn::ConvTranspose2d conv1, conv2, conv3, conv4;
    nn::BatchNorm2d batch_norm1, batch_norm2, batch_norm3;
};
TORCH_MODULE(DCGANGenerator);

鑑別網絡模塊

鑑別器類網絡採用類似於卷積、批量規範化和激活的序列。然而,卷積現在是正常卷積而不是轉置卷積,我們使用alpha值爲0.2的leaky ReLU而不是vanilla ReLU。鑑別網絡接收真實圖像或生成器生成的假圖像,通過一些列卷積和正態化操作,最終激活變成一個S型函數Sigmoid,它將值壓縮到0到1之間的範圍內。然後我們可以將這些壓縮值解釋爲鑑別器賦予圖像真實性的概率。

鑑別網絡的通過多次的卷積計算,尺寸從28*28變爲1,卷積計算尺寸的計算公式是:

$$L_{out}=floor((L_{in}+2padding-dilation(kernerl_size-1)-1)/stride+1)$$

具體變化過程是:

爲了構建鑑別器,我們將嘗試一些不同的東西:順序模塊Sequential Module。與Python一樣,LibTorch在這裏提供了兩個用於模型定義的api:一個是通過連續函數傳遞輸入的函數api(例如生成器模塊示例),另一個是更面向對象的api,在這裏我們構建了一個包含整個模型作爲子模塊的順序模塊。Sequential模塊簡單地執行函數的組合,第一個子模塊的輸出變成第二個子模塊的輸入,第三個模塊的輸出是第四個模塊的輸入,不需要再自定義forwrd函數。nn::Sequential類型的構造函數的參數是nn網絡的列表。 

nn::Sequential discriminator(
    // Layer 1
    nn::Conv2d(
        nn::Conv2dOptions(1, 64, 4).stride(2).padding(1).bias(false)),
    nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
    // Layer 2
    nn::Conv2d(
        nn::Conv2dOptions(64, 128, 4).stride(2).padding(1).bias(false)),
    nn::BatchNorm2d(128),
    nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
    // Layer 3
    nn::Conv2d(
        nn::Conv2dOptions(128, 256, 4).stride(2).padding(1).bias(false)),
    nn::BatchNorm2d(256),
    nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
    // Layer 4
    nn::Conv2d(
        nn::Conv2dOptions(256, 1, 3).stride(1).padding(0).bias(false)),
    nn::Sigmoid());

數據集定義

定義生成模塊和判別模塊之後,需要加載可以用來訓練的數據。面向c++的接口,與python類似,提供了強悍的並行數據加載器。該數據加載器可以從數據集中批量地加載數據, 並且提供了很多配置選項。

儘管python的數據加載器也是多進程的,c++的數據加載器是真正的多線程,不開啓任意的新進程。

數據加載器是面向C++接口的一部分,包含在torch::data:: 名稱空間下,這些接口包括不同的組成:

  • 數據加載器類
  • 定義數據集的接口
  • 定義數據轉換的接口,它可以作用域數據集
  • 定義數據採樣器,它可以生成數據集的索引
  • 已有的(內置的)數據集、轉化和採樣器

在這個例子中,我們可以使用面向c++接口提供的MNIST內置數據集,torch::data::datasets::MNIST,並執行兩次變換。首先,把圖片做正太化變化,這樣可以把數據轉化爲-1到+1之間(原始數據是0到1之間),然後,應用入棧列隊(Stack collation),它提取一組矩陣,並沿着第一維度把單個的矩陣入棧。

dataset的size()函數可以獲取數據集的個數,除以批量加載時數據集kBatchSize的多少,可以計算需要多少批加載。

 //load the dataset
    auto dataset = torch::data::datasets::MNIST("../../pytorchCpp/data/mnist/MNIST/raw")
        .map(torch::data::transforms::Normalize<>(0.5, 0.5))
        .map(torch::data::transforms::Stack<>());
    const int64_t batches_per_epoch = std::ceil(dataset.size().value() / static_cast<double>(kBatchSize));

數據加載

下一步,是創建數據加載器,並把數據集傳遞給它。通過torch::data::make_data_loader創建新的數據加載器,它返回數據集正確類型的地址std::unique_ptr(它依賴數據集的類型,採樣器的類型和其他實現的細節)。

數據加載器有很多選項。您可以[在此處]檢查全部選項。例如,爲了加快數據加載速度,我們可以增加線程的數量。默認值爲零,這意味着將使用主線程。如果將“workers”設置爲“2”,將生成兩個同時加載數據的線程。我們還應該將批大小從默認的'1'增加到更合理的值,比如'64'(kBatchSize的值)。因此,讓我們創建一個“DataLoaderOptions”對象並設置適當的屬性:

//define the data_loader
    auto data_loader = torch::data::make_data_loader(
        std::move(dataset),
        torch::data::DataLoaderOptions().batch_size(kBatchSize).workers(2));

數據檢查的輸出結果

數據加載器返回的數據類型是torch::data::Example,這個數據類型是簡單的結構,擁有data字段存儲數據,和target字段存儲標籤。因爲前面使用了入棧操作,這裏數據加載器僅返回單個樣本,如果不進行入棧操作,則數據加載器返回的是列表形式。std::vectortorch::data::Example<>,每個元素是一批樣本。

 //print to check the data
    for (torch::data::Example<>& batch : *data_loader) {
        std::cout << "Batch size: " << batch.data.size(0) << " | Labels: ";
        for (int64_t i = 0; i < batch.data.size(0); ++i) {
            std::cout << batch.target[i].item<int64_t>() << " ";
        }
        std::cout << std::endl;
    }

這是kBatchSize=64時的數據檢查結果

Batch size: 64 | Labels: 8 7 5 9 1 0 5 9 5 1 7 9 5 7 1 0 6 7 5 2 8 2 2 7 0 2 4 1 8 7 8 7 5 0 2 0 2 7 7 6 5 8 5 8 5 1 6 1 1 0 9 8 7 4 0 5 4 9 8 9 0 3 9 2
Batch size: 64 | Labels: 5 2 6 7 5 3 7 4 2 1 5 3 2 6 2 1 7 6 4 4 7 4 9 7 6 5 4 7 9 2 2 1 7 4 0 8 6 0 5 1 2 8 9 5 9 9 6 9 7 8 1 2 1 0 3 2 3 9 2 5 5 2 8 7
Batch size: 64 | Labels: 9 3 9 5 6 7 6 6 8 2 3 9 8 8 0 1 9 2 2 8 4 0 1 2 7 9 6 8 9 5 6 6 9 4 3 7 8 5 2 9 3 0 7 5 2 8 2 9 7 5 4 3 2 1 9 8 7 2 7 2 0 8 3 3
Batch size: 64 | Labels: 4 7 0 1 3 6 4 8 0 3 2 2 4 8 8 4 8 8 6 5 6 5 7 8 1 9 2 5 3 2 8 5 8 0 6 9 5 7 9 8 5 2 4 4 6 8 2 0 5 0 0 4 3 5 0 9 0 3 2 8 8 1 1 6
Batch size: 64 | Labels: 4 1 0 3 9 6 2 1 1 5 3 4 3 7 7 7 4 4 7 4 5 3 4 1 0 7 8 1 6 0 6 8 8 4 1 8 4 0 3 3 1 9 7 5 6 2 4 1 3 8 9 4 7 1 0 8 6 8 9 8 5 7 2 5
Batch size: 64 | Labels: 0 6 3 9 9 6 0 9 9 3 0 0 0 5 9 0 9 6 9 8 1 8 7 5 1 0 1 1 6 8 4 7 6 2 8 1 8 6 7 8 5 8 9 6 1 2 9 3 8 2 0 8 4 7 6 9 6 1 1 1 4 2 8 8
Batch size: 64 | Labels: 3 0 8 9 3 5 4 9 6 3 2 3 3 9 7 9 6 0 7 2 7 8 2 4 8 7 9 3 4 7 9 0 5 6 3 8 1 1 3 9 9 1 6 3 7 3 1 7 0 1 5 6 2 1 2 1 7 8 7 9 6 2 7 7
Batch size: 64 | Labels: 7 8 7 3 1 7 7 4 9 1 4 6 7 6 4 2 0 8 1 0 5 5 8 4 1 1 8 9 5 3 1 7 4 1 2 8 1 7 8 5 7 4 0 3 8 3 8 3 6 3 7 0 4 2 1 1 8 2 8 5 7 6 5 0
Batch size: 64 | Labels: 2 9 8 6 9 1 4 5 8 9 0 2 5 7 2 9 3 9 4 1 3 5 0 1 1 4 0 4 6 9 0 1 9 6 9 5 4 9 7 4 0 6 2 0 7 6 6 8 6 0 9 9 6 2 9 8 5 2 2 3 4 8 7 7
Batch size: 64 | Labels: 8 4 9 8 5 8 4 2 9 8 0 0 1 9 1 8 6 3 2 3 4 0 2 2 5 6 6 0 7 1 9 9 1 1 8 7 9 3 1 8 2 1 0 9 1 7 2 3 1 3 8 2 8 2 9 6 5 0 1 2 1 6 8 6
.....

定義優化器

下面要這個例子的算法部分,並實現生成器和判別器的微妙雙人舞。首先創建兩個優化器,一個用來優化生成器,一個用來優化判別器。用到的優化器是Adam算法。

就像這個例子用到的,面向c++的接口提供了Adagrad, Adam, LBFGS, RMSprop and SGD等優化算法的實現,具體的優化算法列表可以看這個文檔

 //define the optimizer for these two net
    torch::optim::Adam generator_optimizer(
        generator->parameters(), torch::optim::AdamOptions(2e-4).betas(std::make_tuple(0.5, 0.999)));
    torch::optim::Adam discriminator_optimizer(
        discriminator->parameters(), torch::optim::AdamOptions(5e-4).betas(std::make_tuple(0.5, 0.999)));

網絡訓練

下面就是更新訓練循環模塊,需要增加兩個訓練,在數據加載器獲取每組數據,然後在每組中訓練對抗生成網絡模型。

在訓練中,首先在真實圖片上訓練鑑別器,把真是圖片賦予很高的概率,這裏通過torch::empty(batch.data.size(0)).uniform_(0.8, 1.0)作爲標籤的概率。 選擇0.8到1.0的均勻分佈作爲目標概率是爲了鑑別器訓練更穩健。這種技巧稱爲標籤光滑。

在評價判別器之前,我麼需要把鑑別器的梯度參數歸零,計算損失之後,通過調用d_loss.backward()可是執行神經網絡的反向傳播算法,計算新的梯度。不僅在真實的數據集上,在合成的圖片上也執行這個過程。合成的圖片是通過生成網絡計算得到,生成網絡的輸入數據是個隨機噪聲序列。把生成的圖片傳遞給鑑別器,想讓它給出個很低的真實度判別結果,理想結果是0。計算判別器在真實樣本和合成圖片上的損失之後,可以通過優化器更新它的參數。

爲了訓練生成器,也需要把生成器的梯度歸零,然後重新判別在合成數據上的表現。但是此時,需要把合成數據的標籤賦值爲概率1,意味着生成器可以生成讓判別認爲是真的結果。爲此,需要給表愛你fake_labels賦值爲1,最後是用鑑別器的優化算法更新參數。

for (int64_t epoch = 1; epoch <= kNumberOfEpochs; ++epoch) {
        int64_t batch_index = 0;
        for (torch::data::Example<>& batch : *data_loader) {
            // Train discriminator with real images.
            discriminator->zero_grad();
            torch::Tensor real_images = batch.data;
            torch::Tensor real_labels = torch::empty(batch.data.size(0)).uniform_(0.8, 1.0);
            torch::Tensor real_output = discriminator->forward(real_images);
            torch::Tensor d_loss_real = torch::binary_cross_entropy(real_output, real_labels);
            d_loss_real.backward();

            // Train discriminator with fake images.
            torch::Tensor noise = torch::randn({ batch.data.size(0), kNoiseSize, 1, 1 });
            torch::Tensor fake_images = generator->forward(noise);
            torch::Tensor fake_labels = torch::zeros(batch.data.size(0));
            torch::Tensor fake_output = discriminator->forward(fake_images.detach());
            torch::Tensor d_loss_fake = torch::binary_cross_entropy(fake_output, fake_labels);
            d_loss_fake.backward();

            torch::Tensor d_loss = d_loss_real + d_loss_fake;
            discriminator_optimizer.step();

            // Train generator.
            generator->zero_grad();
            fake_labels.fill_(1);
            fake_output = discriminator->forward(fake_images);
            torch::Tensor g_loss = torch::binary_cross_entropy(fake_output, fake_labels);
            g_loss.backward();
            generator_optimizer.step();

            std::printf(
                "\r[%2ld/%2ld][%3ld/%3ld] D_loss: %.4f | G_loss: %.4f",
                epoch,
                kNumberOfEpochs,
                ++batch_index,
                batches_per_epoch,
                d_loss.item<float>(),
                g_loss.item<float>());
        }
    }

到此,我們基本上可以在cpu上訓練模型。但是,目前還沒有涉及捕捉狀態或樣本的輸出,後面會提到。現在,模型可以做的事,主要依賴於生成的圖片是否看起來有意義(像真實的一樣)。代碼執行的運行結果如下:

[ 1/30][200/938] D_loss: 0.3507 | G_loss: 7.64503
-> checkpoint 2
[ 1/30][400/938] D_loss: 2.7487 | G_loss: 3.29385
-> checkpoint 3
[ 1/30][600/938] D_loss: 0.9987 | G_loss: 1.8063
-> checkpoint 4
[ 1/30][800/938] D_loss: 0.7328 | G_loss: 1.8110
-> checkpoint 5
[ 2/30][200/938] D_loss: 0.9540 | G_loss: 0.9474
-> checkpoint 6
[ 2/30][400/938] D_loss: 0.7088 | G_loss: 2.2973
-> checkpoint 7
[ 2/30][600/938] D_loss: 0.4907 | G_loss: 2.4834
-> checkpoint 8
[ 2/30][800/938] D_loss: 0.5548 | G_loss: 2.5090
-> checkpoint 9
[ 3/30][200/938] D_loss: 0.6886 | G_loss: 3.2052
-> checkpoint 10
[ 3/30][400/938] D_loss: 0.5958 | G_loss: 2.8089
-> checkpoint 11
[ 3/30][522/938] D_loss: 0.6508 | G_loss: 2.5090

模型的定期保存

if (batch_index % kCheckpointEvery == 0) {
  // Checkpoint the model and optimizer state.
  torch::save(generator, "generator-checkpoint.pt");
  torch::save(generator_optimizer, "generator-optimizer-checkpoint.pt");
  torch::save(discriminator, "discriminator-checkpoint.pt");
  torch::save(discriminator_optimizer, "discriminator-optimizer-checkpoint.pt");
  // Sample the generator and save the images.
  torch::Tensor samples = generator->forward(torch::randn({8, kNoiseSize, 1, 1}));
  torch::save((samples + 1.0) / 2.0, torch::str("dcgan-sample-", checkpoint_counter, ".pt"));
  std::cout << "\n-> checkpoint " << ++checkpoint_counter << '\n';
}

全部源代碼

#include <iostream>
#include <tuple>
#include <torch/script.h>
#include <torch/csrc/api/include/torch/torch.h>
#include <torch/nn.h>
#include <torch/optim.h>
#include <torch/torch.h>

using namespace torch;

struct DCGANGeneratorImpl : nn::Module {
    DCGANGeneratorImpl(int kNoiseSize)
        : conv1(nn::ConvTranspose2dOptions(kNoiseSize, 256, 4)
            .bias(false)),
        batch_norm1(256),
        conv2(nn::ConvTranspose2dOptions(256, 128, 3)
            .stride(2)
            .padding(1)
            .bias(false)),
        batch_norm2(128),
        conv3(nn::ConvTranspose2dOptions(128, 64, 4)
            .stride(2)
            .padding(1)
            .bias(false)),
        batch_norm3(64),
        conv4(nn::ConvTranspose2dOptions(64, 1, 4)
            .stride(2)
            .padding(1)
            .bias(false))
    {
        // register_module() is needed if we want to use the parameters() method later on
        register_module("conv1", conv1);
        register_module("conv2", conv2);
        register_module("conv3", conv3);
        register_module("conv4", conv4);
        register_module("batch_norm1", batch_norm1);
        register_module("batch_norm2", batch_norm2);
        register_module("batch_norm3", batch_norm3);
    }

    torch::Tensor forward(torch::Tensor x) {
        x = torch::relu(batch_norm1(conv1(x)));
        x = torch::relu(batch_norm2(conv2(x)));
        x = torch::relu(batch_norm3(conv3(x)));
        x = torch::tanh(conv4(x));
        return x;
    }

    nn::ConvTranspose2d conv1, conv2, conv3, conv4;
    nn::BatchNorm2d batch_norm1, batch_norm2, batch_norm3;
};
TORCH_MODULE(DCGANGenerator);

nn::Sequential discriminator(
    // Layer 1
    nn::Conv2d(
        nn::Conv2dOptions(1, 64, 4).stride(2).padding(1).bias(false)),
    nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
    // Layer 2
    nn::Conv2d(
        nn::Conv2dOptions(64, 128, 4).stride(2).padding(1).bias(false)),
    nn::BatchNorm2d(128),
    nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
    // Layer 3
    nn::Conv2d(
        nn::Conv2dOptions(128, 256, 4).stride(2).padding(1).bias(false)),
    nn::BatchNorm2d(256),
    nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
    // Layer 4
    nn::Conv2d(
        nn::Conv2dOptions(256, 1, 3).stride(1).padding(0).bias(false)),
    nn::Sigmoid());


int main() {
    // The size of the noise vector fed to the generator.
    const int64_t kNoiseSize = 100;

    // The batch size for training.
    const int64_t kBatchSize = 64;

    // The number of epochs to train.
    const int64_t kNumberOfEpochs = 30;

    // Where to find the MNIST dataset.
    const char* kDataFolder = "./data";

    // After how many batches to create a new checkpoint periodically.
    const int64_t kCheckpointEvery = 200;

    // How many images to sample at every checkpoint.
    const int64_t kNumberOfSamplesPerCheckpoint = 10;

    // Set to `true` to restore models and optimizers from previously saved
    // checkpoints.
    const bool kRestoreFromCheckpoint = false;

    // After how many batches to log a new update with the loss value.
    const int64_t kLogInterval = 10;

    DCGANGenerator generator(kNoiseSize);

    //load the dataset
    auto dataset = torch::data::datasets::MNIST("../../pytorchCpp/data/mnist/MNIST/raw")
        .map(torch::data::transforms::Normalize<>(0.5, 0.5))
        .map(torch::data::transforms::Stack<>());
    const int64_t batches_per_epoch = std::ceil(dataset.size().value() / static_cast<double>(kBatchSize));

    //define the data_loader
    auto data_loader = torch::data::make_data_loader(
        std::move(dataset),
        torch::data::DataLoaderOptions().batch_size(kBatchSize).workers(2));


    //print to check the data
    for (torch::data::Example<>& batch : *data_loader) {
        std::cout << "Batch size: " << batch.data.size(0) << " | Labels: ";
        for (int64_t i = 0; i < batch.data.size(0); ++i) {
            std::cout << batch.target[i].item<int64_t>() << " ";
        }
        std::cout << std::endl;
    }
    //define the optimizer for these two net
    torch::optim::Adam generator_optimizer(
        generator->parameters(), torch::optim::AdamOptions(2e-4).betas(std::make_tuple(0.5, 0.999)));
    torch::optim::Adam discriminator_optimizer(
        discriminator->parameters(), torch::optim::AdamOptions(5e-4).betas(std::make_tuple(0.5, 0.999)));

    int64_t checkpoint_counter = 0;
    for (int64_t epoch = 1; epoch <= kNumberOfEpochs; ++epoch) {
        int64_t batch_index = 0;
        for (torch::data::Example<>& batch : *data_loader) {
            // Train discriminator with real images.
            discriminator->zero_grad();
            torch::Tensor real_images = batch.data;
            torch::Tensor real_labels = torch::empty(batch.data.size(0)).uniform_(0.8, 1.0);
            torch::Tensor real_output = discriminator->forward(real_images);
            torch::Tensor d_loss_real = torch::binary_cross_entropy(real_output, real_labels);
            d_loss_real.backward();

            // Train discriminator with fake images.
            torch::Tensor noise = torch::randn({ batch.data.size(0), kNoiseSize, 1, 1 });
            torch::Tensor fake_images = generator->forward(noise);
            torch::Tensor fake_labels = torch::zeros(batch.data.size(0));
            torch::Tensor fake_output = discriminator->forward(fake_images.detach());
            torch::Tensor d_loss_fake = torch::binary_cross_entropy(fake_output, fake_labels);
            d_loss_fake.backward();

            torch::Tensor d_loss = d_loss_real + d_loss_fake;
            discriminator_optimizer.step();

            // Train generator.
            generator->zero_grad();
            fake_labels.fill_(1);
            fake_output = discriminator->forward(fake_images);
            torch::Tensor g_loss = torch::binary_cross_entropy(fake_output, fake_labels);
            g_loss.backward();
            generator_optimizer.step();

            //print the status
            std::printf(
                "\r[%2ld/%2ld][%3ld/%3ld] D_loss: %.4f | G_loss: %.4f",
                epoch,
                kNumberOfEpochs,
                ++batch_index,
                batches_per_epoch,
                d_loss.item<float>(),
                g_loss.item<float>());

            //save current model
            if (batch_index % kCheckpointEvery == 0) {
                // Checkpoint the model and optimizer state.
                torch::save(generator, "generator-checkpoint.pt");
                torch::save(generator_optimizer, "generator-optimizer-checkpoint.pt");
                torch::save(discriminator, "discriminator-checkpoint.pt");
                torch::save(discriminator_optimizer, "discriminator-optimizer-checkpoint.pt");
                // Sample the generator and save the images.
                torch::Tensor samples = generator->forward(torch::randn({ 8, kNoiseSize, 1, 1 }));
                torch::save((samples + 1.0) / 2.0, torch::str("dcgan-sample-", checkpoint_counter, ".pt"));
                std::cout << "\n-> checkpoint " << ++checkpoint_counter << '\n';
            }
        }
    }
}

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章