6、MNIST數據分類(詳細函數代碼)

1、MNIST數據訓練要點

手寫數字識別:

01 像素:28*28=784

02 標籤:神經網絡對圖像進行分類,分配正確的標籤,這些標籤是0到9共10個數字中的一個,這意味着神經網絡有10個輸出層節點,每個節點對應一個可能得答案或標籤,如果答案是“0”,輸出吃呢個第一個節點激發而其餘的輸出節點則保持抑制狀態。

03 輸出:試圖讓神經網絡生成0和1的輸出,對於激活函數而言是不可能的,這回導致大的權重和飽和網絡。因此需要重新調整這些數據,這裏使用0.01和0.99來代替0和1.

注:其餘基礎知識,見之前更新的參數知識

2、詳細代碼(備有詳細註釋)

# -*- coding: utf-8 -*-

import numpy

import scipy

import scipy.special

class neuralNetwork:
    
    # initialise the neural network 初始化網絡
    def _init_(self,inputnodes,hiddennodes,outputnodes,learningrate):
        
        # set number of nodes in each input,hidden,output layer
        self.inodes=inputnodes
        self.hnodes=hiddennodes
        self.onodes=outputnodes
        
        # link weight matrices,with and who 初始化權重
        # w11,w12,w21,w22
        # 正態分佈中心設置爲0.0,pow(self.hnodes,-0.5)表示節點數目的0.5次方,最後一個參數是numpy數組的形狀大小
        
        self.wih=numpy.random.normal(0.0,pow(self.hnodes,-0.5),(self.hnodes,self.inodes))
        self.who=numpy.random.normal(0.0,pow(self.onodes,-0.5),(self.onodes,self.hnodes))
        
        # learning rate 學習率
        self.lr=learningrate
        
        # 激活函數 activation function is the sigmoid function 
        # lambda 是定義函數的簡短形式
        
        self.activation_function =lambda x: scipy.special.expit(x)
        
        # train the network 訓練網絡
        
    def train(self,inputs_list,targets_list):
            
            # convert input list to 2d array
        inputs=numpy.array(inputs_list,ndmin=2).T
        targets=numpy.array(targets_list,ndmin=2).T
            
            # calculate signals into hidden layer  計算隱藏層輸入
        hidden_inputs=numpy.dot(self.wih,inputs)
            
            # calculate the signals emerging from hidden layer 計算隱藏層輸出信號
        hidden_outputs=self.activation_function(hidden_inputs)
            # 計算輸出的
        
        final_inputs=numpy.dot(self.who,hidden_outputs)
            
        final_outputs=self.activation_function(final_inputs)
            
            # output layer error is the (target-actual) 計算誤差
        output_errors=targets-final_outputs
            # hidden layer error is the output_error,split by weights.. 隱藏層誤差
        hidden_errors=numpy.dot(self.who.T,output_errors)
            
            # 在隱藏層和輸出層之間更新權重
            
        self.who += self.lr * numpy.dot((output_errors * final_outputs*(1.0-final_outputs)),numpy.transpose(hidden_outputs))
            
        self.wih += self.lr * numpy.dot((hidden_errors * hidden_outputs * (1.0 - hidden_outputs)), numpy.transpose(inputs))
            
        pass
            
        
        # query the neural network 查詢網絡
        
    def query(self,inputs_list):
            
            # convert inputs list tp 2d array 把列表編程數組
        inputs=numpy.array(inputs_list,ndmin=2).T
           
            # calculate signals into hidden layer  計算隱藏層輸入
        hidden_inputs=numpy.dot(self.wih,inputs)
            
            # calculate the signals emerging from hidden layer 計算隱藏層輸出信號
        hidden_outputs=self.activation_function(hidden_inputs)
            
            # 計算輸出的
        
        final_inputs=numpy.dot(self.who,hidden_outputs)
            
        final_outputs=self.activation_function(final_inputs)
            
        return final_outputs
    
# 輸入相關參數進行計算

input_nodes = 784
hidden_nodes = 100
output_nodes = 10

learning_rate = 0.3

# create instance of neural network

n = neuralNetwork()

n._init_(784,100,10,0.3)
# load the mnist trai

training_data_file = open("D:/DATA/pycase/number2/project/pretice/mnist_train_100.csv", 'r')
training_data_list = training_data_file.readlines()
training_data_file.close()

for record in training_data_list:

    all_values = record.split(',') # split the record by the ',' commas
# scale and shift the inputs
    inputs = (numpy.asfarray(all_values[1:]) / 255.0 * 0.99) + 0.01

 # create the target output values (all 0.01, except the desiredlabel which is 0.99)

    targets = numpy.zeros(output_nodes) + 0.01 # all_values[0] is the target label for this record

    targets[int(all_values[0])] = 0.99
    
    n.train(inputs, targets)
    
    pass

# load the mnist test data CSV file into a list
test_data_file = open("D:/DATA/pycase/number2/project/pretice/mnist_test_10.csv", 'r')
test_data_list = test_data_file.readlines()
test_data_file.close()


all_values=test_data_list[0].split(',')

print(all_values[0])

 

 

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章