TensorFlow模型保存和載入方法彙總

目錄

 

一、TensorFlow常規模型加載方法

保存模型

加載模型

1.不加載圖結構,只加載參數

  2.加載圖結構和參數

  3.簡化版本

二、TensorFlow二進制模型加載方法

三、二進制模型製作

四、從圖上讀取張量

從二進制模型加載張量

從當前圖中獲取對應張量

從圖中獲取節點信息


一、TensorFlow常規模型加載方法

保存模型

tf.train.Saver()類,.save(sess, ckpt文件目錄)方法

參數名稱 功能說明 默認值
var_list Saver中存儲變量集合 全局變量集合
reshape 加載時是否恢復變量形狀 True
sharded 是否將變量輪循放在所有設備上 True
max_to_keep 保留最近檢查點個數 5
restore_sequentially 是否按順序恢復變量,模型較大時順序恢復內存消耗小 True

 

var_list是字典形式{變量名字符串: 變量符號},相對應的restore也根據同樣形式的字典將ckpt中的字符串對應的變量加載給程序中的符號。

如果Saver給定了字典作爲加載方式,則按照字典來,如:saver = tf.train.Saver({"v/ExponentialMovingAverage":v}),否則每個變量尋找自己的name屬性在ckpt中的對應值進行加載。

加載模型

當我們基於checkpoint文件(ckpt)加載參數時,實際上我們使用Saver.restore取代了initializer的初始化

checkpoint文件會記錄保存信息,通過它可以定位最新保存的模型:

1

2

ckpt = tf.train.get_checkpoint_state('./model/')

print(ckpt.model_checkpoint_path)

 

.meta文件保存了當前圖結構

.data文件保存了當前參數名和值

.index文件保存了輔助索引信息

.data文件可以查詢到參數名和參數值,使用下面的命令可以查詢保存在文件中的全部變量{名:值}對,

1

2

from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file

print_tensors_in_checkpoint_file(os.path.join(savedir,savefile),None,True)

tf.train.import_meta_graph函數給出model.ckpt-n.meta的路徑後會加載圖結構,並返回saver對象

1

ckpt = tf.train.get_checkpoint_state('./model/')

tf.train.Saver函數會返回加載默認圖的saver對象,saver對象初始化時可以指定變量映射方式,根據名字映射變量

1

saver = tf.train.Saver({"v/ExponentialMovingAverage":v}) 

saver.restore函數給出model.ckpt-n的路徑後會自動尋找參數名-值文件進行加載

1

2

saver.restore(sess,'./model/model.ckpt-0')

saver.restore(sess,ckpt.model_checkpoint_path)

1.不加載圖結構,只加載參數

由於實際上我們參數保存的都是Variable變量的值,所以其他的參數值(例如batch_size)等,我們在restore時可能希望修改,但是圖結構在train時一般就已經確定了,所以我們可以使用tf.Graph().as_default()新建一個默認圖(建議使用上下文環境),利用這個新圖修改和變量無關的參值大小,從而達到目的。

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

'''

使用原網絡保存的模型加載到自己重新定義的圖上

可以使用python變量名加載模型,也可以使用節點名

'''

import AlexNet as Net

import AlexNet_train as train

import random

import tensorflow as tf

 

IMAGE_PATH = './flower_photos/daisy/5673728_71b8cb57eb.jpg'

 

with tf.Graph().as_default() as g:

 

    x = tf.placeholder(tf.float32, [1, train.INPUT_SIZE[0], train.INPUT_SIZE[1], 3])

    y = Net.inference_1(x, N_CLASS=5, train=False)

 

    with tf.Session() as sess:

        # 程序前面得有 Variable 供 save or restore 纔不報錯

        # 否則會提示沒有可保存的變量

        saver = tf.train.Saver()

 

        ckpt = tf.train.get_checkpoint_state('./model/')

        img_raw = tf.gfile.FastGFile(IMAGE_PATH, 'rb').read()

        img = sess.run(tf.expand_dims(tf.image.resize_images(

            tf.image.decode_jpeg(img_raw),[224,224],method=random.randint(0,3)),0))

 

        if ckpt and ckpt.model_checkpoint_path:

            print(ckpt.model_checkpoint_path)

            saver.restore(sess,'./model/model.ckpt-0')

            global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]

            res = sess.run(y, feed_dict={x: img})

            print(global_step,sess.run(tf.argmax(res,1)))

  2.加載圖結構和參數

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

'''

直接使用使用保存好的圖

無需加載python定義的結構,直接使用節點名稱加載模型

由於節點形狀已經定下來了,所以有不便之處,placeholder定義batch後單張傳會報錯

現階段不推薦使用,以後如果理解深入了可能會找到使用方法

'''

import AlexNet_train as train

import random

import tensorflow as tf

 

IMAGE_PATH = './flower_photos/daisy/5673728_71b8cb57eb.jpg'

 

 

ckpt = tf.train.get_checkpoint_state('./model/')                          # 通過檢查點文件鎖定最新的模型

saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta')   # 載入圖結構,保存在.meta文件中

 

