TF flags的簡介

1、TF flags的簡介

1、flags可以幫助我們通過命令行來動態的更改代碼中的參數。Tensorflow 使用flags定義命令行參數的方法。ML的模型中有大量需要tuning的超參數,所以此方法,迎合了需要一種靈活的方式對代碼某些參數進行調整的需求
(1)、比如,在這個py文件中,首先定義了一些參數,然後將參數統一保存到變量FLAGS中,相當於賦值,後邊調用這些參數的時候直接使用FLAGS參數即可
(2)、基本參數類型有三種flags.DEFINE_integer、flags.DEFINE_float、flags.DEFINE_boolean。
(3)、第一個是參數名稱,第二個參數是默認值,第三個是參數描述

2、使用過程

#第一步,調用flags = tf.app.flags,進行定義參數名稱,並可給定初值、參數說明
#第二步,flags參數直接賦值
#第三步,運行tf.app.run()

FLAGS = tf.flags.FLAGS

tf.flags.DEFINE_string('name', 'default', 'name of the model')
tf.flags.DEFINE_integer('num_seqs', 100, 'number of seqs in one batch')
tf.flags.DEFINE_integer('num_steps', 100, 'length of one seq')
tf.flags.DEFINE_integer('lstm_size', 128, 'size of hidden state of lstm')
tf.flags.DEFINE_integer('num_layers', 2, 'number of lstm layers')
tf.flags.DEFINE_boolean('use_embedding', False, 'whether to use embedding')
tf.flags.DEFINE_integer('embedding_size', 128, 'size of embedding')
tf.flags.DEFINE_float('learning_rate', 0.001, 'learning_rate')
tf.flags.DEFINE_float('train_keep_prob', 0.5, 'dropout rate during training')
tf.flags.DEFINE_string('input_file', '', 'utf8 encoded text file')
tf.flags.DEFINE_integer('max_steps', 100000, 'max steps to train')
tf.flags.DEFINE_integer('save_every_n', 1000, 'save the model every n steps')
tf.flags.DEFINE_integer('log_every_n', 10, 'log to the screen every n steps')
tf.flags.DEFINE_integer('max_vocab', 3500, 'max char number')

示例如下:

import tensorflow as tf
#取上述代碼中一部分進行實驗
tf.flags.DEFINE_integer('num_seqs', 100, 'number of seqs in one batch')
tf.flags.DEFINE_integer('num_steps', 100, 'length of one seq')
tf.flags.DEFINE_integer('lstm_size', 128, 'size of hidden state of lstm')

#通過print()確定下面內容的功能
FLAGS = tf.flags.FLAGS #FLAGS保存命令行參數的數據
FLAGS._parse_flags() #將其解析成字典存儲到FLAGS.__flags中
print(FLAGS.__flags)

print(FLAGS.num_seqs)

print("\nParameters:")
for attr, value in sorted(FLAGS.__flags.items()):
    print("{}={}".format(attr.upper(), value))
print("")

遇到問題可以參考:相關解決辦法

微信號
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章