pytorch實現帶標籤格式數據的模型訓練

1.訓練數據讀入

注:以下模擬數據,主要講解方法。

標籤數據


下面函數即爲實現標籤數據的讀入

def reader(txt):

    fh = open(txt)  
    c=0  
    imgs=[]  
    class_names=[]  
    for line in  fh.readlines():  
        if c==0:  
            class_names=[n.strip() for n in line.rstrip().split('   ')]  
        else:  
            cls = line.split()   
            fn = cls.pop(0)
            imgs.append((fn, tuple([float(v) for v in cls])))  
        c=c+1

    return class_names,imgs

其中,返回imgs是標籤元組,即[1,0,0,1],class_names爲屬性名,即sex。

如人臉特徵數據,也可以通過reader()讀入。

2.簡單模型設計(以全連層爲例)

cmodel=nn.Linear(100, 2) ,(或者nn.Sequential(nn.Linear(100, 2))

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.classify=cmodel
    def forward(self, x):
        x=self.classify(x)
        return x,

3.模型訓練

訓練集讀入

train_data_loader = torch.utils.data.DataLoader(  \
         ImageFloder(root = "./fea.txt", label = "./label.txt"), batch_size= 2, shuffle= False, num_workers= 4)

其中,root,label分別是特徵與標籤文件地址, ImageFloder類定義如下:

class ImageFloder(data.Dataset):  
    def __init__(self, root, label):

self.classes1,self.imgs1 = reader(label)
        self.classes2,self.imgs2 = reader(root)

    def __getitem__(self, index):  
        fn1, label1 = self.imgs1[index]
        fn2, label2 = self.imgs2[index]

return torch.Tensor(label1),torch.Tensor(label2)

    def __len__(self):  
        return len(self.imgs1)

訓練代碼詳見項目:

https://github.com/eeric/pytorch-model-training-label

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章