在之前的文章中講的AlexNet、VGG、GoogLeNet以及ResNet網絡,它們都是傳統卷積神經網絡(都是使用的傳統卷積層),缺點在於內存需求大、運算量大導致無法在移動設備以及嵌入式設備上運行。而本文要講的MobileNet網絡就是專門爲移動端,嵌入式端而設計。
MobileNet v1
MobileNet網絡是由google團隊在2017年提出的,專注於移動端或者嵌入式設備中的輕量級CNN網絡。相比傳統卷積神經網絡,在準確率小幅降低的前提下大大減少模型參數與運算量。(相比VGG16準確率減少了0.9%,但模型參數只有VGG的1/32)。
要說MobileNet網絡的優點,無疑是其中的Depthwise Convolution結構(大大減少運算量和參數數量)。下圖展示了傳統卷積與DW卷積的差異,在傳統卷積中,每個卷積核的channel與輸入特徵矩陣的channel相等(每個卷積核都會與輸入特徵矩陣的每一個維度進行卷積運算)。
而在DW卷積中,每個卷積核的channel都是等於1的(每個卷積核只負責輸入特徵矩陣的一個channel,故卷積核的個數必須等於輸入特徵矩陣的channel數,從而使得輸出特徵矩陣的channel數也等於輸入特徵矩陣的channel數)
剛剛說了使用DW卷積後輸出特徵矩陣的channel是與輸入特徵矩陣的channel相等的,如果想改變/自定義輸出特徵矩陣的channel,那只需要在DW卷積後接上一個PW卷積即可.
如下圖所示,其實PW卷積就是普通的卷積而已(只不過卷積核大小爲1)。通常DW卷積和PW卷積是放在一起使用的,一起叫做Depthwise Separable Convolution(深度可分卷積)。
那Depthwise Separable Convolution(深度可分卷積)與傳統的卷積相比有到底能節省多少計算量呢,下圖對比了這兩個卷積方式的計算量,其中Df是輸入特徵矩陣的寬高(這裏假設寬和高相等),Dk是卷積核的大小,M是輸入特徵矩陣的channel,N是輸出特徵矩陣的channel,卷積計算量近似等於卷積核的高 x 卷積核的寬 x 卷積核的channel x 輸入特徵矩陣的高 x 輸入特徵矩陣的寬(這裏假設stride等於1),在我們mobilenet網絡中DW卷積都是是使用3x3大小的卷積核。所以理論上普通卷積計算量是DW+PW卷積的8到9倍(公式來源於原論文):
在瞭解完Depthwise Separable Convolution(深度可分卷積)後在看下mobilenet v1的網絡結構,左側的表格是mobileNetv1的網絡結構,表中標Conv的表示普通卷積,Conv dw代表剛剛說的DW卷積,s表示步距,根據表格信息就能很容易的搭建出mobileNet v1網絡。
在mobilenetv1原論文中,還提出了兩個超參數,一個是α一個是β。
寬度因子
爲了構造這些結構更小且計算量更小的模型,我們引入了一個參數α,稱爲寬度因子。寬度因子α的作用是在每層均勻地稀疏網絡,爲每層通道乘以一定的比例,從而減少各層的通道數。常用值有1、0.75、0.5、0.25。
分辨率因子
爲了減少計算量,引入了第二個參數ρ,稱爲分辨率因子。其作用是在每層特徵圖的大小乘以一定的比例。
下圖右側給出了使用不同α和β網絡的分類準確率,計算量以及模型參數:
MobileNet v2
在MobileNet v1的網絡結構表中能夠發現,網絡的結構就像VGG一樣是個直筒型的,不像ResNet網絡有shorcut之類的連接方式。而且有人反映說MobileNet v1網絡中的DW卷積很容易訓練廢掉,效果並沒有那麼理想。所以我們接着看下MobileNet v2網絡。
MobileNet v2網絡是由google團隊在2018年提出的,相比MobileNet V1網絡,準確率更高,模型更小。
MobileNet v2 模型的特點:
如上圖,mobileNet v2在V1基礎上進行了改進。
剛剛說了MobileNet v1網絡中的亮點是DW卷積,那麼在MobileNet v2中的亮點就是Inverted residual block(倒殘差結構),同時分析了v1的幾個缺點並針對性的做了改進。v2的改進策略非常簡單,但是在編寫論文時,缺點分析的時候涉及了流行學習等內容,將優化過程弄得非常難懂。我們在這裏簡單總結一下v2中給出的問題分析,希望能對論文的閱讀有所幫助,對v2的motivation感興趣的同學推薦閱讀論文。
當我們單獨去看Feature Map的每個通道的像素的值的時候,其實這些值代表的特徵可以映射到一個低維子空間的一個流形區域上。在進行完卷積操作之後往往會接一層激活函數來增加特徵的非線性性,一個最常見的激活函數便是ReLU。根據我們在殘差網絡中介紹的數據處理不等式(DPI),ReLU一定會帶來信息損耗,而且這種損耗是沒有辦法恢復的,ReLU的信息損耗是當通道數非常少的時候更爲明顯。爲什麼這麼說呢?我們看圖6中這個例子,其輸入是一個表示流形數據的矩陣,和卷機操作類似,他會經過 n個ReLU的操作得到 n個通道的Feature Map,然後我們試圖通過這n個Feature Map還原輸入數據,還原的越像說明信息損耗的越少。從圖6中我們可以看出,當 n的值比較小時,ReLU的信息損耗非常嚴重,但是當n 的值比較大的時候,輸入流形就能還原的很好了。
根據對上面提到的信息損耗問題分析,我們可以有兩種解決方案:
- 既然是ReLU導致的信息損耗,那麼我們就將ReLU替換成線性激活函數;
- 如果比較多的通道數能減少信息損耗,那麼我們就使用更多的通道。
如下下圖所示,左側是ResNet網絡中的殘差結構,右側就是MobileNet v2中的到殘差結構。
在殘差結構中是1x1卷積降維->3x3卷積->1x1卷積升維,在倒殘差結構中正好相反,是1x1卷積升維->3x3DW卷積->1x1卷積降維。爲什麼要這樣做,原文的解釋是高維信息通過ReLU激活函數後丟失的信息更少(注意倒殘差結構中基本使用的都是ReLU6激活函數,但是最後一個1x1的卷積層使用的是線性激活函數)。
在使用倒殘差結構時需要注意下,並不是所有的倒殘差結構都有shortcut連接,只有當stride=1且輸入特徵矩陣與輸出特徵矩陣shape相同時纔有shortcut連接(只有當shape相同時,兩個矩陣才能做加法運算,當stride=1時並不能保證輸入特徵矩陣的channel與輸出特徵矩陣的channel相同)。
下圖是MobileNet v2網絡的結構表,其中t代表的是擴展因子(倒殘差結構中第一個1x1卷積的擴展因子),c代表輸出特徵矩陣的channel,n代表倒殘差結構重複的次數,s代表步距(注意:這裏的步距只是針對重複n次的第一層倒殘差結構,後面的都默認爲1)。
一些問題
- MobileNet V2中的bottleneck爲什麼先擴張通道數在壓縮通道數呢?
因爲MobileNet 網絡結構的核心就是Depth-wise,此卷積方式可以減少計算量和參數量。而爲了引入shortcut結構,若參照Resnet中先壓縮特徵圖的方式,將使輸入給Depth-wise的特徵圖大小太小,接下來可提取的特徵信息少,所以在MobileNet V2中採用先擴張後壓縮的策略。
- MobileNet V2中的bottleneck爲什麼在1*1卷積之後使用Linear激活函數?
因爲在激活函數之前,已經使用1*1卷積對特徵圖進行了壓縮,而ReLu激活函數對於負的輸入值,輸出爲0,會進一步造成信息的損失,所以使用Linear激活函數。
3. 總結
在這篇文章中,我們介紹了兩個版本的MobileNet,它們和傳統卷積的對比如下。
如圖(b)所示,MobileNet v1最主要的貢獻是使用了Depthwise Separable Convolution,它又可以拆分成Depthwise卷積和Pointwise卷積。MobileNet v2主要是將殘差網絡和Depthwise Separable卷積進行了結合。通過分析單通道的流形特徵對殘差塊進行了改進,包括對中間層的擴展(d)以及bottleneck層的線性激活©。Depthwise Separable Convolution的分離式設計直接將模型壓縮了8倍左右,但是精度並沒有損失非常嚴重,這一點還是非常震撼的。
Depthwise Separable卷積的設計非常精彩但遺憾的是目前cudnn對其的支持並不好,導致在使用GPU訓練網絡過程中我們無法從算法中獲益,但是使用串行CPU並沒有這個問題,這也就給了MobileNet很大的市場空間,尤其是在嵌入式平臺。
最後,不得不承認v2的論文的一系列證明非常精彩,雖然沒有這些證明我們也能明白v2的工作原理,但是這些證明過程還是非常值得仔細品鑑的,尤其是對於從事科研方向的工作人員。
代碼
注:
- 本次訓練集下載在AlexNet博客有詳細解說:https://blog.csdn.net/weixin_44023658/article/details/105798326
- 使用遷移學習方法實現收錄在我的這篇blog中: 遷移學習 TransferLearning—通俗易懂地介紹(pytorch實例)
#model.py
from torch import nn
import torch
def _make_divisible(ch, divisor=8, min_ch=None):
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
"""
if min_ch is None:
min_ch = divisor
new_ch = max(min_ch, int(ch + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_ch < 0.9 * ch:
new_ch += divisor
return new_ch
class ConvBNReLU(nn.Sequential):
def __init__(self, in_channel, out_channel, kernel_size=3, stride=1, groups=1):#groups=1普通卷積
padding = (kernel_size - 1) // 2
super(ConvBNReLU, self).__init__(
nn.Conv2d(in_channel, out_channel, kernel_size, stride, padding, groups=groups, bias=False),
nn.BatchNorm2d(out_channel),
nn.ReLU6(inplace=True)
)
#到殘差結構
class InvertedResidual(nn.Module):
def __init__(self, in_channel, out_channel, stride, expand_ratio):#expand_ratio擴展因子
super(InvertedResidual, self).__init__()
hidden_channel = in_channel * expand_ratio
self.use_shortcut = stride == 1 and in_channel == out_channel
layers = []
if expand_ratio != 1:
# 1x1 pointwise conv
layers.append(ConvBNReLU(in_channel, hidden_channel, kernel_size=1))
layers.extend([
# 3x3 depthwise conv
ConvBNReLU(hidden_channel, hidden_channel, stride=stride, groups=hidden_channel),
# 1x1 pointwise conv(linear)
nn.Conv2d(hidden_channel, out_channel, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channel),
])
self.conv = nn.Sequential(*layers)
def forward(self, x):
if self.use_shortcut:
return x + self.conv(x)
else:
return self.conv(x)
class MobileNetV2(nn.Module):
def __init__(self, num_classes=1000, alpha=1.0, round_nearest=8):#alpha超參數
super(MobileNetV2, self).__init__()
block = InvertedResidual
input_channel = _make_divisible(32 * alpha, round_nearest)
last_channel = _make_divisible(1280 * alpha, round_nearest)
inverted_residual_setting = [
# t, c, n, s
[1, 16, 1, 1],
[6, 24, 2, 2],
[6, 32, 3, 2],
[6, 64, 4, 2],
[6, 96, 3, 1],
[6, 160, 3, 2],
[6, 320, 1, 1],
]
features = []
# conv1 layer
features.append(ConvBNReLU(3, input_channel, stride=2))
# building inverted residual residual blockes
for t, c, n, s in inverted_residual_setting:
output_channel = _make_divisible(c * alpha, round_nearest)
for i in range(n):
stride = s if i == 0 else 1
features.append(block(input_channel, output_channel, stride, expand_ratio=t))
input_channel = output_channel
# building last several layers
features.append(ConvBNReLU(input_channel, last_channel, 1))
# combine feature layers
self.features = nn.Sequential(*features)
# building classifier
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.classifier = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(last_channel, num_classes)
)
# weight initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.zeros_(m.bias)
def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
#train.py
import torch
import torch.nn as nn
from torchvision import transforms, datasets
import json
import os
import torch.optim as optim
from model import MobileNetV2
import torchvision.models.mobilenet
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
data_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
"val": transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
data_root = os.path.abspath(os.path.join(os.getcwd(), "../../..")) # get data root path
image_path = data_root + "/data_set/flower_data/" # flower data set path
train_dataset = datasets.ImageFolder(root=image_path+"train",
transform=data_transform["train"])
train_num = len(train_dataset)
# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
# write dict into json file
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)
batch_size = 16
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size, shuffle=True,
num_workers=0)
validate_dataset = datasets.ImageFolder(root=image_path + "val",
transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,
batch_size=batch_size, shuffle=False,
num_workers=0)
net = MobileNetV2(num_classes=5)
# load pretrain weights
model_weight_path = "./mobilenet_v2.pth"
pre_weights = torch.load(model_weight_path)
# delete classifier weights
pre_dict = {k: v for k, v in pre_weights.items() if "classifier" not in k}
missing_keys, unexpected_keys = net.load_state_dict(pre_dict, strict=False)
# freeze features weights
for param in net.features.parameters():
param.requires_grad = False
net.to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.0001)
best_acc = 0.0
save_path = './MobileNetV2.pth'
for epoch in range(5):
# train
net.train()
running_loss = 0.0
for step, data in enumerate(train_loader, start=0):
images, labels = data
optimizer.zero_grad()
logits = net(images.to(device))
loss = loss_function(logits, labels.to(device))
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
# print train process
rate = (step+1)/len(train_loader)
a = "*" * int(rate * 50)
b = "." * int((1 - rate) * 50)
print("\rtrain loss: {:^3.0f}%[{}->{}]{:.4f}".format(int(rate*100), a, b, loss), end="")
print()
# validate
net.eval()
acc = 0.0 # accumulate accurate number / epoch
with torch.no_grad():
for val_data in validate_loader:
val_images, val_labels = val_data
outputs = net(val_images.to(device)) # eval model only have last output layer
# loss = loss_function(outputs, test_labels)
predict_y = torch.max(outputs, dim=1)[1]
acc += (predict_y == val_labels.to(device)).sum().item()
val_accurate = acc / val_num
if val_accurate > best_acc:
best_acc = val_accurate
torch.save(net.state_dict(), save_path)
print('[epoch %d] train_loss: %.3f test_accuracy: %.3f' %
(epoch + 1, running_loss / step, val_accurate))
print('Finished Training')
#pridict.py
import torch
from model import MobileNetV2
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import json
data_transform = transforms.Compose(
[transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
# load image
img = Image.open("sunflower.jpg")
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)
# read class_indict
try:
json_file = open('./class_indices.json', 'r')
class_indict = json.load(json_file)
except Exception as e:
print(e)
exit(-1)
# create model
model = MobileNetV2(num_classes=5)
# load model weights
model_weight_path = "./MobileNetV2.pth"
model.load_state_dict(torch.load(model_weight_path))
model.eval()
with torch.no_grad():
# predict class
output = torch.squeeze(model(img))
predict = torch.softmax(output, dim=0)
predict_cla = torch.argmax(predict).numpy()
print(class_indict[str(predict_cla)], predict[predict_cla].numpy())
plt.show()