使用tensorflow 提取特徵圖(How to use tensorflow to extract feature map from checkpoint)

如何使用tensorflow 提取特徵圖 ?

》首先模型訓練結束後,我們會得到關於【3】【4】檢查點的四個文件:.data, .index, .meta, checkpoint;

其中.meta是圖結構,也就是神經網絡的結構,在訓練過程中圖結構水不不變的,保存一次

即可。實現:saver  = tf.train.Saver(), saver.save(less, ‘model-name’,write_meta_graph=False);

.data是模型權重,偏置,操作等數值。

.index是主要保存.data數據中對應名字。

即.index 與 .data構成了鍵值對。

checkpoint 保存是在訓練過程中所有中間節點上保存模型的名稱。第一行保存最後一次保存的模型的名稱。

》主要利用.data, .index, .meta文件提取特徵圖。基於【2】,過程如下

  1. 加載模型和權重
  2. 創建圖
  3. 加載測試數據
  4. 設置輸入,輸出張量。在設置的時候,需要知道對應的操作的名字,可以 通過以下代碼將操作名稱打印出來。
  for op in graph.get_operations():
            print(op.name)

—————————————————————————————————————————————————

代碼如下:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import os
import datetime
import SimpleITK as sitk
import numpy as np

image_path = "D:\KITS\Code\\vnet-tensorflow-master\extract_feature_map\TestPatch\image.nii.gz"
save_path = "D:\KITS\Code\\vnet-tensorflow-master\extract_feature_map\data_preprocess_vnet"
# select gpu devices
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # e.g. "0,1,2", "0,2"

# tensorflow app flags
FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_string('model_path','D:\KITS\same_space\data_preprocess\modelVNet2\\tmp\ckpt\checkpoint-24336.meta',
    """Path to saved models""")

tf.app.flags.DEFINE_string('checkpoint_path','D:\KITS\same_space\data_preprocess\modelVNet2\\tmp\ckpt\checkpoint-24336',
    """Directory of saved checkpoints""")


def trucateImage(image_np, low_value=-79, high_value=304):
    image_np = image_np.clip(min=low_value, max=high_value)
    image_np = image_np - 101
    image_np = np.true_divide(image_np, 76.9)
    return image_np

def evaluate():
    """evaluate the vnet model by stepwise moving along the 3D image"""
    # restore model grpah
    tf.reset_default_graph()
    # 從.meta文件加載模型
    imported_meta = tf.train.import_meta_graph(FLAGS.model_path)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    graph = tf.get_default_graph()
    # set input
    input_ = graph.get_tensor_by_name('images_placeholder:0')
    print('input shape:',input_.shape)

    # set output
    output_ = graph.get_tensor_by_name('vnet/vnet/encoder/level_2/conv_2/result:0')

    with tf.Session(config=config) as sess:
        print("{}: Start evaluation...".format(datetime.datetime.now()))
        # 從checkpoint-24336(.index, .data)文件加載權重
        imported_meta.restore(sess, FLAGS.checkpoint_path)
        print("{}: Restore checkpoint success".format(datetime.datetime.now()))

        image = sitk.ReadImage(image_path)
        image_np = sitk.GetArrayFromImage(image).astype(np.float32)
        image_np = trucateImage(image_np)

        patch = np.expand_dims(image_np, axis=0)
        patch = np.expand_dims(patch, axis=-1)
        patch_pd = sess.run(output_,feed_dict={input_: patch})

        np.save(os.path.join(save_path, "encoder_l2_conv_2.npy"), patch_pd)
        print("Finish inferencing ",patch_pd.shape)
        # for op in graph.get_operations():
        #     print(op.name)

def main():
    evaluate()
if __name__=='__main__':
    main()

—————————————————————————————————————————————————

 

參考文獻:

【1】https://blog.csdn.net/qq_41185868/article/details/82903223

【2】https://murphypei.github.io/blog/2019/08/tensorflow-show-layer.html

【3】https://blog.csdn.net/u014090429/article/details/93487539

【4】https://www.cnblogs.com/azheng333/p/6972619.html

【5】https://machinelearningmastery.com/how-to-visualize-filters-and-feature-maps-in-convolutional-neural-networks/

 

—————————————————————————————————————————————————

 

sess.run(['predicted_label/prediction:0','softmax/softmax:0'], feed_dict={

'images_placeholder:0': batch, 

爲什麼傳入參數是這種形式:———:0 後面還跟着一個:0? 0 表示 batch 中的第一個,如果 batch 是 1 就是全部結果了

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章