街景字符編碼識別之模型訓練與驗證

點贊再看,養成習慣!覺得不過癮的童鞋,歡迎關注公衆號《機器學習算法工程師》,有非常多大神的乾貨文章可供學習噢…

前言

這篇文章彙總這個系列之前的博客的所有代碼(沒有看過的童鞋,最好去看看),小編按照自己構思的比較不錯的設計架構來組織代碼,把數據預處理模塊(data_precessing)、模型模塊(model)、工具模塊(utils)等等分開,各自封裝成類,然後處理任務的流程在main模塊中實現。如下圖所示:
架構

正文

數據預處理模塊

這個模塊存放的是自定義的數據集類

class SVHNDataset(Dataset):
    def __init__(self, img_path, img_label, transform=None):
        self.img_path = img_path
        self.img_label = img_label
        if transform is not None:
            self.transform = transform
        else:
            self.transform = None

    def __getitem__(self, index):
        # just handle one data
        img = Image.open(self.img_path[index]).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)

        # 定長字符識別策略,填充的字符爲10,這樣不會與有效字符0-9發生碰撞
        lbl = np.array(self.img_label[index], dtype=np.int)
        lbl = list(lbl) + (6 - len(lbl)) * [10]

        return img, torch.from_numpy(np.array(lbl))

    def __len__(self):
        return len(self.img_path)

模型模塊

這個模塊存放的是機器學習模型,這裏有兩個:自定義的神經網絡(用於學習目的);繼承預訓練模型的神經網絡(實操用的)。這裏給出後者的代碼,包括模型結構、訓練、驗證以及預測的功能。

