SRCNN 代碼解讀【main.py】

最近在學習SRCNN,閱讀代碼做好筆記
代碼下載鏈接https://github.com/tegg89/SRCNN-Tensorflow
下面開始

from model import SRCNN
from utils import input_setup
import numpy as np
import tensorflow as tf
import pprint
import os
flags = tf.app.flags
flags.DEFINE_integer("epoch", 2000,"訓練多少波")
#flags.DEFINE_integer("batch_size", 128, "The size of batch images [128]")
#一開始將batch size設爲128和64,不僅參數初始loss很大,而且往往一段時間後訓練就發散
#batch中每個樣本產生梯度競爭可能比較激烈,所以導致了收斂過慢
#後來改回了128
flags.DEFINE_integer("batch_size", 128, "batch size")
flags.DEFINE_integer("image_size", 33, "圖像使用的尺寸")
flags.DEFINE_integer("label_size", 21, "label_製作的尺寸")
#學習率文中設置爲 前兩層1e-4 第三層1e-5
#SGD+指數學習率10-2作爲初始
flags.DEFINE_float("learning_rate", 1e-2, "學習率")
flags.DEFINE_integer("c_dim", 1, "圖像維度")
flags.DEFINE_integer("scale", 3, "sample的scale大小")
#stride訓練採用14,測試採用21
flags.DEFINE_integer("stride", 21 , "步長爲14或者21")
flags.DEFINE_string("checkpoint_dir", "checkpoint", "checkpoint directory名字")
flags.DEFINE_string("sample_dir", "sample", "sample directory名字")
flags.DEFINE_boolean("is_train", False, "True for training, False for testing")#測試
#flags.DEFINE_boolean("is_train", True, "True for training, False for testing")#訓練
FLAGS = flags.FLAGS
pp = pprint.PrettyPrinter()
def main(_):
  pp.pprint(flags.FLAGS.__flags)
  if not os.path.exists(FLAGS.checkpoint_dir):
    os.makedirs(FLAGS.checkpoint_dir)
  if not os.path.exists(FLAGS.sample_dir):
    os.makedirs(FLAGS.sample_dir)
  with tf.Session() as sess:
    srcnn = SRCNN(sess, 
                  image_size=FLAGS.image_size, 
                  label_size=FLAGS.label_size, 
                  batch_size=FLAGS.batch_size,
                  c_dim=FLAGS.c_dim, 
                  checkpoint_dir=FLAGS.checkpoint_dir,
                  sample_dir=FLAGS.sample_dir)
    srcnn.train(FLAGS)
if __name__ == '__main__':
  tf.app.run()
發佈了41 篇原創文章 · 獲贊 9 · 訪問量 1萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章