TensorFlow 使用google inception_v3进行图片训练例程测试(2019_07)

1.起始,因为TensorFlow优化模型使用方法,引入了tensorflow hub,让使用更简单,但导致后果就是以前的教程基本不能使用,官方例程因为hub的模型基本是使用URL模式加载,刚好上不去,这就尴尬了,需要自己慢慢摸索,对萌新及其不友好,本篇主要记录本次调试过程,方便后人绕坑。本次使用google 已经训练好的模型inception_v3,然后对最后一层进行重新训练,以满足我们需要的分类要求

2.环境要求

安装TensorFlow,CPU和GPU版本都可以,GPU比较快而已,CPU直接使用最新的即可,目前最新的GPU版本对应的cuda_10.0 和 tensorflow-gpu 1.14.0,使用cuda_10.1会出现TF无法使用cuda的问题,怀疑是Anaconda没有及时同步导致。

3.准备工作

(1)前往https://github.com/tensorflow/tensorflow ,下载对应的TensorFlow源码。

(2)安装hub, pip install tensorflow-hub ,并前往https://github.com/tensorflow/hub ,下载对应的tensorflow-hub源码

(3)准备好自己需要分类的图片,按类型划分好文件名字,我这里使用的是官方提供的数据集 flower_photos,需要的自己去下载,不用科学上网

(4).在下载下来的hub源码中找到hub-master\examples\image_retraining文件夹,运行retrain.py,开始训练。不能科学上网的会在这里被卡住,我这里提供一个野生方法,本地化模型,更改模型为本地加载,参考连接https://zhuanlan.zhihu.com/p/64069911。

下载模型文件,示例如下:

模型路径:https://tfhub.dev/google/imagenet/inception_v3/feature_vector/3

下载模型路径:https://storage.googleapis.com/tfhub-modules/google/imagenet/inception_v3/feature_vector/3.tar.gz

下载后模型需要解压才可以正常使用,然后,运行脚本开始训练

python H:\tf_py\hub-master\examples\image_retraining\retrain.py ^
--image_dir H:\tf_py\image_retrain\flower_photos\flower_photos ^
--tfhub_module H:\tf_py\image_retrain\inception\3 ^
--saved_model_dir H:\tf_py\image_retrain\inception\4 
pause

4.检测训练好的模型

因为我使用的是鲜花( 玫瑰 郁金香 向日葵 雏菊 蒲公)的分类训练,所以我去百度下了很多这种类型的图片进行测试。

不幸的是TensorFlow上面的测试例程,因为移植等问题,已经对不上这个模型的测试例程了,于是我自己码了一个心塞。

示例代码:

import tensorflow as tf
import tensorflow_hub as hub
import os
import re
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

from tensorflow.keras import layers


saved_model_dir = 'image_retrain/inception/4'

label_lookup_path = 'image_retrain/output_labels.txt' 
image_path = 'image_retrain/test/'

class NodeLookup(object):
    def __init__(self):
        self.node_lookup = self.load(label_lookup_path)
        
    def load(self,label_lookup_path):
        proto_as_ascii_lines = tf.gfile.GFile(label_lookup_path).readlines()
        node_id_to_name = {}
        #一行一行读取数据
        for uid,line in enumerate(proto_as_ascii_lines):
            #去掉换行符
            line = line.strip('\n')
            node_id_to_name[uid] = line
        return node_id_to_name
    
    #传入分类编号1-1000返回分类名称
    def id_to_string(self,node_id):
        if node_id not in self.node_lookup:
            return ''
        return self.node_lookup[node_id]

with tf.Session() as sess:
# 如果不知道模型具体信息 可以使用saved_model_cli.py 查看该模型的输入 输出数据格式要求以及关键的Signature签名
#python H:\tf_py\tensorflow-master\tensorflow\python\tools\saved_model_cli.py show --dir  ....\mode\ --all
    meta_graph_def = tf.saved_model.loader.load(sess,["serve"], saved_model_dir)
    graph = tf.get_default_graph()
    oputs = sess.graph.get_tensor_by_name('final_result:0')
    input_image = sess.graph.get_tensor_by_name('Placeholder:0')
    
        #遍历目录
    for root,dirs,files in os.walk(image_path):
        for file in files:
            #载入图片            
            image_data = Image.open(os.path.join(root,file)).resize([299,299])
            image_data_array = np.array(image_data)/255.0
            image_data_shape = np.reshape(image_data_array,[299,299,3])
            
            #传入图片不能是tensor类型 这里使用np转化成矩阵数组格式
            #原因出在tf.reshape(),因为网络训练时用placeholder定义了输入格式,所以输入不能用tensor,
            #而tf.reshape()返回结果就是一个tensor了,所以输入会报错。
            predictions = sess.run(oputs,{input_image:[image_data_shape]})
            predictions = np.squeeze(predictions) #转化为一维数据
            
            image_path = os.path.join(root,file)
            print(image_path)
            plt.imshow(image_data_shape)
            plt.axis('off')
            plt.show()

            #排序 取概率最大的5个值 然后倒序
            top_k = predictions.argsort()[-5:][::-1]
            node_lookup = NodeLookup()
            for node_id in top_k:
                #获取分类名称
                human_string = node_lookup.id_to_string(node_id)
                #获取分类置信度
                score = predictions[node_id]
                print('%s (score = %.5f)' %(human_string,score))
            print()



运行结果:

 

附录一下错误调试,

GPU的童鞋需要注意,训练的时候很容易出现cudnn错误,解决方法如下:

1.cudnn创建错误,环境没错的话就是显卡内存出错了,修改为按需分配
Problem:Could not create cudnn handle: CUDNN_STATUS_ALLOC_FAILED

      config = tf.ConfigProto()
      config.gpu_options.allow_growth = True
      session = tf.Session(config=config)

2.如果不清楚的模型的输入输出,使用saved_model_cli.py 可以解决很多问题,我被这个输入数据卡了3天,才找到这个解决方案。

3.训练和测试时很容易出现莫名其妙的错误,这个时候最好重启一下python服务,或者删除缓存文件,否则你会崩溃的

4.如果能科学上网,尽量科学上网把,太折腾人了

 

 

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章