Keras深度學習:卷積神經網絡——手寫數字識別

引言:最近在閉關學習中,由於多久沒有寫博客了,今天給大家帶來學習的一些內容,還在學習神經網絡的同學,跑一跑下面的代碼,給你一些自信吧!Nice 奧裏給!

正文:首先該impor的庫就不多說了,不會的就pip install something  that you got it

備註:mnist.npz資源姐姐曉的你沒有,來躺好了:https://s3.amazonaws.com/img-datasets/mnist.npz 自己VPN哈

讀者也可以在下方代碼片裏面直接下載:mnist.load_data(data)=mnist.load_data()

# Practice mnist
from keras.datasets import mnist
import  numpy as np
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Dropout
from keras.layers import Flatten
from keras.layers.convolutional import MaxPooling2D
from keras.layers.convolutional import Conv2D
from keras.utils import np_utils
from keras import  backend
backend.set_image_data_format('channels_first')
data ='E:\\KerasRain\\resoucre\\mnist.npz'
#set random seed
seed=7
np.random.seed(seed)
# import MNIST dataset from Keras
(X_train,y_train),(X_validation,y_validation)=mnist.load_data(data)
X_train=X_train.reshape(X_train.shape[0],1,28,28).astype('float32')
X_validation=X_validation.reshape(X_validation.shape[0],1,28,28).astype('float32')
#Normalized to 0-1
X_train=X_train/255
X_validation=X_validation/255
#make one-hot code
y_train=np_utils.to_categorical(y_train)
y_validation=np_utils.to_categorical(y_validation)
#creat model
def create_model():
    # initialize model
    model=Sequential()
    # define input layer (1x28x28)
    # define Convolutional layer 32 maps, 5x5
    model.add(Conv2D(32,(5,5),input_shape=(1,28,28),activation='relu'))
    # define Pooling layer (2x2)
    model.add(MaxPooling2D(pool_size=(2,2)))
    # define Dropout layer 20%
    model.add(Dropout(0.2))
    # define Flattem layer
    model.add(Flatten())
    # define Fully connected layer 128
    model.add(Dense(units=128,activation='relu'))
    # define output layer 10
    model.add(Dense(units=10,activation='softmax'))
    # compile model
    model.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['accuracy'])
    return  model
model=create_model()
model.fit(X_train,y_train,epochs=10,batch_size=200,verbose=1)
score=model.evaluate(X_validation,y_validation,verbose=0)
print('accuracy: %.2f%%'%(score[1]*100))

 

技術羣:1090519856

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章