TensorFlow程序分析(profile)實戰

導入必要的包

import os
import tempfile

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

建立模型

batch_size = 100

# placeholder
inputs = tf.placeholder(tf.float32, [batch_size, 784])
targets = tf.placeholder(tf.float32, [batch_size, 10])

# model
fc_1_out = tf.layers.dense(inputs, 500, activation=tf.nn.sigmoid)
fc_2_out = tf.layers.dense(fc_1_out, 784, activation=tf.nn.sigmoid)
logits = tf.layers.dense(fc_2_out, 10, activation=None)

# loss + train_op
loss = tf.losses.softmax_cross_entropy(onehot_labels=targets, logits=logits)
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)

加載數據,並獲取程序運行數據

# load data
mnist_save_dir = os.path.join(tempfile.gettempdir(), 'MNIST_data')
mnist = input_data.read_data_sets(mnist_save_dir, one_hot=True)

# get tracing data
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  
  # 創建Profiler實例作爲記錄、處理、顯示數據的主體
  profiler = tf.profiler.Profiler(graph=sess.graph)
  
  # 設置trace_level,這樣才能蒐集到包含GPU硬件在內的最全統計數據
  run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
  # 創建RunMetadata實例,用於在每次sess.run時彙總統計數據
  run_metadata = tf.RunMetadata()
  
  for i in range(10):
    batch_input, batch_target = mnist.train.next_batch(batch_size)
    feed_dict = {inputs: batch_input,
                 targets: batch_target}
    _ = sess.run(train_op,
                 feed_dict=feed_dict,
                 options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
                 run_metadata=run_metadata)
    
    # 將當前step的統計數據添加到Profiler實例中
    profiler.add_step(step=i, run_meta=run_metadata)

統計模型的參數量

## 統計參數量
opts = tf.profiler.ProfileOptionBuilder.trainable_variables_parameter()
param_stats = profiler.profile_name_scope(options=opts)
# 總參數量
print('總參數:', param_stats.total_parameters)
# 各scope參數量
for x in param_stats.children:
  print(x.name, 'scope參數:', x.total_parameters)

統計模型的浮點運算數

# 統計運算量
opts = tf.profiler.ProfileOptionBuilder.float_operation()
float_stats = profiler.profile_operations(opts)
# 總參數量
print('總浮點運算數:', float_stats.total_float_ops)

統計模型的內存、耗時情況

# 統計模型內存和耗時情況
builder = tf.profiler.ProfileOptionBuilder
opts = builder(builder.time_and_memory())
#opts.with_step(1)
opts.with_timeline_output('timeline.json')
opts = opts.build()

#profiler.profile_name_scope(opts) # 只能保存單step的timeline
profiler.profile_graph(opts) # 保存各個step的timeline

給出使用profile工具給出建議

opts = {'AcceleratorUtilizationChecker': {},
        'ExpensiveOperationChecker': {},
        'JobChecker': {},
        'OperationChecker': {}}
profiler.advise(opts)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章