如何使用tensorflow 提取特徵圖 ?
》首先模型訓練結束後,我們會得到關於【3】【4】檢查點的四個文件:.data, .index, .meta, checkpoint;
其中.meta是圖結構,也就是神經網絡的結構,在訓練過程中圖結構水不不變的,保存一次
即可。實現:saver = tf.train.Saver(), saver.save(less, ‘model-name’,write_meta_graph=False);
.data是模型權重,偏置,操作等數值。
.index是主要保存.data數據中對應名字。
即.index 與 .data構成了鍵值對。
checkpoint 保存是在訓練過程中所有中間節點上保存模型的名稱。第一行保存最後一次保存的模型的名稱。
》主要利用.data, .index, .meta文件提取特徵圖。基於【2】,過程如下
- 加載模型和權重
- 創建圖
- 加載測試數據
- 設置輸入,輸出張量。在設置的時候,需要知道對應的操作的名字,可以 通過以下代碼將操作名稱打印出來。
for op in graph.get_operations():
print(op.name)
—————————————————————————————————————————————————
代碼如下:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import os
import datetime
import SimpleITK as sitk
import numpy as np
image_path = "D:\KITS\Code\\vnet-tensorflow-master\extract_feature_map\TestPatch\image.nii.gz"
save_path = "D:\KITS\Code\\vnet-tensorflow-master\extract_feature_map\data_preprocess_vnet"
# select gpu devices
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # e.g. "0,1,2", "0,2"
# tensorflow app flags
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('model_path','D:\KITS\same_space\data_preprocess\modelVNet2\\tmp\ckpt\checkpoint-24336.meta',
"""Path to saved models""")
tf.app.flags.DEFINE_string('checkpoint_path','D:\KITS\same_space\data_preprocess\modelVNet2\\tmp\ckpt\checkpoint-24336',
"""Directory of saved checkpoints""")
def trucateImage(image_np, low_value=-79, high_value=304):
image_np = image_np.clip(min=low_value, max=high_value)
image_np = image_np - 101
image_np = np.true_divide(image_np, 76.9)
return image_np
def evaluate():
"""evaluate the vnet model by stepwise moving along the 3D image"""
# restore model grpah
tf.reset_default_graph()
# 從.meta文件加載模型
imported_meta = tf.train.import_meta_graph(FLAGS.model_path)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
graph = tf.get_default_graph()
# set input
input_ = graph.get_tensor_by_name('images_placeholder:0')
print('input shape:',input_.shape)
# set output
output_ = graph.get_tensor_by_name('vnet/vnet/encoder/level_2/conv_2/result:0')
with tf.Session(config=config) as sess:
print("{}: Start evaluation...".format(datetime.datetime.now()))
# 從checkpoint-24336(.index, .data)文件加載權重
imported_meta.restore(sess, FLAGS.checkpoint_path)
print("{}: Restore checkpoint success".format(datetime.datetime.now()))
image = sitk.ReadImage(image_path)
image_np = sitk.GetArrayFromImage(image).astype(np.float32)
image_np = trucateImage(image_np)
patch = np.expand_dims(image_np, axis=0)
patch = np.expand_dims(patch, axis=-1)
patch_pd = sess.run(output_,feed_dict={input_: patch})
np.save(os.path.join(save_path, "encoder_l2_conv_2.npy"), patch_pd)
print("Finish inferencing ",patch_pd.shape)
# for op in graph.get_operations():
# print(op.name)
def main():
evaluate()
if __name__=='__main__':
main()
—————————————————————————————————————————————————
參考文獻:
【1】https://blog.csdn.net/qq_41185868/article/details/82903223
【2】https://murphypei.github.io/blog/2019/08/tensorflow-show-layer.html
【3】https://blog.csdn.net/u014090429/article/details/93487539
【4】https://www.cnblogs.com/azheng333/p/6972619.html
—————————————————————————————————————————————————
sess.run(['predicted_label/prediction:0','softmax/softmax:0'], feed_dict={
'images_placeholder:0': batch,
爲什麼傳入參數是這種形式:———:0, 後面還跟着一個:0? 0 表示 batch 中的第一個,如果 batch 是 1 就是全部結果了