用圖片數據集訓練神經網絡 tensorflow

        爲了學tensorflow,網上教程看了不少,大部分是利用mnist數據集,但是大部分都是利用已經處理好的非圖片形式進行訓練的;而且很多人都說不要自己造輪子,直接跑別人的代碼,然後修改和學習,不過我覺得其實從數據集,到訓練,到測試寫下來,也有不少收穫的。

        本篇博客主要講的是如何將自己的圖片數據集進行處理,然後搭建神經網絡結構,訓練數據,保存和加載模型,測試等過程,其中利用mnist圖片數據集(mnist),代碼(Github)。下面我分步講解一下,跟大家一起學習。參考博客附後。

         如前文所述,現在的入門教程基本是跑mnist代碼,但是數據集是處理後的,那麼如何處理自己的圖像數據集?首先要將圖片數據集製作成tfrecords格式的數據。簡單的說,tfrecords是一種tensorflow方便快速讀取數據的一種二進制文件,適用於大量數據的處理。下面的代碼將說明如何將圖像製作成tfrecords文件,該mnist圖片數據集大小爲(28,28,3)。


# current work dir
cwd = os.getcwd()

# data to int64List
def _int64_feature(value):
    return tf.train.Feature(int64_list = tf.train.Int64List(value=[value]))
# data to floatlist
def _float_feature(value):
    return tf.train.Feature(float_list = tf.train.FloatList(value=[value]))
# data to byteslist
def _bytes_feature(value):
    return tf.train.Feature(bytes_list = tf.train.BytesList(value=[value]))

# convert image data to tfrecords
def generate_tfrecords(data_dir,filepath):
    # gen a tfrecords object write
    write = tf.python_io.TFRecordWriter(filepath)
    #print cwd
    for index,name in enumerate(num_classes):
	#print class_path
	file_dir = data_dir + name + '/'
	for img_name in os.listdir(file_dir):
	    img_path = file_dir + img_name
	    print img_path
	    img = cv2.imread(img_path)
	    img_raw = img.tobytes()
	    example = tf.train.Example(features=tf.train.Features(feature={
					'label':_int64_feature(index),
					'img_raw':_bytes_feature(img_raw)}))
	    # convert example to binary string
	    write.write(example.SerializeToString())
    write.close()


        tf.python_io.TFRecordWriter返回一個writer對象用於將data_dir製作後的數據存入filepath中保存,該文件就是tfrecords文件。另外,tf.train.Example將數據處理成key-value(在這裏就是標籤-圖像)的格式返回一個example對象。最後writer將數據寫到filepath中,關閉writer,就完成了圖片數據到二進制文件的製作過程。
       製作完成之後,在神經網絡中如何讀取和解析呢?如下代碼

# read and decode tfrecord
def read_and_decode_tfrecord(filename):
    # produce file deque
    filename_deque = tf.train.string_input_producer([filename])
    # generate reader object
    reader = tf.TFRecordReader()
    # read data from filename_deque
    _, serialized_example = reader.read(filename_deque)
    # decode into fixed form
    features = tf.parse_single_example(serialized_example,features={
					'label':tf.FixedLenFeature([],tf.int64),
					'img_raw':tf.FixedLenFeature([],tf.string)})

    label = tf.cast(features['label'],tf.int32)
    img = tf.decode_raw(features['img_raw'],tf.uint8)
    img = tf.reshape(img,[28,28,3])
    img = tf.cast(img,tf.float32)/255.-0.5
    return label,img

        其中,tf.train.string_input_producer([filename])是將filename的文件內容製作成一個隊列,然後tf.parse_single_example按照固定的格式將內容解析出來,稍加處理即可得到label和img,當然[filename]中可以有很多file,因爲當圖片數據太大時可能會將數據分成好幾個部分分別製作tfrecords進行存儲和讀取。
       然後搭建神經網絡,這裏就搭建一個簡單點的,如下

