我們已經訓練過幾個神經網絡了,識別手寫數字,房價預測或者是區分貓和狗,那隨之而來就有一個問題,這些訓練出的網絡怎麼用,每個問題我都需要重新去訓練網絡嗎?因爲程序員都不太喜歡做重複的事情,因此答案肯定是已經有輪子了。
我們先來介紹一個數據集,ImageNet。這就不得不提一個大名鼎鼎的華裔 AI 科學家李飛飛。
2005 年左右,李飛飛結束了他的博士生涯,開始了他的學術研究不就她就意識到了一個問題,在此之前,人們都儘可能優化算法,認爲無論數據如何,只要算法夠好,就能做出更好的決策,李飛飛意識到了這個問題的侷限性,恰巧她還是一個行動派,她要做出一個無比龐大的數據集,儘可能描述世界上一切物體的數據集,下載圖片,給沒一張圖片做標註,簡單而無聊,當然後來這項工作放到了亞馬遜的衆包平臺上,全世界無數的人蔘與了這個偉大的項目,到此刻爲止,已經有 14,197,122 張圖片(一千四百萬張),21841 個分類。在這個發展的過程中,人們也發現了這個數據集帶來的成功遠比預想的要多,甚至現在被認爲最有前景的深度卷積神經網絡的提出也與 ImageNet 不無關係。我忘記了誰這麼說過:“就單單這一個數據集,就可以讓李飛飛數據科學這個領域擁有一席之地”。暫且不說這麼說是否準確,但這個數據集仍然在創造新的突破。(我曾經在臺下聽過李飛飛一次演講,現在想想還覺得甚是激動,她真的充滿熱情)。
基於這個數據集,我們是不是可以訓練出一些網絡,一般情況下,大家就不用耗時再去訓練網絡了呢?答案是肯定的,並且在 Keras 就有個一些這樣的模型,還是內置的,Keras 就是這麼懂你,那就不用客氣了,我們拿來用就好了,謝謝啦!
特徵提取
我們之前用到的卷積神經網絡都是分成了兩部分,第一部分是由池化層和卷積層組成的卷積積,第二部分是由分類器,特徵提取的含義就是第一部分不變,改變第二部分。
爲什麼可以這麼做?我們之前解釋過神經網絡的運行原理,跟人腦的認識過程非常類似,還記得嗎?我們還是看一看原來的圖吧。
我們可以看出來,網絡識別圖像是有層次結構的,比如一開始的網絡層是用來識別圖像或者拼裝線條的,這是通用且類似的,因此我們可以複用。而後面的分類器往往是根據具體的問題所決定的,比如識別貓或狗的眼睛就與識別桌子腿是不一樣的,因此有越靠前越具有通用性的特點。Keras 中很多的內置模型都可以直接下載,如果你沒有下載在使用的時候會自動下載:
https://github.com/fchollet/deep-learning-models/releases
我們舉一個例子,用 VGG16 去識別貓或狗,這次的解釋都比較簡單且都是以前說明過的,因此放在代碼註釋中:
#!/usr/bin/env python3
import os
import time
import matplotlib.pyplot as plt
import numpy as np
from keras import layers
from keras import models
from keras import optimizers
from keras.applications import VGG16
from keras.preprocessing.image import ImageDataGenerator
def extract_features(directory, sample_count):
# 圖片轉換區間
datagen = ImageDataGenerator(rescale=1. / 255)
batch_size = 20
conv_base = VGG16(weights='imagenet',
include_top=False,
input_shape=(150, 150, 3))
conv_base.summary()
features = np.zeros(shape=(sample_count, 4, 4, 512))
labels = np.zeros(shape=(sample_count))
# 讀出圖片,處理成神經網絡需要的數據格式,上一篇文章中有介紹
generator = datagen.flow_from_directory(
directory,
target_size=(150, 150),
batch_size=batch_size,
class_mode='binary')
i = 0
for inputs_batch, labels_batch in generator:
print(i, '/', len(generator))
# 提取特徵
features_batch = conv_base.predict(inputs_batch)
features[i * batch_size: (i + 1) * batch_size] = features_batch
labels[i * batch_size: (i + 1) * batch_size] = labels_batch
i += 1
if i * batch_size >= sample_count:
break
# 特徵和標籤
return features, labels
def cat():
base_dir = '/Users/renyuzhuo/Desktop/cat/dogs-vs-cats-small'
train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')
# 提取出的特徵
train_features, train_labels = extract_features(train_dir, 2000)
validation_features, validation_labels = extract_features(validation_dir, 1000)
# 對特徵進行變形展平
train_features = np.reshape(train_features, (2000, 4 * 4 * 512))
validation_features = np.reshape(validation_features, (1000, 4 * 4 * 512))
# 定義密集連接分類器
model = models.Sequential()
model.add(layers.Dense(256, activation='relu', input_dim=4 * 4 * 512))
model.add(layers.Dropout(0.5))
model.add(layers.Dense(1, activation='sigmoid'))
# 對模型進行配置
model.compile(optimizer=optimizers.RMSprop(lr=2e-5),
loss='binary_crossentropy',
metrics=['acc'])
# 對模型進行訓練
history = model.fit(train_features, train_labels,
epochs=30,
batch_size=20,
validation_data=(validation_features, validation_labels))
# 畫圖
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(1, len(acc) + 1)
plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('Training and validation accuracy')
plt.legend()
plt.show()
plt.figure()
plt.plot(epochs, loss, 'bo', label='Training loss')
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and validation loss')
plt.legend()
plt.show()
if __name__ == "__main__":
time_start = time.time()
cat()
time_end = time.time()
print('Time Used: ', time_end - time_start)
有點巧合的是這裏居然看不到太多的過擬合的痕跡,其實也是有可能會有過擬合的隱患的,那樣就需要進行數據增強,與以前是一樣的,只不過這裏的區別就是用到了內置模型,模型的參數需要凍結,我們是不希望對已經訓練好的模型進行更改的,具體關鍵代碼寫法如下:
conv_base.trainable = False
model = models.Sequential()
model.add(conv_base)
model.add(layers.Flatten())
model.add(layers.Dense(256, activation='relu'))
model.add(layers.Dense(1, activation='sigmoid'))
以上就是模型複用的一種方法,我們對模型都是原封不動的拿來用,我們下一篇文章將介紹另外一種方法,對模型進行微調。
首發自公衆號:RAIS