keras筆記

學習網址:keras中文文檔keras document
常用API導入

import numpy as np
from keras import layers
from keras.layers import Input, Dense, Activation, ZeroPadding2D, BatchNormalization, Flatten, Conv2D     # X = Input(...), X = ZeroPadding2D(...)
from keras.layers import AveragePooling2D, MaxPooling2D, Dropout, GlobalMaxPooling2D, GlobalAveragePooling2D
from keras.models import Model
from keras.preprocessing import image
from keras.utils import layer_utils
from keras.utils.data_utils import get_file
from keras.applications.imagenet_utils import preprocess_input
import pydot
from IPython.display import SVG
from keras.utils.vis_utils import model_to_dot
from keras.utils import plot_model
from kt_utils import *
import keras.backend as K
K.set_image_data_format('channels_last')
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow

建立模型

def model(input_shape):
    X_input = Input(input_shape)
    X = ZeroPadding2D((3, 3))(X_input)
    X = Conv2D(32, (7, 7), strides = (1, 1), name = 'conv0')(X)
    X = BatchNormalization(axis = 3, name = 'bn0')(X)
    X = Activation('relu')(X)
    X = MaxPooling2D((2, 2), name='max_pool')(X)
    X = Flatten()(X)  # FLATTEN X (means convert it to a vector) + FULLYCONNECTED
    X = Dense(1, activation='sigmoid', name='fc')(X)
    model = Model(inputs = X_input, outputs = X, name='HappyModel')
    return model

訓練和測試模型

happyModel = HappyModel(X_train.shape[1:4])  # 建立模型
happyModel.compile(optimizer='adam',loss='mean_squared_error',metrics = ["accuracy"])  # 編譯模型
happyModel.fit(x = X_train, y = Y_train, epochs = 40, batch_size = 20)  # 訓練模型
# -----------------------------
preds = happyModel.evaluate(x = X_test, y = Y_test)  # 測試模型
print ("Loss = " + str(preds[0]))
print ("Test Accuracy = " + str(preds[1]))

利用模型測試自己圖像

img_path = 'images/my_image.jpg'
img = image.load_img(img_path, target_size=(64, 64))   # 下載並調整圖像大小
imshow(img)
x = image.img_to_array(img) # 將圖像轉爲數組
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
print(happyModel.predict(x))

其他有用的函數

# model.summary(): prints the details of your layers
# plot_model(): plots your graph in a nice layout. 
happyModel.summary()
plot_model(happyModel, to_file='HappyModel.png')
SVG(model_to_dot(happyModel).create(prog='dot', format='svg'))
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章