every blog every motto: You will never know unless you try
0. 前言
本文旨在講解有關語義分割的基礎內容,並進行實戰。
1. 正文
1.1 前提梗概
1.1.1 引言
不要強求不可知,要從已知推未知
爲了更好的瞭解語義分割,會先簡單回顧一下作爲深度學習“Hello World ”的mnist / fashion-mnist
1.1.2 mnist / fashion_mnsit
mnist是一個手寫數字識別的數據集,一張圖像上只有一個數字,如下圖所示
上面每張圖片即是一個原圖/樣本(x),他的標籤是對應數字的字符(y / label)
主要任務是分類,即對識別每張圖片的數字!
有關fashion_mnist數據集也是類似,具體可參考想看就點我
總結: mnist/fashion_mnist ,原圖 / 訓練樣本(x)是一張僅含一種物體的圖片,標籤(y) 是一個字符或者說one_hot編碼。對其處理也較爲簡單,屬於分類任務
1.2 語義分割基礎
1.2.1 概念
- 直觀理解:
一言以蔽之: 是將一張(含有多種物體的)圖片,識別出其中的每一種物體。
語義分割的直觀理解,如下圖所示。
動畫理解(實際意義):
最簡單的語義分割,當屬於“二分類”,即識別出圖片中的兩種地物,一種歸爲背景,另一種歸爲想要提取的物體(如我們接下來要說的斑馬線) - 進一步理解:
我們以遙感影像爲例,如下圖所示:
左邊是原圖/樣本(x),右圖是標籤(label,黑色的是背景,紅色的是標註的建築物)
注: 標註自己的訓練樣本
這裏的標籤是一張圖片,不同於mnsit(標籤是一個字符串)
這裏的標籤是一張圖片,不同於mnsit(標籤是一個字符串)
這裏的標籤是一張圖片,不同於mnsit(標籤是一個字符串)
- 更進一步理解:
通過如下過程,不斷計算損失函數,調整參數,直至loss函數最小,即參數最優爲止。
- 深入理解:
整個網絡如下圖所示,分爲兩部分,分別是編碼、解碼部分。
圖像shape的變化:
編碼:圖像(size)不斷變小,同時不斷變厚(通道數增加)
解碼:圖像(size)不斷變大(最終恢復到原圖的1/2),同時不斷變薄(通道數減小)
1.2.2 結果展示
使用下方代碼,進行預測,最終結果。
原圖:
預測結果:
可以發現,基本將斑馬線標出,由於訓練數據(191個)較少,有這樣的結果還算不錯。
1.3 代碼部分
說明: 具體代碼,文後附鏈接。
1.3.1 模型部分(model.py)
主函數如下:
如上所說,模型分爲兩部分,編碼、解碼。
def main():
"""
model 的主程序,語義分割,分爲兩部分,第一部分特徵提取,第二部分放大圖片
:param Height: 圖片長
:param Width: 圖片寬
:return: (H,W,3) -> (h,w,2)
"""
# 第一部分 編碼,提取特徵,圖像size減小,通道增加
img_input, feature_map_list = encoder(input_height=Height, input_width=Width)
# 第二部分 解碼,將圖像上採樣,size放大,通道減小
output = decoder(feature_map_list, class_number=class_number, input_height=Height, input_width=Width,
encoder_level=3)
# 構建模型
model = Model(img_input, output)
# model.summary()
print('模型輸入shape:', img_input.shape, '模型輸出圖像shape:', output.shape)
print('-'*100)
return model
模型的前半部分,編碼部分:
下采樣過程,下采樣包括卷積和池化。
基於VGG16的編碼網絡,用於提取特徵,有關卷積運算可參考文章1、文章2
def encoder(input_height, input_width):
"""
語義分割的第一部分,特徵提取,主要用到VGG網絡,函數式API
:param input_height: 輸入圖像的長
:param input_width: 輸入圖像的寬
:return: 返回:輸入圖像,提取到的5個特徵
"""
# 輸入
img_input = Input(shape=(input_height, input_width, 3))
# print('--')
# print(img_input.shape)
# 三行爲一個結構單元,size減半
# 416,416,3 -> 208,208,64,
x = Conv2D(64, (3, 3), activation='relu', padding='same')(img_input)
x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
x = MaxPool2D((2, 2), strides=(2, 2))(x)
f1 = x # 暫存提取的特徵
# 208,208,64 -> 104,104,128
x = Conv2D(128, (3, 3), activation='relu', padding='same')(x)
x = Conv2D(128, (3, 3), activation='relu', padding='same')(x)
x = MaxPool2D((2, 2), strides=(2, 2))(x)
f2 = x # 暫存提取的特徵
# 104,104,128 -> 52,52,256
x = Conv2D(256, (3, 3), activation='relu', padding='same')(x)
x = Conv2D(256, (3, 3), activation='relu', padding='same')(x)
x = Conv2D(256, (3, 3), activation='relu', padding='same')(x)
x = MaxPool2D((2, 2), strides=(2, 2))(x)
f3 = x # 暫存提取的特徵
# 52,52,256 -> 26,26,512
x = Conv2D(512, (3, 3), activation='relu', padding='same')(x)
x = Conv2D(512, (3, 3), activation='relu', padding='same')(x)
x = Conv2D(512, (3, 3), activation='relu', padding='same')(x)
x = MaxPool2D((2, 2), strides=(2, 2))(x)
f4 = x # 暫存提取的特徵
# 26,26,512 -> 13,13,512
x = Conv2D(512, (3, 3), activation='relu', padding='same')(x)
x = Conv2D(512, (3, 3), activation='relu', padding='same')(x)
x = Conv2D(512, (3, 3), activation='relu', padding='same')(x)
x = MaxPool2D((2, 2), strides=(2, 2))(x)
f5 = x # 暫存提取的特徵
# print(img_input.shape[1:])
return img_input, [f1, f2, f3, f4, f5]
模型的後半部分,解碼部分:
上採樣過程,
def decoder(feature_map_list, class_number, input_height=416, input_width=416, encoder_level=3):
"""
語義分割的後半部分,上採樣,將圖片放大,
:param feature_map_list: 特徵圖(多個),encoder得到
:param class_number: 分類數
:param input_height: 輸入圖像長
:param input_width: 輸入圖像寬
:param encoder_level: 利用的特徵圖,這裏利用f4
:return: output , 返回放大後的特徵圖 (208*208,2)
"""
# 獲取一個特徵圖,特徵圖來源encoder裏面的f1,f2,f3,f4,f5; 這裏獲取f4
feature_map = feature_map_list[encoder_level]
# 解碼過程 ,以下 (26,26,512) -> (208,208,64)
# f4.shape=(26,26,512) -> 26,26,512
x = ZeroPadding2D((1, 1))(feature_map)
x = Conv2D(512, (3, 3), padding='valid')(x)
x = BatchNormalization()(x)
# 上採樣,圖像長寬擴大2倍,(26,26,512) -> (52,52,256)
x = UpSampling2D((2, 2))(x)
x = ZeroPadding2D((1, 1))(x)
x = Conv2D(256, (3, 3), padding='valid')(x)
x = BatchNormalization()(x)
# 上採樣,圖像長寬擴大2倍 (52,52,512) -> (104,104,128)
x = UpSampling2D((2, 2))(x)
x = ZeroPadding2D((1, 1))(x)
x = Conv2D(128, (3, 3), padding='valid')(x)
x = BatchNormalization()(x)
# 上採樣,圖像長寬擴大2倍,(104,104,128) -> (208,208,64)
x = UpSampling2D((2, 2))(x)
x = ZeroPadding2D((1, 1))(x)
x = Conv2D(64, (3, 3), padding='valid')(x)
x = BatchNormalization()(x)
# 再進行一次卷積,將通道數變爲2(要分類的數目) (208,208,64) -> (208,208,2)
x = Conv2D(class_number, (3, 3), padding='same')(x)
# print(x.shape)
# reshape: (208,208,2) -> (208*208,2)
# y = Reshape((208,208,-1))(x)
# print('y shape:',y.shape)
x = Reshape((int(input_height / 2) * int(input_width / 2), -1))(x)
# print(x.shape)
# 求取概率
output = Softmax()(x)
return output
網絡結構,summary:
1.3.2 訓練部分(train.py)
主函數,如下:
def main():
# 獲取已建立的模型,並加載官方與訓練參數,模型編譯
model = get_model()
# 打印模型摘要
# model.summary()
# 獲取樣本(訓練集&驗證集) 和標籤的對應關係,trian_num,val_num
lines, train_nums, val_nums = get_data()
# 設置回調函數 並返回保存的路徑
callbacks, logdir = set_callbacks()
# 生成樣本和標籤
# generate_arrays_from_file(lines, batch_size=batch_size)
# 訓練
model.fit_generator(generate_arrays_from_file(lines[:train_nums], batch_size),
steps_per_epoch=max(1, train_nums // batch_size),
epochs=50, verbose=1, callbacks=callbacks,
validation_data=generate_arrays_from_file(lines[train_nums:], batch_size),
validation_steps=max(1, val_nums // batch_size),
initial_epoch=0)
save_weight_path = os.path.join(logdir, 'last.h5') # 保存模型參數的路徑
model.save_weights(save_weight_path)
1.3.3 預測部分(predict.py)
主函數,如下:
def main():
""" 模型預測"""
model = get_model()
predicting(model)
1.3.4 源碼鏈接
https://github.com/onceone/Semantic-segmentation
參考文獻
[1] https://blog.csdn.net/weixin_39190382/article/details/105890812
[2] https://blog.csdn.net/weixin_44791964/article/details/102979289
[3] https://baijiahao.baidu.com/s?id=1602428106371812559&wfr=spider&for=pc
[4] https://blog.csdn.net/weixin_40446557/article/details/85624579
[5] https://www.sohu.com/a/301097998_120054440
[6] https://www.cnblogs.com/xianhan/p/9145966.html
[7] https://blog.csdn.net/weixin_39190382/article/details/105692853
[8] https://blog.csdn.net/weixin_39190382/article/details/105702100
[9] https://blog.csdn.net/weixin_39190382/article/details/104083347