訓練一個好的卷積神經網絡模型進行圖像分類不僅需要計算資源還需要很長的時間。特別是模型比較複雜和數據量比較大的時候。普通的電腦動不動就需要訓練幾天的時間。爲了能夠快速地訓練好自己的花朵圖片分類器,我們可以使用別人已經訓練好的模型參數,在此基礎之上訓練我們的模型。這個便屬於遷移學習。本文提供訓練數據集和代碼下載。
原理:卷積神經網絡模型總體上可以分爲兩部分,前面的卷積層和後面的全連接層。卷積層的作用是圖片特徵的提取,全連接層作用是特徵的分類。我們的思路便是在inception-v3網絡模型上,修改全連接層,保留卷積層。卷積層的參數使用的是別人已經訓練好的,全連接層的參數需要我們初始化並使用我們自己的數據來訓練和學習。
上面inception-v3模型圖紅色箭頭前面部分是卷積層,後面是全連接層。我們需要修改修改全連接層,同時把模型的最終輸出改爲5。
由於這裏使用了tensorflow框架,所以,我們需要獲取上圖紅色箭頭所在位置的張量BOTTLENECK_TENSOR_NAME
(最後一個卷積層激活函數的輸出值,個數爲2048)以及模型最開始的輸入數據的張量JPEG_DATA_TENSOR_NAME
。獲取這兩個張量的作用是,圖片訓練數據通過JPEG_DATA_TENSOR_NAME
張量輸入模型,通過BOTTLENECK_TENSOR_NAME
張量獲取通過卷積層之後的圖片特徵。
BOTTLENECK_TENSOR_SIZE = 2048
BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape:0'
JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents:0'
通過下面的鏈接下載inception-v3模型,其中包含已經訓練好的參數。
通過下面的代碼加載模型,同時獲取上面所述的兩個張量。
# 讀取已經訓練好的Inception-v3模型。
with gfile.FastGFile(os.path.join(MODEL_DIR, MODEL_FILE), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
bottleneck_tensor, jpeg_data_tensor = tf.import_graph_def(
graph_def, return_elements=[BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME])
由於我們模型的功能是對五種花進行分類,所以,我們需要修改全連接層,這裏,我們只增加一個全連接層。全連接層的輸入數據便是BOTTLENECK_TENSOR_NAME
張量。
# 定義一層全鏈接層
with tf.name_scope('final_training_ops'):
weights = tf.Variable(tf.truncated_normal([BOTTLENECK_TENSOR_SIZE, n_classes], stddev=0.001))
biases = tf.Variable(tf.zeros([n_classes]))
logits = tf.matmul(bottleneck_input, weights) + biases
final_tensor = tf.nn.softmax(logits)
最後便是定義交叉熵損失函數。模型使用反向傳播訓練,而訓練的參數並不是模型的所有參數,僅僅是全連接層的參數,卷積層的參數是不變的。
# 定義交叉熵損失函數。
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=ground_truth_input)
cross_entropy_mean = tf.reduce_mean(cross_entropy)
train_step = tf.train.GradientDescentOptimizer(LEARNING_RATE).minimize(cross_entropy_mean)
那麼接下來的是如何給我們的模型輸入數據了,這裏提供了幾個操作數據的函數。由於訓練數據集比較小,先把所有的圖片通過JPEG_DATA_TENSOR_NAME
張量輸入模型,然後獲取BOTTLENECK_TENSOR_NAME
張量的值並保存到硬盤中。在模型訓練的時候,從硬盤中讀取所保存的BOTTLENECK_TENSOR_NAME
張量的值作爲全連接層的輸入數據。因爲一張圖片可能會被使用多次。
# 輸入圖片並獲取`BOTTLENECK_TENSOR_NAME`張量的值
def get_or_create_bottleneck(sess, image_lists, label_name, index, category, jpeg_data_tensor, bottleneck_tensor)
# 從硬盤中讀取`BOTTLENECK_TENSOR_NAME`張量的值,用於訓練
def get_or_create_bottleneck(sess, image_lists, label_name, index, category, jpeg_data_tensor, bottleneck_tensor):
# 從硬盤中讀取`BOTTLENECK_TENSOR_NAME`張量的值,用於測試。
def get_test_bottlenecks(sess, image_lists, n_classes, jpeg_data_tensor, bottleneck_tensor)
不到5分鐘就可以訓練好我們的模型,精確度還蠻高的。下圖是本人運行的結果。
源碼地址:https://github.com/liangyihuai/my_tensorflow/tree/master/com/huai/converlution/transfer_learning