1. tensorflow保存了3個文件
model.ckpt-10000.data-00000-of-00001
model.ckpt-10000.index
model.ckpt-10000.meta
- 一般調用生成的模型,直接
model.ckpt-1000
這樣的格式即可 - data中存儲的是模型的變量值
- index 存儲的是tensor名稱
- meta 存儲的是graph結構,包括 GraphDef, SaverDef等,當存在meta file,我們可以不在文件中定義模型,也可以運行,而如果沒有meta file,我們需要定義好模型,再加載data file,得到變量值
2. 計算模型中的參數量
- keras是可以直接輸出每層的結構,並且在最後自動計算參數量
- 普通的tensorflow可以調用訓練生成的模型,計算參數量
from tensorflow.python import pywrap_tensorflow
import os
import numpy as np
checkpoint_path = os.path.join("models_pretrained/", "model.ckpt-82798")
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
total_parameters = 0
for key in var_to_shape_map:#list the keys of the model
# print(key)
# print(reader.get_tensor(key))
shape = np.shape(reader.get_tensor(key)) #get the shape of the tensor in the model
shape = list(shape)
# print(shape)
# print(len(shape))
variable_parameters = 1
for dim in shape:
# print(dim)
variable_parameters *= dim
# print(variable_parameters)
total_parameters += variable_parameters
print(total_parameters)
-
計算模型的浮點運算量
指導方法
但是還沒有成功跑通,暫留! -
日誌輸出
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"]='1' # 這是默認的顯示等級,顯示所有信息
os.environ["TF_CPP_MIN_LOG_LEVEL"]='2' # 只顯示 warning 和 Error
os.environ["TF_CPP_MIN_LOG_LEVEL"]='3' # 只顯示 Error