Pytorch實戰之驗證碼識別

驗證碼識別與之前的幾個任務不同,這是一個多標籤的分類的任務,也就是是一個數據對應着幾個標籤,只有所用的標籤都預測對時,纔算真正的預測成功了。


一. 數據的準備工作

  1. 與以往不同,這次的數據,我們是利用python的第三方庫來生成驗證碼圖片,下面進行代碼演示,非常簡單。
a = ['1','2','3','4']    
img = ImageCaptcha()
captcha=img.generate(a)   #生成圖片,根據a中的內容
captcha_image = PIL.Image.open(captcha)    #讀取圖片
captcha_image.show()				#顯示圖片

在這裏插入圖片描述
2. 基於上面的代碼我們可以很輕鬆裝備幾萬張數據,然後用一個Excel保存圖片以及圖片對應的標籤,這裏就不作代碼展示了,展示一下Excel文件。
在這裏插入圖片描述
3. 下面開始讀取我們的Excel文件,來構造我們的數據集,標籤有4個,這裏我們需要one-hot編碼一下,弄成長度爲40的向量,這也是一個需要特別需要注意的地方。

  • . 讀取csv文件
def read_data():
    data = pd.read_csv("qwe.csv")
    img_path = data["ID"].values
    label = data.iloc[:,data.columns!="ID"].values
    y = []
    for x in label:
        t = one_hot(x)
        y.append(np.array(t))
    return img_path,np.array(y)
  • 進行one-hot編碼
def one_hot(x):
    tmp = [0 for i in range(40)]
    for step,i in enumerate(x):
        tmp[i+10*step] = 1
    return tmp
  • 最後構造DataLoader,與顯示最後的標籤形式,到這裏數據的準備工作就基本上完成了。
class DataSet(Dataset):
    def __init__(self):
        self.img_path,self.label = read_data()
    def __getitem__(self, index):
        img_path = self.img_path[index]
        img = cv2.imread(img_path,0)
        img = img/255.
        img = torch.from_numpy(img).float()
        img = torch.unsqueeze(img,0)
        label = torch.from_numpy(self.label[index]).float()
        return img,label
    def __len__(self):
        return len(self.img_path)
data = DataSet()
data_loader = DataLoader(data,shuffle=True,batch_size=64,drop_last=True)

在這裏插入圖片描述


二. 網絡的構建與優化、損失函數的選取以及訓練

  1. 網絡的構建和優化函數在這裏就不做多的說明了,直接看代碼。
class CNN_Network(nn.Module):
    def __init__(self):
        super(CNN_Network, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, stride=1, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True)
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, stride=1, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(stride=2, kernel_size=2),  # 30 80
        )
        self.layer3 = nn.Sequential(
            nn.Conv2d(32, 64, stride=1, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64,128,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2,stride=2),   # 15 40
        )

        self.fc = nn.Sequential(
            nn.Linear(128 * 15 * 40, 2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 40)
        )

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x
model = CNN_Network()
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
  1. 損失函數我們這裏使用的是多標籤分類的損失函數,和交叉熵損失函數的公式比較相像,loss(x,y)=iy[i]log(11+exp(x[i])+(1y[i])log(exp(x[i])1+exp(x[i])) loss(x,y)=−∑iy[i]∗log(11+exp(−x[i])+(1−y[i])∗log(exp(−x[i])1+exp(−x[i])),上面是公式,有興趣的可以自己研究一下。
error = nn.MultiLabelSoftMarginLoss()   #注意輸入的數據要是float32類型的,否則會出錯。
  1. 所有的準備工作都完成了,下面就開始訓練吧。
for i in range(2):
    for x_index,y in data_loader:
        pass
        x = Variable(x_index)

        optimizer.zero_grad()
        label = Variable(y)

        out = model(x)
        loss = error(out,label)
        print(loss)
        loss.backward()
        optimizer.step()
torch.save(model.state_dict(),"驗證碼識別.pth")

三. 測試模型

  1. 訓練完成後,來測試一下我們的模型吧.
cnn = CNN_Network()
cnn.load_state_dict(torch.load("驗證碼識別.pth"))


a = cv2.imread("./data/9354.jpg",0)
b = cv2.resize(a,(200,200))
cv2.imshow('a',b)
cv2.waitKey(0)
a = a/255.
a = torch.from_numpy(a).float()
a = torch.unsqueeze(a,0)
a = torch.unsqueeze(a,0)
pred = cnn(a)
print(pred.size())
a1 = torch.argmax(pred[0,:10],dim=0)    #第一個標籤
a2 = torch.argmax(pred[0,10:20],dim=0)	#第二個標籤
a3 = torch.argmax(pred[0,20:30],dim=0)	#第三個標籤
a4 = torch.argmax(pred[0,30:],dim=0)	#第四的標籤
pred = [a1,a2,a3,a4]
print(pred)

預測的圖片
在這裏插入圖片描述
預測結果
在這裏插入圖片描述
上面就完成了所有的工作了。
github地址.
Thank for your reading !!!


發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章