學習網址:keras中文文檔,keras document
常用API導入
import numpy as np
from keras import layers
from keras.layers import Input, Dense, Activation, ZeroPadding2D, BatchNormalization, Flatten, Conv2D # X = Input(...), X = ZeroPadding2D(...)
from keras.layers import AveragePooling2D, MaxPooling2D, Dropout, GlobalMaxPooling2D, GlobalAveragePooling2D
from keras.models import Model
from keras.preprocessing import image
from keras.utils import layer_utils
from keras.utils.data_utils import get_file
from keras.applications.imagenet_utils import preprocess_input
import pydot
from IPython.display import SVG
from keras.utils.vis_utils import model_to_dot
from keras.utils import plot_model
from kt_utils import *
import keras.backend as K
K.set_image_data_format('channels_last')
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow
建立模型
def model(input_shape):
X_input = Input(input_shape)
X = ZeroPadding2D((3, 3))(X_input)
X = Conv2D(32, (7, 7), strides = (1, 1), name = 'conv0')(X)
X = BatchNormalization(axis = 3, name = 'bn0')(X)
X = Activation('relu')(X)
X = MaxPooling2D((2, 2), name='max_pool')(X)
X = Flatten()(X) # FLATTEN X (means convert it to a vector) + FULLYCONNECTED
X = Dense(1, activation='sigmoid', name='fc')(X)
model = Model(inputs = X_input, outputs = X, name='HappyModel')
return model
訓練和測試模型
happyModel = HappyModel(X_train.shape[1:4]) # 建立模型
happyModel.compile(optimizer='adam',loss='mean_squared_error',metrics = ["accuracy"]) # 編譯模型
happyModel.fit(x = X_train, y = Y_train, epochs = 40, batch_size = 20) # 訓練模型
# -----------------------------
preds = happyModel.evaluate(x = X_test, y = Y_test) # 測試模型
print ("Loss = " + str(preds[0]))
print ("Test Accuracy = " + str(preds[1]))
利用模型測試自己圖像
img_path = 'images/my_image.jpg'
img = image.load_img(img_path, target_size=(64, 64)) # 下載並調整圖像大小
imshow(img)
x = image.img_to_array(img) # 將圖像轉爲數組
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
print(happyModel.predict(x))
其他有用的函數
# model.summary(): prints the details of your layers
# plot_model(): plots your graph in a nice layout.
happyModel.summary()
plot_model(happyModel, to_file='HappyModel.png')
SVG(model_to_dot(happyModel).create(prog='dot', format='svg'))