計算模型需要訓練的參數數目:
def count_trainable_vars():
total_parameters = 0
for variable in tf.trainable_variables():
variable_parameters = 1
for dim in variable.get_shape():
variable_parameters *= dim.value
total_parameters += variable_parameters
print("Total number of trainable parameters-----------------------------------------: %d" % total_parameters)