《CRNN_training代碼解析》

在進行CRNN的訓練時,在training過程中遇到了很多問題,主要是對training中處理數據的具體細節不是很瞭解,而不瞭解就會帶來困擾,下面是我自己弄的一個小demo,通過調試終於慢慢的理解了細節。下面是我的捋的思路:

圖片

下面是對應的demo:

from __future__ import print_function
import argparse
import random
import torch
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
from torch.autograd import Variable
import numpy as np
from warpctc_pytorch import CTCLoss
import os
import utils
import dataset
import models.crnn as crnn
from alphabets import alphabet

parser = argparse.ArgumentParser()
parser.add_argument('--trainroot', default='/home/gobills/PycharmProjects/crnn_chinese_characters_rec/to_lmdb/train/',
                    help='path to dataset')
parser.add_argument('--valroot', default='/home/gobills/PycharmProjects/crnn_chinese_characters_rec/to_lmdb/val/',
                    help='path to dataset')
parser.add_argument('--workers', type=int, help='number of data loading workers', default=0)
parser.add_argument('--batchSize', type=int, default=64, help='input batch size')
parser.add_argument('--imgH', type=int, default=32, help='the height of the input image to network')
parser.add_argument('--imgW', type=int, default=280, help='the width of the input image to network')
parser.add_argument('--nh', type=int, default=256, help='size of the lstm hidden state')
parser.add_argument('--niter', type=int, default=25, help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.00001, help='learning rate for Critic, default=0.00005')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
parser.add_argument('--cuda', default=False, help='enables cuda')
parser.add_argument('--ngpu', type=int, default=0, help='number of GPUs to use')
parser.add_argument('--crnn', default='', help="")
parser.add_argument('--alphabet', type=str, default=alphabet)
parser.add_argument('--experiment', default=None, help='Where to store samples and models')
parser.add_argument('--displayInterval', type=int, default=1000, help='Interval to be displayed')
parser.add_argument('--n_test_disp', type=int, default=10, help='Number of samples to display when test')
parser.add_argument('--valInterval', type=int, default=1000, help='Interval to be displayed')
parser.add_argument('--saveInterval', type=int, default=1000, help='Interval to be displayed')
parser.add_argument('--adam', action='store_true', help='Whether to use adam (default is rmsprop)')
parser.add_argument('--adadelta', action='store_true', help='Whether to use adadelta (default is rmsprop)')
parser.add_argument('--keep_ratio', action='store_true', help='whether to keep ratio for image resize')
parser.add_argument('--random_sample', action='store_true', default=True,
                    help='whether to sample the dataset with random sampler')
opt = parser.parse_args()
print(opt)


# custom weights initialization called on crnn
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

image = torch.FloatTensor(opt.batchSize, 3, opt.imgH, opt.imgH)
text = torch.IntTensor(opt.batchSize * 5)
length = torch.IntTensor(opt.batchSize)

image = Variable(image)
text = Variable(text)
length = Variable(length)

loss_avg = utils.averager()

# 加載lmdb文件
train_dataset = dataset.lmdbDataset(root=opt.trainroot)
print('length of train_dataset: ', len(train_dataset))

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=opt.batchSize,
    shuffle=True, sampler=None,
    num_workers=int(opt.workers),
    collate_fn=dataset.alignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio=opt.keep_ratio))
print('length of train_loader: ', len(train_loader))

nclass = len(opt.alphabet) + 1
nc = 1

converter = utils.strLabelConverter(opt.alphabet)
criterion = CTCLoss()

train_iter = iter(train_loader)
print('length of train_iter: ', len(train_iter))

data = train_iter.next()

# 這裏是data是從lmdb中來的,cpu_text是{str}=漢字
cpu_images, cpu_texts = data
print('cpu_texts', cpu_texts)

batch_size = cpu_images.size(0)
print('batch_size: ', batch_size)
utils.loadData(image, cpu_images)

# 漢字-->index
t, l = converter.encode(cpu_texts)
utils.loadData(text, t)
utils.loadData(length, l)

crnn = crnn.CRNN(opt.imgH, nc, nclass, opt.nh)
crnn.apply(weights_init)

# 返回的preds也是index值
preds = crnn(image)
preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size))
cost = criterion(preds, text, preds_size, length) / batch_size
loss_avg.add(cost)

_, preds = preds.max(2)
# preds = preds.squeeze(2)
preds = preds.transpose(1, 0).contiguous().view(-1)

# index-->漢字
sim_preds = converter.decode(preds.data, preds_size.data, raw=False)
print(sim_preds)
print(cpu_texts)
n_correct = 0

# 這裏比較的是每串字符串是否相等--只有每個字符串的十個字符都相等纔是相等此時n_correct+1
for pred, target in zip(sim_preds, cpu_texts):
    if (pred == target.lower()) | (pred == target):
        n_correct += 1

# 這個顯示的是CTC處理之前的CRNN結果顯示的字符串:帶‘---’的
raw_preds = converter.decode(preds.data, preds_size.data, raw=True)[:opt.n_test_disp]

for raw_pred, pred, gt in zip(raw_preds, sim_preds, cpu_texts):
    print('%-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt))

accuracy = n_correct / float(100 * opt.batchSize)
print('Test loss: %f, accuray: %f' % (loss_avg.val(), accuracy))

運行結果:

參考的是github上的項目:https://github.com/Aurora11111/crnn-train-pytorch

其中的alphabets可以根據自己的情況調整。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章