with tf.Session() as sess:

    saver.restore(sess,ckpt.model_checkpoint_path)                        # 載入參數,參數保存在兩個文件中,不過restore會自己尋找

 

    img_raw = tf.gfile.FastGFile(IMAGE_PATH, 'rb').read()

    img = sess.run(tf.image.resize_images(

        tf.image.decode_jpeg(img_raw), train.INPUT_SIZE, method=random.randint(0, 3)))

    imgs = []

    for i in range(128):

       imgs.append(img)

    print(sess.run(tf.get_default_graph().get_tensor_by_name('fc3:0'),feed_dict={'Placeholder:0': imgs}))

 

    '''

    img = sess.run(tf.expand_dims(tf.image.resize_images(

        tf.image.decode_jpeg(img_raw), train.INPUT_SIZE, method=random.randint(0, 3)), 0))

    print(img)

    imgs = []

    for i in range(128):

        imgs.append(img)

    print(sess.run(tf.get_default_graph().get_tensor_by_name('conv1:0'),

                   feed_dict={'Placeholder:0':img}))

注意,在所有兩種方式中都可以通過調用節點名稱使用節點輸出張量,節點.name屬性返回節點名稱。

  3.簡化版本

1

2

3

4

5

6

7

8

9

10

11

12

# 連同圖結構一同加載

ckpt = tf.train.get_checkpoint_state('./model/')

saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta')

with tf.Session() as sess:

    saver.restore(sess,ckpt.model_checkpoint_path)

             

# 只加載數據,不加載圖結構,可以在新圖中改變batch_size等的值

# 不過需要注意,Saver對象實例化之前需要定義好新的圖結構,否則會報錯

saver = tf.train.Saver()

with tf.Session() as sess:

    ckpt = tf.train.get_checkpoint_state('./model/')

    saver.restore(sess,ckpt.model_checkpoint_path)

二、TensorFlow二進制模型加載方法

這種加載方法一般是對應網上各大公司已經訓練好的網絡模型進行修改的工作

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

# 新建空白圖

self.graph = tf.Graph()

# 空白圖列爲默認圖

with self.graph.as_default():

    # 二進制讀取模型文件

    with tf.gfile.FastGFile(os.path.join(model_dir,model_name),'rb') as f:

        # 新建GraphDef文件,用於臨時載入模型中的圖

        graph_def = tf.GraphDef()

        # GraphDef加載模型中的圖

        graph_def.ParseFromString(f.read())

        # 在空白圖中加載GraphDef中的圖

        tf.import_graph_def(graph_def,name='')

        # 在圖中獲取張量需要使用graph.get_tensor_by_name加張量名

        # 這裏的張量可以直接用於session的run方法求值了

        # 補充一個基礎知識,形如'conv1'是節點名稱,而'conv1:0'是張量名稱,表示節點的第一個輸出張量

        self.input_tensor = self.graph.get_tensor_by_name(self.input_tensor_name)

        self.layer_tensors = [self.graph.get_tensor_by_name(name + ':0') for name   in self.layer_operation_names]

 

上面兩篇都使用了二進制加載模型的方式

三、二進制模型製作

這節是關於tensorflow的Freezing,字面意思是冷凍,可理解爲整合合併;整合什麼呢,就是將模型文件和權重文件整合合併爲一個文件,主要用途是便於發佈。

tensorflow在訓練過程中,通常不會將權重數據保存的格式文件裏(這裏我理解是模型文件),反而是分開保存在一個叫checkpoint的檢查點文件裏,當初始化時,再通過模型文件裏的變量Op節點來從checkoupoint文件讀取數據並初始化變量。這種模型和權重數據分開保存的情況,使得發佈產品時不是那麼方便,我們可以將tf的圖和參數文件整合進一個後綴爲pb的二進制文件中,由於整合過程回將變量轉化爲常量,所以我們在日後讀取模型文件時不能夠進行訓練,僅能向前傳播,而且我們在保存時需要指定節點名稱。

將圖變量轉換爲常量的API:tf.graph_util.convert_variables_to_constants

轉換後的graph_def對象轉換爲二進制數據(graph_def.SerializeToString())後,寫入pb即可。

1

2

3

4

5

6

7

8

9

10

11

12

13

import tensorflow as tf

 

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name='v1')

v2 = tf.Variable(tf.constant(2.0, shape=[1]), name='v2')

result = v1 + v2

 

saver = tf.train.Saver()

with tf.Session() as sess:

    sess.run(tf.global_variables_initializer())

    saver.save(sess, './tmodel/test_model.ckpt')

    gd = tf.graph_util.convert_variables_to_constants(sess, tf.get_default_graph().as_graph_def(), ['add'])

with tf.gfile.GFile('./tmodel/model.pb', 'wb') as f:

    f.write(gd.SerializeToString())

我們可以直接查看gd:

node {
  name: "v1"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_FLOAT
        tensor_shape {
          dim {
            size: 1
          }
        }
        float_val: 1.0
      }
    }
  }
}
……
node {
  name: "add"
  op: "Add"
  input: "v1/read"
  input: "v2/read"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}
library {
 

四、從圖上讀取張量

上面的代碼實際上已經包含了本小節的內容,但是由於從圖上讀取特定的張量是如此的重要,所以我仍然單獨的補充上這部分的內容。

無論如何,想要獲取特定的張量我們必須要有張量的名稱圖的句柄,比如 'import/pool_3/_reshape:0' 這種,有了張量名和圖,索引就很簡單了。

從二進制模型加載張量

第二小節的代碼很好的展示了這種情況

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape:0'  # 瓶頸層輸出張量名稱

JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents:0'  # 輸入層張量名稱

MODEL_DIR = './inception_dec_2015'  # 模型存放文件夾

MODEL_FILE = 'tensorflow_inception_graph.pb'  # 模型名

 

 

# 加載模型

# with gfile.FastGFile(os.path.join(MODEL_DIR,MODEL_FILE),'rb') as f:   # 閱讀器上下文

with open(os.path.join(MODEL_DIR, MODEL_FILE), 'rb') as f:  # 閱讀器上下文

    graph_def = tf.GraphDef()  # 生成圖

    graph_def.ParseFromString(f.read())  # 圖加載模型

# 加載圖上節點張量(按照句柄理解)

bottleneck_tensor, jpeg_data_tensor = tf.import_graph_def(  # 從圖上讀取張量,同時導入默認圖

    graph_def,

    return_elements=[BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME])

從當前圖中獲取對應張量

這個就是很普通的情況,從我們當前操作的圖中獲取某個張量,用於feed啦或者用於輸出等操作,API也很簡單,用法如下:

g.get_tensor_by_name('import/pool_3/_reshape:0')

 g表示當前圖句柄,可以簡單的使用 g = tf.get_default_graph() 獲取。

從圖中獲取節點信息

有的時候我們對於模型中的節點並不夠了解,此時我們可以通過圖句柄來查詢圖的構造:

1

2

g = tf.get_default_graph()

print(g.as_graph_def().node)

這個操作將返回圖的構造結構。從這裏,對比前面的代碼,我們也可以瞭解到:graph_def 實際就是圖的結構信息存儲形式,我們可以將之還原爲圖(二進制模型加載代碼中展示了),也可以從圖中將之提取出來(本部分代碼)。[Ref]


查看TensorFlow中checkpoint內變量的幾種方法

查看ckpt中變量的方法有三種:

  1. 在有model的情況下,使用tf.train.Saver進行restore
  2. 使用tf.train.NewCheckpointReader直接讀取ckpt文件,這種方法不需要model。
  3. 使用tools裏的freeze_graph來讀取ckpt

注意:

  1. 如果模型保存爲.ckpt的文件,則使用該文件就可以查看.ckpt文件裏的變量。ckpt路徑爲 model.ckpt
  2. 如果模型保存爲.ckpt-xxx-data (圖結構)、.ckpt-xxx.index (參數名)、.ckpt-xxx-meta (參數值)文件,則需要同時擁有這三個文件纔行。並且ckpt的路徑爲 model.ckpt-xxx

1. 基於model來讀取ckpt文件裏的變量

1.首先建立model
2.從ckpt中恢復變量

1

2

3

4

5

6

7

8

9

10

with tf.Graph().as_default() as g:

  #建立model

  images, labels = cifar10.inputs(eval_data=eval_data)

  logits = cifar10.inference(images)

  top_k_op = tf.nn.in_top_k(logits, labels, 1)

  #從ckpt中恢復變量

  sess = tf.Session()

  saver = tf.train.Saver() #saver = tf.train.Saver(...variables...) # 恢復部分變量時,只需要在Saver裏指定要恢復的變量

  save_path = 'ckpt的路徑'

  saver.restore(sess, save_path) # 從ckpt中恢復變量

注意:基於model來讀取ckpt中變量時,model和ckpt必須匹配。

2. 使用tf.train.NewCheckpointReader直接讀取ckpt文件裏的變量,使用tools.inspect_checkpoint裏的print_tensors_in_checkpoint_file函數打印ckpt裏的東西

1

2

3

4

5

6

7

8

#使用NewCheckpointReader來讀取ckpt裏的變量

from tensorflow.python import pywrap_tensorflow

checkpoint_path = os.path.join(model_dir, "model.ckpt")

reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path) #tf.train.NewCheckpointReader

var_to_shape_map = reader.get_variable_to_shape_map()

for key in var_to_shape_map:

  print("tensor_name: ", key)

  #print(reader.get_tensor(key))

1

2

3

4

5

6

7

8

#使用print_tensors_in_checkpoint_file打印ckpt裏的內容

from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file

 

print_tensors_in_checkpoint_file(file_name, #ckpt文件名字

                 tensor_name, # 如果爲None,則默認爲ckpt裏的所有變量

                 all_tensors, # bool 是否打印所有的tensor,這裏打印出的是tensor的值,一般不推薦這裏設置爲False

                 all_tensor_names) # bool 是否打印所有的tensor的name

#上面的打印ckpt的內部使用的是pywrap_tensorflow.NewCheckpointReader所以,掌握NewCheckpointReader纔是王道

3.使用tools裏的freeze_graph來讀取ckpt

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

from tensorflow.python.tools import freeze_graph

 

freeze_graph(input_graph, #=some_graph_def.pb

       input_saver,

       input_binary,

       input_checkpoint, #=model.ckpt

       output_node_names, #=softmax

       restore_op_name,

       filename_tensor_name,

       output_graph, #='./tmp/frozen_graph.pb'

       clear_devices,

       initializer_nodes,

       variable_names_whitelist='',

       variable_names_blacklist='',

       input_meta_graph=None,

       input_saved_model_dir=None,

       saved_model_tags='serve',

       checkpoint_version=2)

#freeze_graph_test.py講述了怎麼使用freeze_grapg。

使用freeze_graph可以將圖和ckpt進行合併。[Ref]


一般情況下,我們得到一個模型後都想知道模型裏面的張量,下面分別從ckpt模型和pb模型中讀取裏面的張量名字。
1.讀取ckpt模型裏面的張量

首先,ckpt模型需包含以下文件,一個都不能少


然後編寫代碼,將所有張量的名字都保存到tensor_name_list_ckpt.txt文件中

import tensorflow as tf

#直接讀取圖的結構,不需要手動重新定義 
meta_graph = tf.train.import_meta_graph("model.ckpt.meta")

with tf.Session()as sess:
	meta_graph.restore(sess,"D:/Face_recognition_github/20180402-114759/model.ckpt")

	tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
	with open("tensor_name_list_ckpt.txt",'a+')as f:
		for tensor_name in tensor_name_list:
			f.write(tensor_name+"\n")
			# print(tensor_name,'\n')
		f.close()

 

運行結果截圖(部分)


2.讀取pb模型裏面的張量

需要一個pb文件

編寫代碼

import tensorflow as tf

model_path = "D:/Face_recognition_github/20180402-114759/20180402-114759.pb"

with tf.gfile.FastGFile(model_path,'rb')as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def,name='')

    tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
    with open('tensor_name_list_pb.txt','a')as t:
        for tensor_name in tensor_name_list:
            t.write(tensor_name+'\n')
            print(tensor_name,'\n')
        t.close()

順便再查看pb模型裏面的張量的屬性(ckpt模型的操作類似),保存到txt文件中[Ref]

import tensorflow as tf

model_path = "/home/boss/Study/face_recognition_flask/20180402-114759/model.pb"

with tf.gfile.FastGFile(model_path,'rb')as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def,name='')

    # tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
    # with open('tensor_name_list_pb.txt','a')as t:
    #     for tensor_name in tensor_name_list:
    #         t.write(tensor_name+'\n')
    #         print(tensor_name,'\n')
    #     t.close()
    with tf.Session()as sess:
        op_list = sess.graph.get_operations()
        with open("model裏面張量的屬性.txt",'a+')as f:
            for index,op in enumerate(op_list):
                f.write(str(op.name)+"\n")                   #張量的名稱
                f.write(str(op.values())+"\n")              #張量的屬性

運行結果截圖(部分)


用於獲得一個pb文件的所有節點名稱 

# -*- coding: utf-8 -*-
"""
Created on Tue Dec 18 18:31:13 2018
1、model_dir爲模型路徑文件夾,model_name爲模型名稱(自定義非如alexnet等訓練實際名稱)
2、寫入到模型路徑下的result.txt文件內
@author: Mr_dogyang
"""
 
import tensorflow as tf
import os
 
model_dir = 'D:\\TensorFlow\\MyTensorFlow\\MyTensorFlow\\slim\\satellite'
model_name = 'inception_v3_frozen_graph.pb'
 
# 讀取並創建一個圖graph來存放Google訓練好的Inception_v3模型(函數)
def create_graph():
    with tf.gfile.FastGFile(os.path.join(
            model_dir, model_name), 'rb') as f:
        # 使用tf.GraphDef()定義一個空的Graph
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        # Imports the graph from graph_def into the current default Graph.
        tf.import_graph_def(graph_def, name='')
 
# 創建graph
create_graph()
 
tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
result_file = os.path.join(model_dir, 'result.txt') 
with open(result_file, 'w+') as f:
    for tensor_name in tensor_name_list:
        f.write(tensor_name+'\n')

Tensorflow學習教程------下載圖像識別模型inceptionV3

# coding: utf-8
 
import tensorflow as tf
import os
import tarfile
import requests
 

#inception模型下載地址
inception_pretrain_model_url = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
 
#模型存放地址
inception_pretrain_model_dir = "inception_model"
if not os.path.exists(inception_pretrain_model_dir):
    os.makedirs(inception_pretrain_model_dir)
     
#獲取文件名,以及文件路徑
filename = inception_pretrain_model_url.split('/')[-1]
filepath = os.path.join(inception_pretrain_model_dir, filename)
 
#下載模型
if not os.path.exists(filepath):
    print("download: ", filename)
    r = requests.get(inception_pretrain_model_url, stream=True)
    with open(filepath, 'wb') as f:
        for chunk in r.iter_content(chunk_size=1024):
            if chunk:
                f.write(chunk)
print("finish: ", filename)
#解壓文件
tarfile.open(filepath, 'r:gz').extractall(inception_pretrain_model_dir)
  
#模型結構存放文件
log_dir = 'inception_log'
if not os.path.exists(log_dir):
    os.makedirs(log_dir)
 
#classify_image_graph_def.pb爲google訓練好的模型
inception_graph_def_file = os.path.join(inception_pretrain_model_dir, 'classify_image_graph_def.pb')
with tf.Session() as sess:
    #創建一個圖來存放google訓練好的模型
    with tf.gfile.FastGFile(inception_graph_def_file, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def, name='')
    #保存圖的結構
    writer = tf.summary.FileWriter(log_dir, sess.graph)
    writer.close()

[Ref]


用tensorflow神經網絡實現一個簡易的圖片分類器 

這篇文章我們將用 CIFAR-10數據集做一個很簡易的圖片分類器。 在 CIFAR-10數據集包含了60,000張圖片。在此數據集中,有10個不同的類別,每個類別中有6,000個圖像。每幅圖像的大小爲32 x 32像素。雖然這麼小的尺寸通常給人類識別正確的類別帶來了困難,但它實際上是對計算機模型的簡化並且減少了分析圖像所需的計算。

                                                                                     CIFAR-10數據集

我們可以通過輸入模型的大量數字序列將這些圖像輸入到我們的模型中。每個像素由三個浮點數標識,這三個浮點數表示該像素的紅色,綠色和藍色值(RGB值)。所以每個圖像有32 x 32 x 3 = 3,072 個值0.

使用非常大的卷積神經網絡可以實現高質量的結果,你可以在這個連接中學習Rodrigo Benenson’s page

 

下載CIFAR-10數據集,網址:Python version of the dataset, 並把他安裝在我們分類器代碼所在的文件夾下

 

 先上源代碼

 模型的源代碼:

複製代碼

import numpy as np
import tensorflow as tf
import time
import data_helpers
beginTime = time.time()


batch_size = 100
learning_rate = 0.005
max_steps = 1000

data_sets = data_helpers.load_data()


# Define input placeholders
images_placeholder = tf.placeholder(tf.float32, shape=[None, 3072])
labels_placeholder = tf.placeholder(tf.int64, shape=[None])

# Define variables (these are the values we want to optimize)
weights = tf.Variable(tf.zeros([3072, 10]))
biases = tf.Variable(tf.zeros([10]))

# Define the classifier's result
logits = tf.matmul(images_placeholder, weights) + biases

# Define the loss function
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                                     labels=labels_placeholder))

# Define the training operation
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)

# Operation comparing prediction with true label
correct_prediction = tf.equal(tf.argmax(logits, 1), labels_placeholder)

# Operation calculating the accuracy of our predictions
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))


with tf.Session() as sess:
    # Initialize variables
    sess.run(tf.global_variables_initializer())

    # Repeat max_steps times
    for i in range(max_steps):

        # Generate input data batch
        indices = np.random.choice(data_sets['images_train'].shape[0], batch_size)
        images_batch = data_sets['images_train'][indices]
        labels_batch = data_sets['labels_train'][indices]

        # Periodically print out the model's current accuracy
        if i % 100 == 0:
            train_accuracy = sess.run(accuracy, feed_dict={
                images_placeholder: images_batch, labels_placeholder: labels_batch})
            print('Step {:5d}: training accuracy {:g}'.format(i, train_accuracy))

        # Perform a single training step
        sess.run(train_step, feed_dict={images_placeholder: images_batch,
                                        labels_placeholder: labels_batch})

    # After finishing the training, evaluate on the test set
    test_accuracy = sess.run(accuracy, feed_dict={
        images_placeholder: data_sets['images_test'],
        labels_placeholder: data_sets['labels_test']})
    print('Test accuracy {:g}'.format(test_accuracy))

endTime = time.time()
print('Total time: {:5.2f}s'.format(endTime - beginTime))

複製代碼

處理數據集的代碼

複製代碼

import numpy as np
import pickle
import sys


def load_CIFAR10_batch(filename):
    '''load data from single CIFAR-10 file'''

    with open(filename, 'rb') as f:
        if sys.version_info[0] < 3:
            dict = pickle.load(f)
        else:
            dict = pickle.load(f, encoding='latin1')
        x = dict['data']
        y = dict['labels']
        x = x.astype(float)
        y = np.array(y)
    return x, y


def load_data():
    '''load all CIFAR-10 data and merge training batches'''

    xs = []
    ys = []
    for i in range(1, 6):
        filename = 'cifar-10-batches-py/data_batch_' + str(i)
        X, Y = load_CIFAR10_batch(filename)
        xs.append(X)
        ys.append(Y)

    x_train = np.concatenate(xs)
    y_train = np.concatenate(ys)
    del xs, ys

    x_test, y_test = load_CIFAR10_batch('cifar-10-batches-py/test_batch')

    classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
               'ship', 'truck']

    # Normalize Data
    mean_image = np.mean(x_train, axis=0)
    x_train -= mean_image
    x_test -= mean_image

    data_dict = {
        'images_train': x_train,
        'labels_train': y_train,
        'images_test': x_test,
        'labels_test': y_test,
        'classes': classes
    }
    return data_dict


def reshape_data(data_dict):
    im_tr = np.array(data_dict['images_train'])
    im_tr = np.reshape(im_tr, (-1, 3, 32, 32))
    im_tr = np.transpose(im_tr, (0, 2, 3, 1))
    data_dict['images_train'] = im_tr
    im_te = np.array(data_dict['images_test'])
    im_te = np.reshape(im_te, (-1, 3, 32, 32))
    im_te = np.transpose(im_te, (0, 2, 3, 1))
    data_dict['images_test'] = im_te
    return data_dict



def gen_batch(data, batch_size, num_iter):
    data = np.array(data)
    index = len(data)
    for i in range(num_iter):
        index += batch_size
        if (index + batch_size > len(data)):
            index = 0
            shuffled_indices = np.random.permutation(np.arange(len(data)))
            data = data[shuffled_indices]
        yield data[index:index + batch_size]


def main():
    data_sets = load_data()
    print(data_sets['images_train'].shape)
    print(data_sets['labels_train'].shape)
    print(data_sets['images_test'].shape)
    print(data_sets['labels_test'].shape)


if __name__ == '__main__':
    main()

複製代碼

 

首先我們導入了tensorflow numpy time 以及自己寫的data_help包

time是爲了計算整個代碼的運行時間。 data_help是將數據集做成我們訓練用的數據結構

data_help中的load_data()會把60000張的CIFAR數據集分成兩塊:500000張的訓練集和100000張的測試集,具體來說他會返回這樣的一個包含如下內容的字典

  • images_train: 訓練集。一個500000張 包含3072(32x32像素點x3顏色通道)值
  • labels_train: 訓練集的50,000個標籤(每個標籤在0到9之間,代表訓練圖像所屬的10個類別中的哪一個)
  • images_test: 測試集(10,000 by 3,072)
  • labels_test: 測試集的10,000個標籤
  • classes: 10個文本標籤,用於將數字類值轉換爲單詞(例如0代表'plane',1代表'car')

 然後我們就可以開始建立我們的模型了

先頂兩個tensroflow的佔位符 這些佔位符不包含任何數據,但僅指定輸入數據的類型和形狀:

  images_placeholder = tf.placeholder(tf.float32, shape=[None, 3072]) 
  labels_placeholder = tf.placeholder(tf.int64, shape=[None])    #值得注意的是,這邊的Dtype是int 還有shape是沒有維度的(一維的)

然後我們定義偏置和權重

  weights = tf.Variable(tf.zeros([3072, 10]))
  biases = tf.Variable(tf.zeros([10]))

我們的輸入由3,072個浮點數組成,但我們尋找的輸出是10個不同的整數值之一,代表一個類別。我們如何從3,072個值到單個值?

我們採用的簡單方法是分別查看每個像素。對於每個像素和每個可能的類別,我們想知道該像素的顏色是增加還是減少屬於特定類別的概率。例如,如果第一個像素是紅色 - 並且如果汽車的圖像通常具有紅色的第一個像素,那麼我們希望汽車類別的分數增加。我們通過將紅色通道值乘以正數並將其添加到汽車類別得分來實現此目的。

同樣,如果馬圖像在位置1很少有紅色像素,我們希望該分數降低。這意味着乘以小數或負數並將結果添加到馬匹得分中。對於10個類別中的每個類別,我們在每個像素上重複此步驟,然後總結所有3,072個值以獲得單個總分。這是我們的3,072像素值的總和,由該類別的3,072參數權重加權。這裏的最終結果是我們將得到10個分數 - 每個類別一個。最高分給我們分類。

使用矩陣,我們可以大大簡化用於將像素值與權重值相乘並總結結果的方案。我們用3,072維向量表示單個圖像。如果我們將此向量乘以3,072 x 10權重矩陣,則結果是一個10維矩陣,其中包含我們想要的加權和。

 

3,072 x 10矩陣中的實際值是模型參數。但是,如果它們是隨機的並且毫無意義,那麼輸出也將是。在這裏,我們可以看到訓練數據的值,它準備模型以最終自己確定參數值。 

在上面的兩行代碼中,我們通知TensorFlow 3,072 x 10加權參數矩陣 - 所有這些參數在開始時都具有初始值0。我們還定義了第二個參數:包含偏差的10維數組。偏差不直接與圖像數據相互作用,而是加到加權和 - 每個分數的起點。想象一個全黑圖像:所有像素值都是0,因此它的所有類別得分都是0(與權重矩陣中的值無關)。偏見允許我們從非零類別分數開始。 

訓練方案的工作原理如下:首先,我們輸入訓練數據並讓模型使用當前參數值進行預測。使用正確的類別對該預測進行比較,並且該比較的數值結果稱爲損失。損失值越小,類別預測越接近正確的類別 - 反之亦然。目的是儘量減少損失。但在我們看一下損失最小化之前,讓我們來看看如何計算損失。

  # Define loss function
  loss=tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits,labels_placeholder))

 TensorFlow通過提供處理所有這些的功能來處理我們的所有細節。然後,我們可以將logits中包含的模型預測與labels_placeholder(正確的類別標籤)進行比較。 sparse_softmax_cross_entropy_with_logits()的輸出是每個圖像的損失值。最後,我們計算所有輸入圖像的平均損失值。

tf.nn.sparse_softmax_cross_entropy_with_logits()這個函數的功能就是計算labels和logits之間的交叉熵(cross entropy)。

複製代碼

import tensorflow as tf

input_data = tf.Variable([[0.2, 0.1, 0.9], [0.3, 0.4, 0.6]], dtype=tf.float32)
output = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=input_data, labels=[0, 2])
with tf.Session() as sess:
    init = tf.global_variables_initializer()
    sess.run(init)
    print(sess.run(output))
# [ 1.36573195  0.93983102]

複製代碼

 

 

 這邊順便介紹一下tf.nn.softmax_cross_entopy_with_logits()

複製代碼

tf.nn.softmax_cross_entropy_with_logits(
    _sentinel=None,
    labels=None,
    logits=None,
    dim=-1,
    name=None
)

複製代碼

第一個參數基本不用。此處不說明。
第二個參數label的含義就是一個分類標籤,所不同的是,這個label是分類的概率,比如說[0.2,0.3,0.5],labels的每一行必須是一個概率分佈。

現在來說明第三個參數logits,logit本身就是是一種函數,它把某個概率p從[0,1]映射到[-inf,+inf](即正負無窮區間)。這個函數的形式化描述爲:logit=ln(p/(1-p))。
我們可以把logist理解爲原生態的、未經縮放的,可視爲一種未歸一化的log 概率,如是[4, 1, -2]

於是,Softmax的工作則是,它把一個系列數從[-inf, +inf] 映射到[0,1],除此之外,它還把所有參與映射的值累計之和等於1,變成諸如[0.95, 0.05, 0]的概率向量。這樣一來,經過Softmax加工的數據可以當做概率來用。

也就是說,logits是作爲softmax的輸入。經過softmax的加工,就變成“歸一化”的概率(設爲q),然後和labels代表的概率分佈(設爲q),於是,整個函數的功能就是前面的計算labels(概率分佈p)和logits(概率分佈q)之間的交叉熵

(1)如果labels的每一行是one-hot表示,也就是隻有一個地方爲1(或者說100%),其他地方爲0(或者說0%),還可以使用tf.sparse_softmax_cross_entropy_with_logits()。之所以用100%和0%描述,就是讓它看起來像一個概率分佈。
(2)tf.nn.softmax_cross_entropy_with_logits()函數已經過時 (deprecated),它在TensorFlow未來的版本中將被去除。取而代之的是

tf.nn.softmax_cross_entropy_with_logits_v2()。

(3)參數labels,logits必須有相同的形狀 [batch_size, num_classes] 和相同的類型(float16, float32, float64)中的一種,否則交叉熵無法計算。

(4)tf.nn.softmax_cross_entropy_with_logits 函數內部的 logits 不能進行縮放,因爲在這個工作會在改函數內部進行(注意函數名稱中的 softmax ,它負責完成原始數據的歸一化),如果 logits 進行了縮放,那麼反而會影響計算正確性。
 

-------------------------------------------------------------------------------------------------------------------------------------------

最後,我們計算所有輸入圖像的平均損失值。

# Define training operation
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)

如何改變參數值以減少損失? TensorFlow在這裏發光,使用一種稱爲自動微分的技術,它根據參數值計算損耗的梯度。它計算每個參數對總體損失的影響,以及減少或增加少量用於減少損失的程度。它試圖通過遞歸調整所有參數值來提高準確性。完成此步驟後,將使用下一個圖像組重新啓動該過程。 

TensorFlow包含各種優化技術,用於將梯度信息轉換爲參數的更新。對於本教程中的目的,我們選擇簡單的梯度下降選項,該選項僅檢查模型的當前狀態以確定如何更新參數,而不考慮先前的參數值。 

對輸入圖像進行分類,將預測與正確的類別進行比較,計算損失以及調整參數值的過程重複了很多次。計算持續時間和成本會隨着更大,更復雜的模型而迅速升級,但我們這裏的簡單模型不需要太多耐心或高性能設備就能看到有意義的結果。 

我們代碼中的下兩行(下面)採取精度測量。沿維度1的logg的argmax返回具有最高分數的類別的索引,這是類別標籤預測。這些標籤通過tf.equal()與正確的類別類別標籤進行比較,後者返回一個布爾值向量 - 它被轉換爲浮點值(0或1),其平均值是正確預測圖像的分數。

# Operation comparing prediction with true label

correct_prediction = tf.equal(tf.argmax(logits, 1), labels_placeholder)

# Operation calculating the accuracy of our predictions
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

現在我們已經定義了TensorFlow圖,我們可以運行它。該圖可在sess變量中訪問(見下文)。我們立即初始化之前創建的變量。現在,變量定義初始值已分配給變量。 

迭代訓練過程開始並重復max_steps次。

# Run the TensorFlow graph
with tf.Session() as sess:

# Initialize variables
sess.run(tf.initialize_all_variables())

# Repeat max_steps times
for i in range(max_steps):

 接下來的幾行代碼隨機從訓練數據中選擇一些圖像:

# Generate batch of input data
indices = np.random.choice(data_sets['images_train'].shape[0], batch_size)
images_batch = data_sets['images_train'][indices]
 labels_batch = data_sets['labels_train'][indices]

上面的第一行代碼選擇0和訓練集大小之間的batch_size隨機索引。然後通過選擇這些索引處的圖像和類別標籤來構建批次。 

來自訓練數據的結果圖像和類別組稱爲批次。批量大小表示執行參數更新步驟的頻率。首先,我們平均特定批次中所有圖像的損失,然後通過梯度下降更新參數。 

如果不是在批處理後停止並對訓練集中的所有圖像進行分類,我們將能夠計算真正的平均損失和真正的梯度而不是使用批處理時的估計。但是每個參數更新步驟需要更多的計算。在另一個極端,我們可以將批量大小設置爲1,並在每個圖像後執行參數更新。這將導致更頻繁的更新,但更新將更加不穩定,並且往往不會朝着正確的方向前進。通常,在這兩個極端之間的某種方法可以最快地改善結果。通常最好選擇儘可能大的批量大小,同時仍然能夠將所有變量和中間結果放入內存中。 

每100次迭代,檢查訓練數據批次的當前準確度。

# Periodically print out the model's current accuracy
if i % 100 == 0:
 train_accuracy = sess.run(accuracy, feed_dict={
   images_placeholder: images_batch, labels_placeholder: labels_batch})
 print('Step {:5d}: training accuracy {:g}'.format(i, train_accuracy))

這是訓練循環中最重要的一行,我們建議模型執行單個訓練步驟:

# Perform the training step
sess.run(train_step, feed_dict={images_placeholder: images_batch,
 labels_placeholder: labels_batch})

 已經在TensorFlow圖形定義中提供了所有數據。 TensorFlow知道梯度下降更新取決於損失的值,而損失的值又取決於logits,後者取決於權重,偏差和實際輸入批次。 

現在只需將批量訓練數據輸入模型,這是通過提供一個飼料字典來完成的,其中當前的訓練數據批次被分配給上面定義的佔位符。 

培訓結束後,我們轉而在測試集上運行模型。由於這是模型第一次遇到測試集,因此圖像對模型來說是全新的。 

記住,目標是評估訓練有素的模型處理未知數據的能力

# After finishing the training, evaluate on the test set
test_accuracy = sess.run(accuracy, feed_dict={
 images_placeholder: data_sets['images_test'],
 labels_placeholder: data_sets['labels_test']})
print('Test accuracy {:g}'.format(test_accuracy))

最後一行打印了培訓和運行模型的持續時間。

endTime = time.time()
print('Total time: {:5.2f}s'.format(endTime - beginTime))

[Ref]


 

Google_BERT
1. 模型文件轉換

    .indel是對應模型的索引文件,保存.data文件與.meta文件中圖的結構關係;

    .data-00000-of-00001文件:保存Tensorflow每個變量的取值,存儲格式SSTable,(key, value)列表;

    .meta文件保存tensorflow計算圖的網絡結構,MetaGraph元圖,以protocal buffer格式保存

將tensorflow的ckpt模型轉爲pb文件, 需要知道網絡的輸出節點名稱, 如果不指定輸出節點名稱, 程序就不知道該freeze哪些節點, 就沒有辦法保存模型.

 
1.1 獲取模型中節點名稱

    # function: get the node name of ckpt model
    from tensorflow.python import pywrap_tensorflow
    # checkpoint_path = 'model.ckpt-xxx'
    checkpoint_path = './uncased_L-12_H-768_A-12/bert_model.ckpt'
    reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
    var_to_shape_map = reader.get_variable_to_shape_map()
    for key in var_to_shape_map:
        print("tensor_name: ", key)

tensorflow獲取模型節點名稱及將.ckpt轉爲.pb文件
1.2 將ckpt模型轉換爲pb模型

    import tensorflow as tf
    from tensorflow.python.framework import graph_util
    from tensorflow.python.platform import gfile
     
    def freeze_graph(ckpt, output_graph):
        output_node_names = 'bert/encoder/layer_7/output/dense/kernel'
        # saver = tf.train.import_meta_graph(ckpt+'.meta', clear_devices=True)
        saver = tf.compat.v1.train.import_meta_graph(ckpt+'.meta', clear_devices=True)
        graph = tf.get_default_graph()
        input_graph_def = graph.as_graph_def()
     
        with tf.Session() as sess:
            saver.restore(sess, ckpt)
            output_graph_def = graph_util.convert_variables_to_constants(
                sess=sess,
                input_graph_def=input_graph_def,
                output_node_names=output_node_names.split(',')
            )
            with tf.gfile.GFile(output_graph, 'wb') as fw:
                fw.write(output_graph_def.SerializeToString())
            print ('{} ops in the final graph.'.format(len(output_graph_def.node)))
     
    ckpt = '/home/jie/gitdir/ckpt_pb/uncased_L-12_H-768_A-12/bert_model.ckpt'
    pb   = '/home/jie/gitdir/ckpt_pb/bert_model.pb'
     
    if __name__ == '__main__':
        freeze_graph(ckpt, pb)

 
1.3 查看.ckpt文件保存的tensor信息

    import os
    from tensorflow.python import pywrap_tensorflow
     
    # code for finall ckpt
    checkpoint_path = "./uncased_L-12_H-768_A-12/bert_model.ckpt"
    # Read data from checkpoint file
    reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
    var_to_shape_map = reader.get_variable_to_shape_map()
    # Print tensor name and values
    for key in var_to_shape_map:
        print("tensor_name: ", key)
        print(reader.get_tensor(key))

[reference]
1.4 查看.pb文件的所有tensor

    # params: pb_file_direction
    import argparse
    import tensorflow as tf
     
    def print_tensors(pb_file):
        print('Model File: {}\n'.format(pb_file))
        # read pb into graph_def
        with tf.gfile.GFile(pb_file, "rb") as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
     
        # import graph_def
        with tf.Graph().as_default() as graph:
            tf.import_graph_def(graph_def)
     
        # print operations
        for op in graph.get_operations():
            print(op.name + '\t' + str(op.values()))
     
     
    if __name__ == '__main__':
        parser = argparse.ArgumentParser()
        parser.add_argument("--pb_file", type=str, required=True, help="Pb file")
        args = parser.parse_args()
        print_tensors(args.pb_file)

 

 
2. 模型文件可視化
2.1 ckpt模型可視化

 
2.2 pb模型可視化

1. 從pb文件中恢復計算圖

    import tensorflow as tf
    # path of pb file
    model = './bert_model.pb'
    # graph = tf.get_default_graph()
    graph = tf.compat.v1.get_default_graph()
    graph_def = graph.as_graph_def()
    graph_def.ParseFromString(tf.gfile.FastGFile(model, 'rb').read())
    tf.import_graph_def(graph_def, name='graph')
    # summaryWriter = tf.summary.FileWriter('log/', graph)
    summaryWriter = tf.compat.v1.summary.FileWriter('log/', graph)

2. Tensorboard查看計算圖

tensorboard --logdir ./log/

Tensorflow之pb文件分析

3. 打印pb模型的tensor info

    # coding:utf-8
    import tensorflow as tf
    from tensorflow.python.platform import gfile
     
    tf.reset_default_graph()  # 重置計算圖
    output_graph_path = '1.pb'
    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        output_graph_def = tf.GraphDef()
        # 獲得默認的圖
        graph = tf.get_default_graph()
        with gfile.FastGFile(output_graph_path, 'rb') as f:
            output_graph_def.ParseFromString(f.read())
            _ = tf.import_graph_def(output_graph_def, name="")
            # 得到當前圖有幾個操作節點
            print("%d ops in the final graph." % len(output_graph_def.node))
     
            tensor_name = [tensor.name for tensor in output_graph_def.node]
            print(tensor_name)
            print('---------------------------')
            # 在log_graph文件夾下生產日誌文件,可以在tensorboard中可視化模型
            # summaryWriter = tf.summary.FileWriter('log_graph/', graph)
     
            for op in graph.get_operations():
                # print出tensor的name和值
                print(op.name, op.values())

查看TensorFlow的pb模型文件並使用TensorBoard可視化

[Ref]


tf.train.get_checkpoint_state函數通過checkpoint文件找到模型文件名。

tf.train.get_checkpoint_state(checkpoint_dir,latest_filename=None)

該函數返回的是checkpoint文件CheckpointState proto類型的內容,其中有model_checkpoint_path和all_model_checkpoint_paths兩個屬性。其中model_checkpoint_path保存了最新的tensorflow模型文件的文件名,all_model_checkpoint_paths則有未被刪除的所有tensorflow模型文件的文件名。

下圖是在訓練過程中生成的幾個模型文件列表:


以下是測試程序裏的部分代碼:

    with tf.Session() as sess:            
                ckpt=tf.train.get_checkpoint_state('Model/')
                print(ckpt)
                if ckpt and ckpt.all_model_checkpoint_paths:
                    #加載模型
                    #這一部分是有多個模型文件時,對所有模型進行測試驗證
                    for path in ckpt.all_model_checkpoint_paths:
                        saver.restore(sess,path)                
                        global_step=path.split('/')[-1].split('-')[-1]
                        accuracy_score=sess.run(accuracy,feed_dict=validate_feed)
                        print("After %s training step(s),valisation accuracy = %g"%(global_step,accuracy_score))
                    '''
                    #對最新的模型進行測試驗證
                    saver.restore(sess,ckpt.model_checkpoint_paths)                
                    global_step=ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
                    accuracy_score=sess.run(accuracy,feed_dict=validate_feed)
                    print("After %s training step(s),valisation accuracy = %g"%(global_step,accuracy_score))
                    '''
                else:
                    print('No checkpoint file found')
                    return
            #time.sleep(eval_interval_secs)
            return

在上面代碼中,通過tf.train.get_checkpoint_state函數得到的相關模型文件名如下:


對所有模型進行測試,得到:

[Ref]


核心代碼如下:

[tensor.name for tensor in tf.get_default_graph().as_graph_def().node]

實例代碼:(加載了Inceptino_v3的模型,並獲取該模型所有節點的名稱)

    # -*- coding: utf-8 -*-
     
    import tensorflow as tf
    import os
     
    model_dir = 'C:/Inception_v3'
    model_name = 'output_graph.pb'
     
    # 讀取並創建一個圖graph來存放訓練好的 Inception_v3模型(函數)
    def create_graph():
        with tf.gfile.FastGFile(os.path.join(
                model_dir, model_name), 'rb') as f:
            # 使用tf.GraphDef()定義一個空的Graph
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            # Imports the graph from graph_def into the current default Graph.
            tf.import_graph_def(graph_def, name='')
     
    # 創建graph
    create_graph()
     
    tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
    for tensor_name in tensor_name_list:
        print(tensor_name,'\n')

輸出結果:

mixed_8/tower/conv_1/batchnorm/moving_variance

mixed_8/tower/conv_1/batchnorm

r_1/mixed/conv_1/batchnorm

.

.

.

mixed_10/tower_1/mixed/conv_1/CheckNumerics

mixed_10/tower_1/mixed/conv_1/control_dependency

mixed_10/tower_1/mixed/conv_1

pool_3

pool_3/_reshape/shape

pool_3/_reshape

input/BottleneckInputPlaceholder
.
.
.
.
final_training_ops/weights/final_weights

final_training_ops/weights/final_weights/read

final_training_ops/biases/final_biases

final_training_ops/biases/final_biases/read

final_training_ops/Wx_plus_b/MatMul

final_training_ops/Wx_plus_b/add

final_result

由於結果太長了,就省略了一些。

如果不想這樣print輸出也可以將其寫入txt 查看。

寫入txt代碼如下:

    tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
     
    txt_path = './txt/節點名稱'
    full_path = txt_path+ '.txt'
     
    for tensor_name in tensor_name_list:
        name = tensor_name + '\n'
        file = open(full_path,'a+')
    file.write(name)
    file.close()

參考鏈接:

TensorFlow學習筆記:獲取以來模型全部張量名稱

Tensorflow:如何通過名稱獲得張量?

Ref

[Ref]:tensorflow中讀取模型中保存的值, tf.train.NewCheckpointReader;

[Ref]https://blog.csdn.net/u014568072/article/details/85281769

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章