最近在學習SRCNN,閱讀代碼做好筆記
代碼下載鏈接https://github.com/tegg89/SRCNN-Tensorflow
下面開始
from model import SRCNN
from utils import input_setup
import numpy as np
import tensorflow as tf
import pprint
import os
flags = tf.app.flags
flags.DEFINE_integer("epoch", 2000,"訓練多少波")
#flags.DEFINE_integer("batch_size", 128, "The size of batch images [128]")
#一開始將batch size設爲128和64,不僅參數初始loss很大,而且往往一段時間後訓練就發散
#batch中每個樣本產生梯度競爭可能比較激烈,所以導致了收斂過慢
#後來改回了128
flags.DEFINE_integer("batch_size", 128, "batch size")
flags.DEFINE_integer("image_size", 33, "圖像使用的尺寸")
flags.DEFINE_integer("label_size", 21, "label_製作的尺寸")
#學習率文中設置爲 前兩層1e-4 第三層1e-5
#SGD+指數學習率10-2作爲初始
flags.DEFINE_float("learning_rate", 1e-2, "學習率")
flags.DEFINE_integer("c_dim", 1, "圖像維度")
flags.DEFINE_integer("scale", 3, "sample的scale大小")
#stride訓練採用14,測試採用21
flags.DEFINE_integer("stride", 21 , "步長爲14或者21")
flags.DEFINE_string("checkpoint_dir", "checkpoint", "checkpoint directory名字")
flags.DEFINE_string("sample_dir", "sample", "sample directory名字")
flags.DEFINE_boolean("is_train", False, "True for training, False for testing")#測試
#flags.DEFINE_boolean("is_train", True, "True for training, False for testing")#訓練
FLAGS = flags.FLAGS
pp = pprint.PrettyPrinter()
def main(_):
pp.pprint(flags.FLAGS.__flags)
if not os.path.exists(FLAGS.checkpoint_dir):
os.makedirs(FLAGS.checkpoint_dir)
if not os.path.exists(FLAGS.sample_dir):
os.makedirs(FLAGS.sample_dir)
with tf.Session() as sess:
srcnn = SRCNN(sess,
image_size=FLAGS.image_size,
label_size=FLAGS.label_size,
batch_size=FLAGS.batch_size,
c_dim=FLAGS.c_dim,
checkpoint_dir=FLAGS.checkpoint_dir,
sample_dir=FLAGS.sample_dir)
srcnn.train(FLAGS)
if __name__ == '__main__':
tf.app.run()