# create network 
class network(object):
    # define parameters w and b
    def __init__(self):
	with tf.variable_scope("Weight"):
	   self.weights={
		'conv1':tf.get_variable('conv1',[5,5,3,32],initializer=tf.contrib.layers.xavier_initializer_conv2d()),
		'conv2':tf.get_variable('conv2',[5,5,32,64],initializer=tf.contrib.layers.xavier_initializer_conv2d()),
		'fc1'  :tf.get_variable('fc1',  [7*7*64,1024],initializer=tf.contrib.layers.xavier_initializer()),
		'fc2'  :tf.get_variable('fc2',  [1024,10],     initializer=tf.contrib.layers.xavier_initializer()),}
	with tf.variable_scope("biases"):
	    self.biases={
		'conv1':tf.get_variable('conv1',[32,],initializer=tf.constant_initializer(value=0.0,dtype=tf.float32)),
		'conv2':tf.get_variable('conv2',[64,],initializer=tf.constant_initializer(value=0.0,dtype=tf.float32)),
		'fc1'  :tf.get_variable('fc1',  [1024,],initializer=tf.constant_initializer(value=0.0,dtype=tf.float32)),
		'fc2'  :tf.get_variable('fc2',  [10,] ,initializer=tf.constant_initializer(value=0.0,dtype=tf.float32)),}

    # define model
    def model(self,img):
	conv1 = tf.nn.bias_add(tf.nn.conv2d(img,self.weights['conv1'],strides=[1,1,1,1],padding='SAME'),self.biases['conv1'])
	relu1 = tf.nn.relu(conv1)
	pool1 = tf.nn.max_pool(relu1,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')

	conv2 = tf.nn.bias_add(tf.nn.conv2d(pool1,self.weights['conv2'],strides=[1,1,1,1],padding='SAME'),self.biases['conv2'])
	relu2 = tf.nn.relu(conv2)
	pool2 = tf.nn.max_pool(relu2,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')

	flatten = tf.reshape(pool2,[-1,self.weights['fc1'].get_shape().as_list()[0]])
	
	drop1 = tf.nn.dropout(flatten,0.8)
	fc1   = tf.matmul(drop1,self.weights['fc1']) + self.biases['fc1']
	fc_relu1 = tf.nn.relu(fc1)
	fc2   = tf.matmul(fc_relu1,self.weights['fc2'])+self.biases['fc2']

	return fc2

    # define model test
    def test(self,img):
	img = tf.reshape(img,shape=[-1,28,28,3])

	conv1 = tf.nn.bias_add(tf.nn.conv2d(img,self.weights['conv1'],strides=[1,1,1,1],padding='SAME'),self.biases['conv1'])
	relu1 = tf.nn.relu(conv1)
	pool1 = tf.nn.max_pool(relu1,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')

	conv2 = tf.nn.bias_add(tf.nn.conv2d(pool1,self.weights['conv2'],strides=[1,1,1,1],padding='SAME'),self.biases['conv2'])
	relu2 = tf.nn.relu(conv2)
	pool2 = tf.nn.max_pool(relu2,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')

	flatten = tf.reshape(pool2,[-1,self.weights['fc1'].get_shape().as_list()[0]])

	drop1 = tf.nn.dropout(flatten,1)
	fc1   = tf.matmul(drop1,self.weights['fc1']) + self.biases['fc1']
	fc_relu1 = tf.nn.relu(fc1)
	fc2   = tf.matmul(fc_relu1,self.weights['fc2'])+self.biases['fc2']

	return fc2

    #loss
    def softmax_loss(self,predicts,labels):
	predicts = tf.nn.softmax(predicts)
	labels   = tf.one_hot(labels,len(num_classes))
	loss = -tf.reduce_mean(labels*tf.log(predicts))
	self.cost = loss
	return self.cost

    # optimizer
    def optimizer(self,loss,lr=0.001):
	train_optimizer = tf.train.GradientDescentOptimizer(lr).minimize(loss)
	return train_optimizer

        tf.contrib.layers.xavier_initializer_conv2d()是對參數進行初始化的,據某位童鞋的博客說,當激活函數是sigmoid或tanh時,這個初始化方法比較好,但是當激活函數是relu時,使用tf.contrib.layers.variance_scaling_initializer比較好,具體我沒有嘗試過,大家可以試一試,另外tf.contrib.layers.xavier_initializer()也是一種權值初始化方式。而在神經網絡中,權值的初始化非常重要,可以按照某種特定的分佈來初始化,以後可以嘗試使用其他初始化方式從而加快收斂速度和準確率。
        dropout層是一種解決過擬合的方法,它是在訓練的過程中對網絡中的某些神經單元按照一定概率屏蔽,此時的概率選擇爲0.8,但是在測試時可以不用dropout層,或者將概率設置爲1,即使用所有的神經單元,這樣應該能提高正確率。
        模型搭建完畢就可以開始訓練了,代碼如下

def train():
    label, img = read_and_decode_tfrecord(train_tfrecords_dir)
    img_batch,label_batch = tf.train.shuffle_batch([img,label],num_threads=16,batch_size=batch_size,capacity=50000,min_after_dequeue=49000)

    net = network()
    predicts = net.model(img_batch)
    loss = net.softmax_loss(predicts,label_batch)
    opti = net.optimizer(loss)
    # add trace
    tf.summary.scalar('cost fuction',loss)
    merged_summary_op = tf.summary.merge_all()


    train_correct = tf.equal(tf.cast(tf.argmax(predicts,1),tf.int32),label_batch)
    train_accuracy = tf.reduce_mean(tf.cast(train_correct,tf.float32))

    #evaluate
    test_label,test_img = read_and_decode_tfrecord(test_tfrecords_dir)
    test_img_batch,test_label_batch = tf.train.shuffle_batch([test_img,test_label],num_threads=16,batch_size=batch_size,capacity=50000,min_after_dequeue=40000)
    test_out = net.test(test_img_batch)
    
    test_correct = tf.equal(tf.cast(tf.argmax(test_out,1),tf.int32),test_label_batch)
    test_accuracy = tf.reduce_mean(tf.cast(test_correct,tf.float32))

    # init varibels
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
	sess.run(init)
	# manerge different threads
	coord = tf.train.Coordinator()
        summary_writer = tf.summary.FileWriter('log',sess.graph)
    	# run deque
    	threads = tf.train.start_queue_runners(sess=sess,coord=coord)
	path =cwd+'/'+'model/model.ckpt'
	#if os.path.exists(path):
	try:
	    print "try to reload model ......"
	    tf.train.Saver(max_to_keep=None).restore(sess,path)
	    print 'reload successful ......'
	except:
	    print 'reload model failed ......'
	finally:
	    print 'training .......'
    	for i in range(1,epoch+1):
	    #val,l= sess.run([img_batch,label_batch])
	    if i%50 ==0:
	        loss_np,_,label_np,img_np,predict_np = sess.run([loss,opti,label_batch,img_batch,predicts])
		tr_accuracy_np = sess.run([train_accuracy])
		print i,' epoch loss :',loss_np,'    train accuracy: ', tr_accuracy_np
	    if i%200==0:
		summary_str,_l,_o = sess.run([merged_summary_op,loss,opti])
		summary_writer.add_summary(summary_str,i)
		te_accuracy = sess.run([test_accuracy])
		print 'test accuracy: ', te_accuracy
	    if i%1000==0:
	        tf.train.Saver(max_to_keep=None).save(sess,os.path.join('model','model.ckpt'))
	# somethind happend that the thread should stop
	coord.request_stop()
	# wait for all threads should stop and then stop
	coord.join(threads)

        tf.train.shuffle_batch是將隊列裏的數據打亂順序使用n_threads個線程,batch_size大小的形式讀取出來,capacity是整個隊列的容量,min_after_deque代表參與順序打亂的程度,參數越大代表數據越混亂。在本代碼中,由於各個類別已經分好,大概都是5000張,而在製作tfrecords的時候是按順序存儲的,所以使用tf.train.shuffle_batch來打亂順序,但是如果batch_size設置太小,那很大概率上每個batch_size的圖像數據的類別都是一樣的,造成過擬合,所以本次將batch_size設置成2000,這樣效果比較明顯,設置成1000也可以,或者在處理數據的時候提前將數據打亂,或者有其他方法歡迎下方討論。

tf.summary.scalar('cost fuction',loss) 
merged_summary_op = tf.summary.merge_all()
summary_writer = tf.summary.FileWriter('log',sess.graph)
summary_str,_l,_o = sess.run([merged_summary_op,loss,opti])
summary_writer.add_summary(summary_str,i)


         以上幾行代碼的搭配是將loss的變化存儲在log文件中,當訓練完成之後使用以下命令進行圖形化展示,有些情況下直接使用tensorboard命令不一定管用,此時可以找到tensorboard.py目錄,使用第二種命令進行展示.最後打開瀏覽器,訪問終端提示的網址即可。
         tensorboard --logdir=log
         python 目錄/tensorboard.py --logdir=log


         關於evaluate部分,在訓練過程中可以使用一部分數據集來驗證模型的準確率,本程序將驗證集合測試集視爲相同。

         coord = tf.train.Coordinator()#創建一個協調器,用於管理線程,發生錯誤時及時關閉線程
         threads = tf.train.start_queue_runners(sess=sess,coord=coord)#各個線程開始讀取數據,這一句如果沒有,整個網絡將被掛起
         coord.request_stop()#某個線程數據讀取完或發生錯誤請求停止
         coord.join(threads)#所有線程都請求停止後關閉線程
         以上幾行代碼的搭配是線程的開啓和關閉過程,後面兩句如果不存在,當讀取過程出現某些錯誤(如Outofrange)時,程序將不會正常關閉等,詳細情況大家可以查閱一下其他資料。

          tf.train.Saver(max_to_keep=None).save(sess,os.path.join('model','model.ckpt'))
          tf.train.Saver(max_to_keep=None).restore(sess,path)
          以上兩句是模型的保存和恢復,max_to_keep=None這個參數是保存最新的或者加載最新的模型。

          accuracy_np = sess.run([accuracy])
          以上是關於輸出,如果希望得到某個輸出A,那麼只要使用A_out = sess.run([A])即可
 
參考博客:1. http://blog.csdn.net/hjimce/article/details/51899683

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章