torchvision.model是torchvision一個很重要的包,裏面包含了以下模型結構:
- AlexNet
- VGG
- ResNet
- SqueezeNet
- DenseNet
- …
並且提供了預訓練模型,可以通過簡單調用來讀取網絡結構和預訓練模型。
在進行深度學習的圖像分類任務時,我們可以利用torchvision.model這個包快速構建模型,做適當調整即可運用於分類訓練。
使用例子1:
import torchvision.models as models
# pretrained=True就可以使用預訓練的模型
resnet18 = models.resnet18(pretrained=True)
alexnet = models.alexnet(pretrained=True)
上述就創建了resnet18和alexnet兩個模型,並且用預訓練模型的參數來初始化。
如果不需要預訓練模型,那麼pretrained=False
即可。如下:
import torchvision.models as models
# pretrained=True就可以使用預訓練的模型
resnet18 = models.resnet18(pretrained=False)
alexnet = models.alexnet(pretrained=False)
不過,構建的這些模型,都是在imagenet上訓練得到的,他們的默認輸出類別數是1000,那如果我們需要訓練自己的數據,並且數據類別數目不是1000時,我們需要在最後一層微調。
使用例子2:
class ResNet_101(nn.Module):
def __init__(self, num_classes):
super(ResNet_101, self).__init__()
model = models.resnet101(pretrained=True)
model.fc = nn.Sequential(
nn.Linear(2048, num_classes, bias=True),
)
self.net = model
def forward(self, img):
output = self.net(img)
return output
下面代碼就構建了用預訓練模型進行參數初始化的輸出類別數目爲3的resnet101模型。
model = ResNet_101(num_classes=3)
從torchvision 0.3.0開始,torchvision.models中就集成了目標檢測、分割、關鍵點檢測的models。
Semantic Segmentation:
- FCN ResNet101
- DeepLabV3 ResNet101
Object Detection:
- Faster R-CNN ResNet-50 FPN
Instance Segmentation:
- Mask R-CNN ResNet-50 FPN
Person Keypoint Detection:
- keypointrcnn_resnet50_fpn