使用tensorflow 提取特征图(How to use tensorflow to extract feature map from checkpoint)

如何使用tensorflow 提取特征图 ?

》首先模型训练结束后,我们会得到关于【3】【4】检查点的四个文件:.data, .index, .meta, checkpoint;

其中.meta是图结构,也就是神经网络的结构,在训练过程中图结构水不不变的,保存一次

即可。实现:saver  = tf.train.Saver(), saver.save(less, ‘model-name’,write_meta_graph=False);

.data是模型权重,偏置,操作等数值。

.index是主要保存.data数据中对应名字。

即.index 与 .data构成了键值对。

checkpoint 保存是在训练过程中所有中间节点上保存模型的名称。第一行保存最后一次保存的模型的名称。

》主要利用.data, .index, .meta文件提取特征图。基于【2】,过程如下

  1. 加载模型和权重
  2. 创建图
  3. 加载测试数据
  4. 设置输入,输出张量。在设置的时候,需要知道对应的操作的名字,可以 通过以下代码将操作名称打印出来。
  for op in graph.get_operations():
            print(op.name)

—————————————————————————————————————————————————

代码如下:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import os
import datetime
import SimpleITK as sitk
import numpy as np

image_path = "D:\KITS\Code\\vnet-tensorflow-master\extract_feature_map\TestPatch\image.nii.gz"
save_path = "D:\KITS\Code\\vnet-tensorflow-master\extract_feature_map\data_preprocess_vnet"
# select gpu devices
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # e.g. "0,1,2", "0,2"

# tensorflow app flags
FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_string('model_path','D:\KITS\same_space\data_preprocess\modelVNet2\\tmp\ckpt\checkpoint-24336.meta',
    """Path to saved models""")

tf.app.flags.DEFINE_string('checkpoint_path','D:\KITS\same_space\data_preprocess\modelVNet2\\tmp\ckpt\checkpoint-24336',
    """Directory of saved checkpoints""")


def trucateImage(image_np, low_value=-79, high_value=304):
    image_np = image_np.clip(min=low_value, max=high_value)
    image_np = image_np - 101
    image_np = np.true_divide(image_np, 76.9)
    return image_np

def evaluate():
    """evaluate the vnet model by stepwise moving along the 3D image"""
    # restore model grpah
    tf.reset_default_graph()
    # 从.meta文件加载模型
    imported_meta = tf.train.import_meta_graph(FLAGS.model_path)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    graph = tf.get_default_graph()
    # set input
    input_ = graph.get_tensor_by_name('images_placeholder:0')
    print('input shape:',input_.shape)

    # set output
    output_ = graph.get_tensor_by_name('vnet/vnet/encoder/level_2/conv_2/result:0')

    with tf.Session(config=config) as sess:
        print("{}: Start evaluation...".format(datetime.datetime.now()))
        # 从checkpoint-24336(.index, .data)文件加载权重
        imported_meta.restore(sess, FLAGS.checkpoint_path)
        print("{}: Restore checkpoint success".format(datetime.datetime.now()))

        image = sitk.ReadImage(image_path)
        image_np = sitk.GetArrayFromImage(image).astype(np.float32)
        image_np = trucateImage(image_np)

        patch = np.expand_dims(image_np, axis=0)
        patch = np.expand_dims(patch, axis=-1)
        patch_pd = sess.run(output_,feed_dict={input_: patch})

        np.save(os.path.join(save_path, "encoder_l2_conv_2.npy"), patch_pd)
        print("Finish inferencing ",patch_pd.shape)
        # for op in graph.get_operations():
        #     print(op.name)

def main():
    evaluate()
if __name__=='__main__':
    main()

—————————————————————————————————————————————————

 

参考文献:

【1】https://blog.csdn.net/qq_41185868/article/details/82903223

【2】https://murphypei.github.io/blog/2019/08/tensorflow-show-layer.html

【3】https://blog.csdn.net/u014090429/article/details/93487539

【4】https://www.cnblogs.com/azheng333/p/6972619.html

【5】https://machinelearningmastery.com/how-to-visualize-filters-and-feature-maps-in-convolutional-neural-networks/

 

—————————————————————————————————————————————————

 

sess.run(['predicted_label/prediction:0','softmax/softmax:0'], feed_dict={

'images_placeholder:0': batch, 

为什么传入参数是这种形式:———:0 后面还跟着一个:0? 0 表示 batch 中的第一个,如果 batch 是 1 就是全部结果了

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章