目錄
我們的目的是從MNIST 數據集生成圖片,將使用對抗生成網絡(GAN)完成這個任務。具體說,將採用DCGAN架構,它是最早最簡單的對抗生成網絡,但足以完成這項任務。
什麼是對抗生成網絡GAN
GAN由兩個不同的神經網絡模型組成:一個生成器和一個鑑別器。生成器接收來自噪聲分佈的樣本,其目的是將每個噪聲樣本轉換爲類似於目標分佈的圖像(在我們的例子中是MNIST數據集)。鑑別器依次從MNIST數據集接收真實圖像,或從生成器接收假圖像。它被要求發出一個概率來判斷一個特定圖像是真實的(接近“1”)還是假的(接近“0”)。從鑑別器對生成器產生的圖像的真實性反饋用於進一步訓練生成器。關於真假圖片的反饋用於優化鑑別器。理論上,生成器和鑑別器之間的微妙平衡使它們協同改進,導致生成器生成的圖像與目標分佈不可區分,從而欺騙鑑別器(那時)的優秀眼睛,使真實圖像和僞圖像的概率都達到“0.5”。對我們來說,最終的結果是一臺機器,它接收噪聲作爲輸入,並生成數字的真實圖像作爲輸出。
生成網絡模塊
生成模塊包含一系列的二維轉置卷積、批量正態轉化、ReLU激活單元。在forward方法把多個模塊之間傳遞輸入和輸出。
生成網絡的作用是接受一個隨機數組成的數據序列,生成一個灰度圖片。輸入層的通道爲kNoiseSize=100,尺寸爲1*1,輸出的通道爲256,尺寸爲4*4,然後依次得到尺寸爲7*7, 14*14, 28*28的圖像,MNIST數據集的圖像尺寸就是28*28。採用二維卷積轉置的尺寸計算公式,
H_{out}=(H_{in}-1)*stride[0]-2*padding[0]+kernel_size[0]+output_padding[0]
尺寸變化過程如下:
自定義的網絡繼承nn::Module模塊,這裏採用了初始化列表的方式,一個好處是不需要再定義複雜的構造函數,第二個好處是使用初始化列表少了一次調用默認構造函數的過程,這對於數據密集型的類來說,是非常高效的。。另外,由於c++語言自身沒有反射功能,要求每個網絡層都需要通過register_module()函數進行手動註冊。另外,定義網絡之後,也需要通過宏的方式註冊自定義的網絡模塊,TORCH_MODULE(DCGANGenerator);
struct DCGANGeneratorImpl : nn::Module {
DCGANGeneratorImpl(int kNoiseSize)
: conv1(nn::ConvTranspose2dOptions(kNoiseSize, 256, 4)
.bias(false)),
batch_norm1(256),
conv2(nn::ConvTranspose2dOptions(256, 128, 3)
.stride(2)
.padding(1)
.bias(false)),
batch_norm2(128),
conv3(nn::ConvTranspose2dOptions(128, 64, 4)
.stride(2)
.padding(1)
.bias(false)),
batch_norm3(64),
conv4(nn::ConvTranspose2dOptions(64, 1, 4)
.stride(2)
.padding(1)
.bias(false))
{
// register_module() is needed if we want to use the parameters() method later on
register_module("conv1", conv1);
register_module("conv2", conv2);
register_module("conv3", conv3);
register_module("conv4", conv4);
register_module("batch_norm1", batch_norm1);
register_module("batch_norm2", batch_norm2);
register_module("batch_norm3", batch_norm3);
}
torch::Tensor forward(torch::Tensor x) {
x = torch::relu(batch_norm1(conv1(x)));
x = torch::relu(batch_norm2(conv2(x)));
x = torch::relu(batch_norm3(conv3(x)));
x = torch::tanh(conv4(x));
return x;
}
nn::ConvTranspose2d conv1, conv2, conv3, conv4;
nn::BatchNorm2d batch_norm1, batch_norm2, batch_norm3;
};
TORCH_MODULE(DCGANGenerator);
鑑別網絡模塊
鑑別器類網絡採用類似於卷積、批量規範化和激活的序列。然而,卷積現在是正常卷積而不是轉置卷積,我們使用alpha值爲0.2的leaky ReLU而不是vanilla ReLU。鑑別網絡接收真實圖像或生成器生成的假圖像,通過一些列卷積和正態化操作,最終激活變成一個S型函數Sigmoid,它將值壓縮到0到1之間的範圍內。然後我們可以將這些壓縮值解釋爲鑑別器賦予圖像真實性的概率。
鑑別網絡的通過多次的卷積計算,尺寸從28*28變爲1,卷積計算尺寸的計算公式是:
$$L_{out}=floor((L_{in}+2padding-dilation(kernerl_size-1)-1)/stride+1)$$
具體變化過程是:
爲了構建鑑別器,我們將嘗試一些不同的東西:順序模塊Sequential Module。與Python一樣,LibTorch在這裏提供了兩個用於模型定義的api:一個是通過連續函數傳遞輸入的函數api(例如生成器模塊示例),另一個是更面向對象的api,在這裏我們構建了一個包含整個模型作爲子模塊的順序模塊。Sequential模塊簡單地執行函數的組合,第一個子模塊的輸出變成第二個子模塊的輸入,第三個模塊的輸出是第四個模塊的輸入,不需要再自定義forwrd函數。nn::Sequential類型的構造函數的參數是nn網絡的列表。
nn::Sequential discriminator(
// Layer 1
nn::Conv2d(
nn::Conv2dOptions(1, 64, 4).stride(2).padding(1).bias(false)),
nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
// Layer 2
nn::Conv2d(
nn::Conv2dOptions(64, 128, 4).stride(2).padding(1).bias(false)),
nn::BatchNorm2d(128),
nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
// Layer 3
nn::Conv2d(
nn::Conv2dOptions(128, 256, 4).stride(2).padding(1).bias(false)),
nn::BatchNorm2d(256),
nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
// Layer 4
nn::Conv2d(
nn::Conv2dOptions(256, 1, 3).stride(1).padding(0).bias(false)),
nn::Sigmoid());
數據集定義
定義生成模塊和判別模塊之後,需要加載可以用來訓練的數據。面向c++的接口,與python類似,提供了強悍的並行數據加載器。該數據加載器可以從數據集中批量地加載數據, 並且提供了很多配置選項。
儘管python的數據加載器也是多進程的,c++的數據加載器是真正的多線程,不開啓任意的新進程。
數據加載器是面向C++接口的一部分,包含在torch::data:: 名稱空間下,這些接口包括不同的組成:
- 數據加載器類
- 定義數據集的接口
- 定義數據轉換的接口,它可以作用域數據集
- 定義數據採樣器,它可以生成數據集的索引
- 已有的(內置的)數據集、轉化和採樣器
在這個例子中,我們可以使用面向c++接口提供的MNIST內置數據集,torch::data::datasets::MNIST,並執行兩次變換。首先,把圖片做正太化變化,這樣可以把數據轉化爲-1到+1之間(原始數據是0到1之間),然後,應用入棧列隊(Stack collation),它提取一組矩陣,並沿着第一維度把單個的矩陣入棧。
dataset的size()函數可以獲取數據集的個數,除以批量加載時數據集kBatchSize的多少,可以計算需要多少批加載。
//load the dataset
auto dataset = torch::data::datasets::MNIST("../../pytorchCpp/data/mnist/MNIST/raw")
.map(torch::data::transforms::Normalize<>(0.5, 0.5))
.map(torch::data::transforms::Stack<>());
const int64_t batches_per_epoch = std::ceil(dataset.size().value() / static_cast<double>(kBatchSize));
數據加載
下一步,是創建數據加載器,並把數據集傳遞給它。通過torch::data::make_data_loader創建新的數據加載器,它返回數據集正確類型的地址std::unique_ptr(它依賴數據集的類型,採樣器的類型和其他實現的細節)。
數據加載器有很多選項。您可以[在此處]檢查全部選項。例如,爲了加快數據加載速度,我們可以增加線程的數量。默認值爲零,這意味着將使用主線程。如果將“workers”設置爲“2”,將生成兩個同時加載數據的線程。我們還應該將批大小從默認的'1'增加到更合理的值,比如'64'(kBatchSize的值)。因此,讓我們創建一個“DataLoaderOptions”對象並設置適當的屬性:
//define the data_loader
auto data_loader = torch::data::make_data_loader(
std::move(dataset),
torch::data::DataLoaderOptions().batch_size(kBatchSize).workers(2));
數據檢查的輸出結果
數據加載器返回的數據類型是torch::data::Example
,這個數據類型是簡單的結構,擁有data字段存儲數據,和target字段存儲標籤。因爲前面使用了入棧操作,這裏數據加載器僅返回單個樣本,如果不進行入棧操作,則數據加載器返回的是列表形式。std::vectortorch::data::Example<>,每個元素是一批樣本。
//print to check the data
for (torch::data::Example<>& batch : *data_loader) {
std::cout << "Batch size: " << batch.data.size(0) << " | Labels: ";
for (int64_t i = 0; i < batch.data.size(0); ++i) {
std::cout << batch.target[i].item<int64_t>() << " ";
}
std::cout << std::endl;
}
這是kBatchSize=64時的數據檢查結果
Batch size: 64 | Labels: 8 7 5 9 1 0 5 9 5 1 7 9 5 7 1 0 6 7 5 2 8 2 2 7 0 2 4 1 8 7 8 7 5 0 2 0 2 7 7 6 5 8 5 8 5 1 6 1 1 0 9 8 7 4 0 5 4 9 8 9 0 3 9 2
Batch size: 64 | Labels: 5 2 6 7 5 3 7 4 2 1 5 3 2 6 2 1 7 6 4 4 7 4 9 7 6 5 4 7 9 2 2 1 7 4 0 8 6 0 5 1 2 8 9 5 9 9 6 9 7 8 1 2 1 0 3 2 3 9 2 5 5 2 8 7
Batch size: 64 | Labels: 9 3 9 5 6 7 6 6 8 2 3 9 8 8 0 1 9 2 2 8 4 0 1 2 7 9 6 8 9 5 6 6 9 4 3 7 8 5 2 9 3 0 7 5 2 8 2 9 7 5 4 3 2 1 9 8 7 2 7 2 0 8 3 3
Batch size: 64 | Labels: 4 7 0 1 3 6 4 8 0 3 2 2 4 8 8 4 8 8 6 5 6 5 7 8 1 9 2 5 3 2 8 5 8 0 6 9 5 7 9 8 5 2 4 4 6 8 2 0 5 0 0 4 3 5 0 9 0 3 2 8 8 1 1 6
Batch size: 64 | Labels: 4 1 0 3 9 6 2 1 1 5 3 4 3 7 7 7 4 4 7 4 5 3 4 1 0 7 8 1 6 0 6 8 8 4 1 8 4 0 3 3 1 9 7 5 6 2 4 1 3 8 9 4 7 1 0 8 6 8 9 8 5 7 2 5
Batch size: 64 | Labels: 0 6 3 9 9 6 0 9 9 3 0 0 0 5 9 0 9 6 9 8 1 8 7 5 1 0 1 1 6 8 4 7 6 2 8 1 8 6 7 8 5 8 9 6 1 2 9 3 8 2 0 8 4 7 6 9 6 1 1 1 4 2 8 8
Batch size: 64 | Labels: 3 0 8 9 3 5 4 9 6 3 2 3 3 9 7 9 6 0 7 2 7 8 2 4 8 7 9 3 4 7 9 0 5 6 3 8 1 1 3 9 9 1 6 3 7 3 1 7 0 1 5 6 2 1 2 1 7 8 7 9 6 2 7 7
Batch size: 64 | Labels: 7 8 7 3 1 7 7 4 9 1 4 6 7 6 4 2 0 8 1 0 5 5 8 4 1 1 8 9 5 3 1 7 4 1 2 8 1 7 8 5 7 4 0 3 8 3 8 3 6 3 7 0 4 2 1 1 8 2 8 5 7 6 5 0
Batch size: 64 | Labels: 2 9 8 6 9 1 4 5 8 9 0 2 5 7 2 9 3 9 4 1 3 5 0 1 1 4 0 4 6 9 0 1 9 6 9 5 4 9 7 4 0 6 2 0 7 6 6 8 6 0 9 9 6 2 9 8 5 2 2 3 4 8 7 7
Batch size: 64 | Labels: 8 4 9 8 5 8 4 2 9 8 0 0 1 9 1 8 6 3 2 3 4 0 2 2 5 6 6 0 7 1 9 9 1 1 8 7 9 3 1 8 2 1 0 9 1 7 2 3 1 3 8 2 8 2 9 6 5 0 1 2 1 6 8 6
.....
定義優化器
下面要這個例子的算法部分,並實現生成器和判別器的微妙雙人舞。首先創建兩個優化器,一個用來優化生成器,一個用來優化判別器。用到的優化器是Adam算法。
就像這個例子用到的,面向c++的接口提供了Adagrad, Adam, LBFGS, RMSprop and SGD等優化算法的實現,具體的優化算法列表可以看這個文檔。
//define the optimizer for these two net
torch::optim::Adam generator_optimizer(
generator->parameters(), torch::optim::AdamOptions(2e-4).betas(std::make_tuple(0.5, 0.999)));
torch::optim::Adam discriminator_optimizer(
discriminator->parameters(), torch::optim::AdamOptions(5e-4).betas(std::make_tuple(0.5, 0.999)));
網絡訓練
下面就是更新訓練循環模塊,需要增加兩個訓練,在數據加載器獲取每組數據,然後在每組中訓練對抗生成網絡模型。
在訓練中,首先在真實圖片上訓練鑑別器,把真是圖片賦予很高的概率,這裏通過torch::empty(batch.data.size(0)).uniform_(0.8, 1.0)作爲標籤的概率。 選擇0.8到1.0的均勻分佈作爲目標概率是爲了鑑別器訓練更穩健。這種技巧稱爲標籤光滑。
在評價判別器之前,我麼需要把鑑別器的梯度參數歸零,計算損失之後,通過調用d_loss.backward()
可是執行神經網絡的反向傳播算法,計算新的梯度。不僅在真實的數據集上,在合成的圖片上也執行這個過程。合成的圖片是通過生成網絡計算得到,生成網絡的輸入數據是個隨機噪聲序列。把生成的圖片傳遞給鑑別器,想讓它給出個很低的真實度判別結果,理想結果是0。計算判別器在真實樣本和合成圖片上的損失之後,可以通過優化器更新它的參數。
爲了訓練生成器,也需要把生成器的梯度歸零,然後重新判別在合成數據上的表現。但是此時,需要把合成數據的標籤賦值爲概率1,意味着生成器可以生成讓判別認爲是真的結果。爲此,需要給表愛你fake_labels賦值爲1,最後是用鑑別器的優化算法更新參數。
for (int64_t epoch = 1; epoch <= kNumberOfEpochs; ++epoch) {
int64_t batch_index = 0;
for (torch::data::Example<>& batch : *data_loader) {
// Train discriminator with real images.
discriminator->zero_grad();
torch::Tensor real_images = batch.data;
torch::Tensor real_labels = torch::empty(batch.data.size(0)).uniform_(0.8, 1.0);
torch::Tensor real_output = discriminator->forward(real_images);
torch::Tensor d_loss_real = torch::binary_cross_entropy(real_output, real_labels);
d_loss_real.backward();
// Train discriminator with fake images.
torch::Tensor noise = torch::randn({ batch.data.size(0), kNoiseSize, 1, 1 });
torch::Tensor fake_images = generator->forward(noise);
torch::Tensor fake_labels = torch::zeros(batch.data.size(0));
torch::Tensor fake_output = discriminator->forward(fake_images.detach());
torch::Tensor d_loss_fake = torch::binary_cross_entropy(fake_output, fake_labels);
d_loss_fake.backward();
torch::Tensor d_loss = d_loss_real + d_loss_fake;
discriminator_optimizer.step();
// Train generator.
generator->zero_grad();
fake_labels.fill_(1);
fake_output = discriminator->forward(fake_images);
torch::Tensor g_loss = torch::binary_cross_entropy(fake_output, fake_labels);
g_loss.backward();
generator_optimizer.step();
std::printf(
"\r[%2ld/%2ld][%3ld/%3ld] D_loss: %.4f | G_loss: %.4f",
epoch,
kNumberOfEpochs,
++batch_index,
batches_per_epoch,
d_loss.item<float>(),
g_loss.item<float>());
}
}
到此,我們基本上可以在cpu上訓練模型。但是,目前還沒有涉及捕捉狀態或樣本的輸出,後面會提到。現在,模型可以做的事,主要依賴於生成的圖片是否看起來有意義(像真實的一樣)。代碼執行的運行結果如下:
[ 1/30][200/938] D_loss: 0.3507 | G_loss: 7.64503
-> checkpoint 2
[ 1/30][400/938] D_loss: 2.7487 | G_loss: 3.29385
-> checkpoint 3
[ 1/30][600/938] D_loss: 0.9987 | G_loss: 1.8063
-> checkpoint 4
[ 1/30][800/938] D_loss: 0.7328 | G_loss: 1.8110
-> checkpoint 5
[ 2/30][200/938] D_loss: 0.9540 | G_loss: 0.9474
-> checkpoint 6
[ 2/30][400/938] D_loss: 0.7088 | G_loss: 2.2973
-> checkpoint 7
[ 2/30][600/938] D_loss: 0.4907 | G_loss: 2.4834
-> checkpoint 8
[ 2/30][800/938] D_loss: 0.5548 | G_loss: 2.5090
-> checkpoint 9
[ 3/30][200/938] D_loss: 0.6886 | G_loss: 3.2052
-> checkpoint 10
[ 3/30][400/938] D_loss: 0.5958 | G_loss: 2.8089
-> checkpoint 11
[ 3/30][522/938] D_loss: 0.6508 | G_loss: 2.5090
模型的定期保存
if (batch_index % kCheckpointEvery == 0) {
// Checkpoint the model and optimizer state.
torch::save(generator, "generator-checkpoint.pt");
torch::save(generator_optimizer, "generator-optimizer-checkpoint.pt");
torch::save(discriminator, "discriminator-checkpoint.pt");
torch::save(discriminator_optimizer, "discriminator-optimizer-checkpoint.pt");
// Sample the generator and save the images.
torch::Tensor samples = generator->forward(torch::randn({8, kNoiseSize, 1, 1}));
torch::save((samples + 1.0) / 2.0, torch::str("dcgan-sample-", checkpoint_counter, ".pt"));
std::cout << "\n-> checkpoint " << ++checkpoint_counter << '\n';
}
全部源代碼
#include <iostream>
#include <tuple>
#include <torch/script.h>
#include <torch/csrc/api/include/torch/torch.h>
#include <torch/nn.h>
#include <torch/optim.h>
#include <torch/torch.h>
using namespace torch;
struct DCGANGeneratorImpl : nn::Module {
DCGANGeneratorImpl(int kNoiseSize)
: conv1(nn::ConvTranspose2dOptions(kNoiseSize, 256, 4)
.bias(false)),
batch_norm1(256),
conv2(nn::ConvTranspose2dOptions(256, 128, 3)
.stride(2)
.padding(1)
.bias(false)),
batch_norm2(128),
conv3(nn::ConvTranspose2dOptions(128, 64, 4)
.stride(2)
.padding(1)
.bias(false)),
batch_norm3(64),
conv4(nn::ConvTranspose2dOptions(64, 1, 4)
.stride(2)
.padding(1)
.bias(false))
{
// register_module() is needed if we want to use the parameters() method later on
register_module("conv1", conv1);
register_module("conv2", conv2);
register_module("conv3", conv3);
register_module("conv4", conv4);
register_module("batch_norm1", batch_norm1);
register_module("batch_norm2", batch_norm2);
register_module("batch_norm3", batch_norm3);
}
torch::Tensor forward(torch::Tensor x) {
x = torch::relu(batch_norm1(conv1(x)));
x = torch::relu(batch_norm2(conv2(x)));
x = torch::relu(batch_norm3(conv3(x)));
x = torch::tanh(conv4(x));
return x;
}
nn::ConvTranspose2d conv1, conv2, conv3, conv4;
nn::BatchNorm2d batch_norm1, batch_norm2, batch_norm3;
};
TORCH_MODULE(DCGANGenerator);
nn::Sequential discriminator(
// Layer 1
nn::Conv2d(
nn::Conv2dOptions(1, 64, 4).stride(2).padding(1).bias(false)),
nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
// Layer 2
nn::Conv2d(
nn::Conv2dOptions(64, 128, 4).stride(2).padding(1).bias(false)),
nn::BatchNorm2d(128),
nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
// Layer 3
nn::Conv2d(
nn::Conv2dOptions(128, 256, 4).stride(2).padding(1).bias(false)),
nn::BatchNorm2d(256),
nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
// Layer 4
nn::Conv2d(
nn::Conv2dOptions(256, 1, 3).stride(1).padding(0).bias(false)),
nn::Sigmoid());
int main() {
// The size of the noise vector fed to the generator.
const int64_t kNoiseSize = 100;
// The batch size for training.
const int64_t kBatchSize = 64;
// The number of epochs to train.
const int64_t kNumberOfEpochs = 30;
// Where to find the MNIST dataset.
const char* kDataFolder = "./data";
// After how many batches to create a new checkpoint periodically.
const int64_t kCheckpointEvery = 200;
// How many images to sample at every checkpoint.
const int64_t kNumberOfSamplesPerCheckpoint = 10;
// Set to `true` to restore models and optimizers from previously saved
// checkpoints.
const bool kRestoreFromCheckpoint = false;
// After how many batches to log a new update with the loss value.
const int64_t kLogInterval = 10;
DCGANGenerator generator(kNoiseSize);
//load the dataset
auto dataset = torch::data::datasets::MNIST("../../pytorchCpp/data/mnist/MNIST/raw")
.map(torch::data::transforms::Normalize<>(0.5, 0.5))
.map(torch::data::transforms::Stack<>());
const int64_t batches_per_epoch = std::ceil(dataset.size().value() / static_cast<double>(kBatchSize));
//define the data_loader
auto data_loader = torch::data::make_data_loader(
std::move(dataset),
torch::data::DataLoaderOptions().batch_size(kBatchSize).workers(2));
//print to check the data
for (torch::data::Example<>& batch : *data_loader) {
std::cout << "Batch size: " << batch.data.size(0) << " | Labels: ";
for (int64_t i = 0; i < batch.data.size(0); ++i) {
std::cout << batch.target[i].item<int64_t>() << " ";
}
std::cout << std::endl;
}
//define the optimizer for these two net
torch::optim::Adam generator_optimizer(
generator->parameters(), torch::optim::AdamOptions(2e-4).betas(std::make_tuple(0.5, 0.999)));
torch::optim::Adam discriminator_optimizer(
discriminator->parameters(), torch::optim::AdamOptions(5e-4).betas(std::make_tuple(0.5, 0.999)));
int64_t checkpoint_counter = 0;
for (int64_t epoch = 1; epoch <= kNumberOfEpochs; ++epoch) {
int64_t batch_index = 0;
for (torch::data::Example<>& batch : *data_loader) {
// Train discriminator with real images.
discriminator->zero_grad();
torch::Tensor real_images = batch.data;
torch::Tensor real_labels = torch::empty(batch.data.size(0)).uniform_(0.8, 1.0);
torch::Tensor real_output = discriminator->forward(real_images);
torch::Tensor d_loss_real = torch::binary_cross_entropy(real_output, real_labels);
d_loss_real.backward();
// Train discriminator with fake images.
torch::Tensor noise = torch::randn({ batch.data.size(0), kNoiseSize, 1, 1 });
torch::Tensor fake_images = generator->forward(noise);
torch::Tensor fake_labels = torch::zeros(batch.data.size(0));
torch::Tensor fake_output = discriminator->forward(fake_images.detach());
torch::Tensor d_loss_fake = torch::binary_cross_entropy(fake_output, fake_labels);
d_loss_fake.backward();
torch::Tensor d_loss = d_loss_real + d_loss_fake;
discriminator_optimizer.step();
// Train generator.
generator->zero_grad();
fake_labels.fill_(1);
fake_output = discriminator->forward(fake_images);
torch::Tensor g_loss = torch::binary_cross_entropy(fake_output, fake_labels);
g_loss.backward();
generator_optimizer.step();
//print the status
std::printf(
"\r[%2ld/%2ld][%3ld/%3ld] D_loss: %.4f | G_loss: %.4f",
epoch,
kNumberOfEpochs,
++batch_index,
batches_per_epoch,
d_loss.item<float>(),
g_loss.item<float>());
//save current model
if (batch_index % kCheckpointEvery == 0) {
// Checkpoint the model and optimizer state.
torch::save(generator, "generator-checkpoint.pt");
torch::save(generator_optimizer, "generator-optimizer-checkpoint.pt");
torch::save(discriminator, "discriminator-checkpoint.pt");
torch::save(discriminator_optimizer, "discriminator-optimizer-checkpoint.pt");
// Sample the generator and save the images.
torch::Tensor samples = generator->forward(torch::randn({ 8, kNoiseSize, 1, 1 }));
torch::save((samples + 1.0) / 2.0, torch::str("dcgan-sample-", checkpoint_counter, ".pt"));
std::cout << "\n-> checkpoint " << ++checkpoint_counter << '\n';
}
}
}
}