TensorFlow 使用google inception_v3進行圖片訓練例程測試(2019_07)

1.起始,因爲TensorFlow優化模型使用方法,引入了tensorflow hub,讓使用更簡單,但導致後果就是以前的教程基本不能使用,官方例程因爲hub的模型基本是使用URL模式加載,剛好上不去,這就尷尬了,需要自己慢慢摸索,對萌新及其不友好,本篇主要記錄本次調試過程,方便後人繞坑。本次使用google 已經訓練好的模型inception_v3,然後對最後一層進行重新訓練,以滿足我們需要的分類要求

2.環境要求

安裝TensorFlow,CPU和GPU版本都可以,GPU比較快而已,CPU直接使用最新的即可,目前最新的GPU版本對應的cuda_10.0 和 tensorflow-gpu 1.14.0,使用cuda_10.1會出現TF無法使用cuda的問題,懷疑是Anaconda沒有及時同步導致。

3.準備工作

(1)前往https://github.com/tensorflow/tensorflow ,下載對應的TensorFlow源碼。

(2)安裝hub, pip install tensorflow-hub ,並前往https://github.com/tensorflow/hub ,下載對應的tensorflow-hub源碼

(3)準備好自己需要分類的圖片,按類型劃分好文件名字,我這裏使用的是官方提供的數據集 flower_photos,需要的自己去下載,不用科學上網

(4).在下載下來的hub源碼中找到hub-master\examples\image_retraining文件夾,運行retrain.py,開始訓練。不能科學上網的會在這裏被卡住,我這裏提供一個野生方法,本地化模型,更改模型爲本地加載,參考連接https://zhuanlan.zhihu.com/p/64069911。

下載模型文件,示例如下:

模型路徑:https://tfhub.dev/google/imagenet/inception_v3/feature_vector/3

下載模型路徑:https://storage.googleapis.com/tfhub-modules/google/imagenet/inception_v3/feature_vector/3.tar.gz

下載後模型需要解壓纔可以正常使用,然後,運行腳本開始訓練

python H:\tf_py\hub-master\examples\image_retraining\retrain.py ^
--image_dir H:\tf_py\image_retrain\flower_photos\flower_photos ^
--tfhub_module H:\tf_py\image_retrain\inception\3 ^
--saved_model_dir H:\tf_py\image_retrain\inception\4 
pause

4.檢測訓練好的模型

因爲我使用的是鮮花( 玫瑰 鬱金香 向日葵 雛菊 蒲公)的分類訓練,所以我去百度下了很多這種類型的圖片進行測試。

不幸的是TensorFlow上面的測試例程,因爲移植等問題,已經對不上這個模型的測試例程了,於是我自己碼了一個心塞。

示例代碼:

import tensorflow as tf
import tensorflow_hub as hub
import os
import re
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

from tensorflow.keras import layers


saved_model_dir = 'image_retrain/inception/4'

label_lookup_path = 'image_retrain/output_labels.txt' 
image_path = 'image_retrain/test/'

class NodeLookup(object):
    def __init__(self):
        self.node_lookup = self.load(label_lookup_path)
        
    def load(self,label_lookup_path):
        proto_as_ascii_lines = tf.gfile.GFile(label_lookup_path).readlines()
        node_id_to_name = {}
        #一行一行讀取數據
        for uid,line in enumerate(proto_as_ascii_lines):
            #去掉換行符
            line = line.strip('\n')
            node_id_to_name[uid] = line
        return node_id_to_name
    
    #傳入分類編號1-1000返回分類名稱
    def id_to_string(self,node_id):
        if node_id not in self.node_lookup:
            return ''
        return self.node_lookup[node_id]

with tf.Session() as sess:
# 如果不知道模型具體信息 可以使用saved_model_cli.py 查看該模型的輸入 輸出數據格式要求以及關鍵的Signature簽名
#python H:\tf_py\tensorflow-master\tensorflow\python\tools\saved_model_cli.py show --dir  ....\mode\ --all
    meta_graph_def = tf.saved_model.loader.load(sess,["serve"], saved_model_dir)
    graph = tf.get_default_graph()
    oputs = sess.graph.get_tensor_by_name('final_result:0')
    input_image = sess.graph.get_tensor_by_name('Placeholder:0')
    
        #遍歷目錄
    for root,dirs,files in os.walk(image_path):
        for file in files:
            #載入圖片            
            image_data = Image.open(os.path.join(root,file)).resize([299,299])
            image_data_array = np.array(image_data)/255.0
            image_data_shape = np.reshape(image_data_array,[299,299,3])
            
            #傳入圖片不能是tensor類型 這裏使用np轉化成矩陣數組格式
            #原因出在tf.reshape(),因爲網絡訓練時用placeholder定義了輸入格式,所以輸入不能用tensor,
            #而tf.reshape()返回結果就是一個tensor了,所以輸入會報錯。
            predictions = sess.run(oputs,{input_image:[image_data_shape]})
            predictions = np.squeeze(predictions) #轉化爲一維數據
            
            image_path = os.path.join(root,file)
            print(image_path)
            plt.imshow(image_data_shape)
            plt.axis('off')
            plt.show()

            #排序 取概率最大的5個值 然後倒序
            top_k = predictions.argsort()[-5:][::-1]
            node_lookup = NodeLookup()
            for node_id in top_k:
                #獲取分類名稱
                human_string = node_lookup.id_to_string(node_id)
                #獲取分類置信度
                score = predictions[node_id]
                print('%s (score = %.5f)' %(human_string,score))
            print()



運行結果:

 

附錄一下錯誤調試,

GPU的童鞋需要注意,訓練的時候很容易出現cudnn錯誤,解決方法如下:

1.cudnn創建錯誤,環境沒錯的話就是顯卡內存出錯了,修改爲按需分配
Problem:Could not create cudnn handle: CUDNN_STATUS_ALLOC_FAILED

      config = tf.ConfigProto()
      config.gpu_options.allow_growth = True
      session = tf.Session(config=config)

2.如果不清楚的模型的輸入輸出,使用saved_model_cli.py 可以解決很多問題,我被這個輸入數據卡了3天,才找到這個解決方案。

3.訓練和測試時很容易出現莫名其妙的錯誤,這個時候最好重啓一下python服務,或者刪除緩存文件,否則你會崩潰的

4.如果能科學上網,儘量科學上網把,太折騰人了

 

 

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章