class SVHN_Model2(nn.Module):
    def __init__(self):
        super(SVHN_Model2, self).__init__()

        # 繼承resnet18
        model_conv = models.resnet18(pretrained=True)
        # 將resnet18的最後一個池化層修改爲自適應的全局平均池化層
        model_conv.avgpool = nn.AdaptiveAvgPool2d(1)
        # 微調,把fc層刪除
        model_conv = nn.Sequential(*list(model_conv.children())[:-1])
        self.cnn = model_conv
        # 自定義fc層
        self.fc1 = nn.Linear(512, 11)
        self.fc2 = nn.Linear(512, 11)
        self.fc3 = nn.Linear(512, 11)
        self.fc4 = nn.Linear(512, 11)
        self.fc5 = nn.Linear(512, 11)
        self.fc6 = nn.Linear(512, 11)

    def forward(self, img):
        feat = self.cnn(img)
        # print(feat.shape)
        feat = feat.view(feat.shape[0], -1)
        c1 = self.fc1(feat)
        c2 = self.fc2(feat)
        c3 = self.fc3(feat)
        c4 = self.fc4(feat)
        c5 = self.fc5(feat)
        c6 = self.fc6(feat)
        return c1, c2, c3, c4, c5, c6

    def mytraining(self, train_loader, criterion, optimizer, device=torch.device('cpu')):
        # 切換模型爲訓練模式
        self.train()
        train_loss = []

        for i, (data, label) in enumerate(train_loader):
            c0, c1, c2, c3, c4, c5 = self(data.to(device))
            label = label.long().to(device)
            loss = criterion(c0, label[:, 0]) + \
                   criterion(c1, label[:, 1]) + \
                   criterion(c2, label[:, 2]) + \
                   criterion(c3, label[:, 3]) + \
                   criterion(c4, label[:, 4]) + \
                   criterion(c5, label[:, 5])
            loss /= 6
            train_loss.append(loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        return round(np.mean(train_loss),4)

    def myvalidating(self, val_loader, criterion, device=torch.device('cpu')):
        # 切換模型爲預測模型
        self.eval()
        val_loss = []

        # 不記錄模型梯度信息
        with torch.no_grad():
            for i, (data, label) in enumerate(val_loader):
                c0, c1, c2, c3, c4, c5 = self(data.to(device))
                label = label.long().to(device)
                loss = criterion(c0, label[:, 0]) + \
                       criterion(c1, label[:, 1]) + \
                       criterion(c2, label[:, 2]) + \
                       criterion(c3, label[:, 3]) + \
                       criterion(c4, label[:, 4]) + \
                       criterion(c5, label[:, 5])
                loss /= 6
                val_loss.append(loss.item())

        return round(np.mean(val_loss),4)

    def myPredicting(self, test_loader, device=torch.device('cpu')):
        # 切換模型爲預測模型
        self.eval()
        is_init = True

        # 不記錄模型梯度信息
        with torch.no_grad():
            for i, (data, label) in enumerate(test_loader):
                c0, c1, c2, c3, c4, c5 = self(data)
                l0 = np.reshape(c0.numpy().argmax(axis=1),(-1,1)) #100x11->100x1
                l1 = np.reshape(c1.numpy().argmax(axis=1),(-1,1)) #100x11->100x1
                l2 = np.reshape(c2.numpy().argmax(axis=1),(-1,1)) #100x11->100x1
                l3 = np.reshape(c3.numpy().argmax(axis=1),(-1,1)) #100x11->100x1
                l4 = np.reshape(c4.numpy().argmax(axis=1),(-1,1)) #100x11->100x1
                l5 = np.reshape(c5.numpy().argmax(axis=1),(-1,1)) #100x11->100x1
                # 合併->100x6
                tmp = np.concatenate((l0,l1,l2,l3,l4,l5),axis=1)
                if is_init:
                    pred_labels=tmp
                    is_init=False
                else:
                    pred_labels = np.concatenate((pred_labels,tmp),axis=0)

        return pred_labels

工具模塊

這個模塊包含一些零零散散地工具方法,都是些靜態方法,有數據導入、結果保存等功能。

class Tools:
    @staticmethod
    def dataFromPath(img_path,label_path=None):
        imgs = glob.glob(img_path)
        imgs.sort()
        if label_path:
            label_json = json.load(open(label_path))
            labels = [label_json[x]['label'] for x in label_json]
        else: #製作假的測試集標籤
            labels = [[10]]*len(imgs)
        return imgs,labels


    @staticmethod
    def calAcc(pred_label,true_label):
        length = len(true_label)
        count = 0
        for i in range(length):
            for j in range(len(true_label[i])):
                if true_label[i][j]==pred_label[i][j] or true_label[i][j]==10:
                    if true_label[i][j]==10:
                        count+=1
                        break
                else:
                    break
        return round(count/length,4)*100

    @staticmethod
    def printInfo(epoch,train_loss,val_loss,
                  best_epoch,best_val_loss,
                  train_acc='--',val_acc='--',best_val_acc='--'):
        print("epoch {}: train_loss {}, train_acc {}; val_loss {}, val_acc {}; " 
              "(best_epoch,best_val_loss,best_val_acc):({},{},{})".format(
            epoch,train_loss,train_acc,val_loss,val_acc,best_epoch,best_val_loss,best_val_acc))

    @staticmethod
    def submit(demo_submit_path,pred_labels,out_path='Submit_files/'):
        submit = pd.read_csv(demo_submit_path)
        pred_result = []
        for label in pred_labels:
            tmp = []
            for char in label:
                if char!=10:
                    tmp.append(char)
                else:
                    break
            # 意外情況,沒有有效字符,默認填充0
            if not tmp:
                tmp.append(0)
            pred_result.append("".join(map(str,tmp)))
        # 填充到pd表格
        submit['file_code'] = pred_result
        # 保存爲文件submit.csv
        out_path += "submit.csv"
        submit.to_csv(out_path,index=False)

主模塊

這個模塊是程序入口,顯式實現了整個字符識別任務地處理邏輯。

if __name__=='__main__':
    # 配置環境
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # 初始化參數
    train_img_path = r'E:\Datas\StreetCharsRecognition\mchar_train\*.png'
    train_label_path = r'E:\Datas\StreetCharsRecognition\mchar_train.json'
    val_img_path = r'E:\Datas\StreetCharsRecognition\mchar_val\*.png'
    val_label_path = r'E:\Datas\StreetCharsRecognition\mchar_val.json'
    test_img_path = r'E:\Datas\StreetCharsRecognition\mchar_test_a\*.png'
    demo_submit_path = r'E:\Datas\StreetCharsRecognition\mchar_sample_submit_A.csv'
    batch_size = 100
    epochs = 20
    lr = .001
    is_predicting = False #默認is_predicting=False, 表明爲訓練過程

    # 訓練過程
    if not is_predicting:
        # 加載數據
        train_path,train_label = Tools.dataFromPath(train_img_path,train_label_path)
        train_dataset = SVHNDataset(train_path, train_label,
                    transforms.Compose([
                        transforms.Resize((64, 128)),
                        transforms.ColorJitter(0.3, 0.3, 0.2),
                        transforms.RandomRotation(5),
                        transforms.ToTensor(),
                        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                    ]))
        val_path,val_label = Tools.dataFromPath(val_img_path,val_label_path)
        val_dataset = SVHNDataset(val_path, val_label,
                    transforms.Compose([
                        transforms.Resize((64, 128)),
                        transforms.ColorJitter(0.3, 0.3, 0.2),
                        transforms.RandomRotation(5),
                        transforms.ToTensor(),
                        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                    ]))
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=5,
        )

        val_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=5,
        )

        # 創建模型
        model = SVHN_Model1().to(device)
        criterion = nn.CrossEntropyLoss(reduction='sum')
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        best_epoch, best_loss, best_acc = -1, 1000.0, 0

        # 模型訓練,並保存最優參數
        for epoch in range(epochs):
            train_loss = model.mytraining(train_loader, criterion, optimizer, device)
            val_loss = model.myvalidating(val_loader, criterion, device)
            # 記錄下驗證集精度
            if val_loss < best_loss:
                best_epoch, best_loss = epoch, val_loss
                # 保存model可學習參數
                torch.save(model.state_dict(), 'Model/model.pt')
            # 打印相關信息
            Tools.printInfo(epoch, train_loss, val_loss,
                            best_epoch, best_loss)
    else:
        # 預測過程
        test_path, test_label = Tools.dataFromPath(test_img_path)
        test_dataset = SVHNDataset(test_path, test_label,
                                    transforms.Compose([
                                        transforms.Resize((64, 128)),
                                        transforms.ColorJitter(0.3, 0.3, 0.2),
                                        transforms.RandomRotation(5),
                                        transforms.ToTensor(),
                                        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                    ]))
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=5,
        )
        model = SVHN_Model1()
        model.load_state_dict(torch.load("Model/model.pt", map_location='cpu'))
        pred_labels = model.myPredicting(test_loader)
        Tools.submit(demo_submit_path,pred_labels)

其他

一些自定義地文件目錄,Model目錄存放訓練過程中最優地模型參數,Submit_files目錄存放滿足可提交格式地預測結果csv文件。

結語

這篇文章貼的代碼已經是一份完整的代碼啦,有需要地可以去參考文獻中的github鏈接下載代碼。這份代碼,小編近期將會持續更新,還有很多沒講到地知識點噢。

參考文獻

  1. https://github.com/Ggmatch/CV_StreetCharsRecognition

童鞋們,讓小編聽見你們的聲音,點贊評論,一起加油。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章