Pytorch學習(二) --- 模型定義之torchvivsion.models快速構建預訓練模型

torchvision.model是torchvision一個很重要的包,裏面包含了以下模型結構:

  • AlexNet
  • VGG
  • ResNet
  • SqueezeNet
  • DenseNet

並且提供了預訓練模型,可以通過簡單調用來讀取網絡結構和預訓練模型。
在進行深度學習的圖像分類任務時,我們可以利用torchvision.model這個包快速構建模型,做適當調整即可運用於分類訓練。

使用例子1

import torchvision.models as models
# pretrained=True就可以使用預訓練的模型
resnet18 = models.resnet18(pretrained=True)
alexnet = models.alexnet(pretrained=True)

上述就創建了resnet18和alexnet兩個模型,並且用預訓練模型的參數來初始化。
如果不需要預訓練模型,那麼pretrained=False即可。如下:

import torchvision.models as models
# pretrained=True就可以使用預訓練的模型
resnet18 = models.resnet18(pretrained=False)
alexnet = models.alexnet(pretrained=False)

不過,構建的這些模型,都是在imagenet上訓練得到的,他們的默認輸出類別數是1000,那如果我們需要訓練自己的數據,並且數據類別數目不是1000時,我們需要在最後一層微調。

使用例子2

class ResNet_101(nn.Module):
    def __init__(self, num_classes):
        super(ResNet_101, self).__init__()
        model = models.resnet101(pretrained=True)
        model.fc = nn.Sequential(
                nn.Linear(2048, num_classes, bias=True),
        )
        self.net = model
    
    def forward(self, img):
        output = self.net(img)
        return output

下面代碼就構建了用預訓練模型進行參數初始化的輸出類別數目爲3的resnet101模型。

model = ResNet_101(num_classes=3)

從torchvision 0.3.0開始,torchvision.models中就集成了目標檢測、分割、關鍵點檢測的models。

Semantic Segmentation:

  • FCN ResNet101
  • DeepLabV3 ResNet101

Object Detection:

  • Faster R-CNN ResNet-50 FPN

Instance Segmentation:

  • Mask R-CNN ResNet-50 FPN

Person Keypoint Detection:

  • keypointrcnn_resnet50_fpn

https://pytorch.org/docs/stable/torchvision/models.html#

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章