SRCNN代碼及註釋

SRCNN訓練流程分析:https://download.csdn.net/download/baixue0729/12451414

main.py

from model import SRCNN
from utils import input_setup

import numpy as np
import tensorflow as tf

import pprint
import os

'''
定義訓練和測試參數
(包括:如果採用SGD時的batchSize、學習率、步長stride、訓練還是測試模式),
此後由設定的參數進行訓練或測試。
'''

flags = tf.app.flags  #命令行執行時傳參數,使命令行運行的時候可以定義裏面的參數
#第一個是參數名稱,第二個參數是默認值,第三個是參數描述
flags.DEFINE_integer("epoch", 15000, "Number of epoch [15000]")
#一個batch更新一次參數
flags.DEFINE_integer("batch_size", 128, "The size of batch images [128]")
flags.DEFINE_integer("image_size", 33, "The size of image to use [33]")
#因爲卷積時不進行padding,三層卷積後特徵尺寸變爲21,因此將圖像中心的21*21的小圖像作爲標籤值。
flags.DEFINE_integer("label_size", 21, "The size of label to produce [21]")
##學習率文中設置爲:前兩層1e-4 第三層1e-5。SGD+指數學習率10-2作爲初始?
flags.DEFINE_float("learning_rate", 1e-4, "The learning rate of gradient descent algorithm [1e-4]")
#圖像顏色維度
flags.DEFINE_integer("c_dim", 1, "Dimension of image color. [1]")
#三次插值時的尺度
flags.DEFINE_integer("scale", 3, "The size of scale factor for preprocessing input image [3]")
#卷積步長:訓練採用14(可變,越小訓練集越多);測試採用21,因爲image_size33-(image_size33-label_size21)=21
flags.DEFINE_integer("stride", 14, "The size of stride to apply input image [14]")
#checkpoint所在文件夾名稱,數據保存檢查點
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Name of checkpoint directory [checkpoint]")
#sample所在文件夾名稱
flags.DEFINE_string("sample_dir", "sample", "Name of sample directory [sample]")
flags.DEFINE_boolean("is_train", True, "True for training, False for testing [True]")
FLAGS = flags.FLAGS
#測試:stride=21,is_train=False。訓練:stride=14,is_train=True。

#創建了一個打印的類,這樣就可以調用pp的函數了
pp = pprint.PrettyPrinter()

def main(_):
  pp.pprint(flags.FLAGS.__flags)

  if not os.path.exists(FLAGS.checkpoint_dir):
    os.makedirs(FLAGS.checkpoint_dir)  #在當前地址創建"checkpoint"文件夾
  if not os.path.exists(FLAGS.sample_dir):
    os.makedirs(FLAGS.sample_dir)  #在當前地址創建"sample"文件夾

  with tf.Session() as sess:
    srcnn = SRCNN(sess, 
                  image_size=FLAGS.image_size, 
                  label_size=FLAGS.label_size, 
                  batch_size=FLAGS.batch_size,
                  c_dim=FLAGS.c_dim, 
                  checkpoint_dir=FLAGS.checkpoint_dir,
                  sample_dir=FLAGS.sample_dir)   #創建一個SRCNN對象,自動調用初始化函數

    srcnn.train(FLAGS)
    
if __name__ == '__main__':   #import到其他腳本中不會被執行
  tf.app.run()

model.py

from utils import (
  read_data, 
  input_setup, 
  imsave,
  merge
)

import time
import os
import matplotlib.pyplot as plt

import numpy as np
import tensorflow as tf

try:
  xrange
except:
  xrange = range

class SRCNN(object):
  #對象初始化設置
  def __init__(self, 
               sess, 
               image_size=33,
               label_size=21, 
               batch_size=128,
               c_dim=1, 
               checkpoint_dir=None, 
               sample_dir=None):

    self.sess = sess
    self.is_grayscale = (c_dim == 1)
    self.image_size = image_size
    self.label_size = label_size
    self.batch_size = batch_size

    self.c_dim = c_dim  #?

    self.checkpoint_dir = checkpoint_dir
    self.sample_dir = sample_dir
    self.build_model()

  '''
  三次卷積,卷積核大小分別是9,1,5。輸出通道分別是64,32,1。
    #第一層CNN:對輸入圖片的特徵提取。(9 x 9 x 64卷積核)
    #第二層CNN:對第一層提取的特徵的非線性映射(1 x 1 x 32卷積核)
    #第三層CNN:對映射後的特徵進行重建,生成高分辨率圖像(5 x 5 x 1卷積核)
  
  '''
  def build_model(self):
    self.images = tf.placeholder(tf.float32, [None, self.image_size, self.image_size, self.c_dim], name='images')
    self.labels = tf.placeholder(tf.float32, [None, self.label_size, self.label_size, self.c_dim], name='labels')

    self.weights = {
      # 卷積核:f1*f1*c*n1. c爲輸入圖像通道數,文中取YCrCb中Y通道,c=1;f1=9;n1爲當前卷積核輸出深度取64
      'w1': tf.Variable(tf.random_normal([9, 9, 1, 64], stddev=1e-3), name='w1'),
      'w2': tf.Variable(tf.random_normal([1, 1, 64, 32], stddev=1e-3), name='w2'),
      'w3': tf.Variable(tf.random_normal([5, 5, 32, 1], stddev=1e-3), name='w3')
    }
    self.biases = {
      'b1': tf.Variable(tf.zeros([64]), name='b1'),
      'b2': tf.Variable(tf.zeros([32]), name='b2'),
      'b3': tf.Variable(tf.zeros([1]), name='b3')
    }

    self.pred = self.model()  #只要調用了model函數就會有輸出,返回三層卷積後的結果

    # Loss function (MSE)  loss用的是MSE
    self.loss = tf.reduce_mean(tf.square(self.labels - self.pred))

    # 主函數調用(訓練或測試),創建一個Saver變量
    self.saver = tf.train.Saver()

  def train(self, config):
    #數據的準備
    if config.is_train:
      input_setup(self.sess, config)  #讀取圖像文件並製作其子圖像,並將其保存爲h5文件格式。
    else:
      nx, ny = input_setup(self.sess, config)

    # 訓練爲checkpoint下train.h5;測試爲checkpoint下test.h5
    if config.is_train:     
      data_dir = os.path.join('./{}'.format(config.checkpoint_dir), "train.h5")
    else:
      data_dir = os.path.join('./{}'.format(config.checkpoint_dir), "test.h5")

    train_data, train_label = read_data(data_dir)

    #建立優化器,初始化所有參數,計數器,計時器
    #採用SGD(具有標準反向傳播的隨機梯度下降)優化器。還有Adam優化器,據說SGD的效果更好(待驗證)
    self.train_op = tf.train.GradientDescentOptimizer(config.learning_rate).minimize(self.loss)

    tf.initialize_all_variables().run()
    
    counter = 0  #計數器
    start_time = time.time()  #計時器

    #加載訓練過的參數
    if self.load(self.checkpoint_dir):
      print(" [*] Load SUCCESS")
    else:
      print(" [!] Load failed...")

    #訓練模型,每隔500次保存一次模型
    if config.is_train:
      print("Training...")

      for ep in xrange(config.epoch):  #對於每次epoch
        # Run by batch images
        batch_idxs = len(train_data) // config.batch_size  #計算一個epoch有多少batch
        for idx in xrange(0, batch_idxs):  #以batch爲單元更新
          batch_images = train_data[idx*config.batch_size : (idx+1)*config.batch_size]
          batch_labels = train_label[idx*config.batch_size : (idx+1)*config.batch_size]

          counter += 1
          _, err = self.sess.run([self.train_op, self.loss], feed_dict={self.images: batch_images, self.labels: batch_labels})

          if counter % 10 == 0:  #每更新10次輸出一次數據
            print("Epoch: [%2d], step: [%2d], time: [%4.4f], loss: [%.8f]" \
              % ((ep+1), counter, time.time()-start_time, err))

          if counter % 500 == 0:  #每更新500次保存一次數據
            self.save(config.checkpoint_dir, counter)  #確認路徑,存儲sess

    else:
      print("Testing...")

      result = self.pred.eval({self.images: train_data, self.labels: train_label})

      result = merge(result, [nx, ny])
      result = result.squeeze()  #squeeze去除維度爲1的地方
      image_path = os.path.join(os.getcwd(), config.sample_dir)
      image_path = os.path.join(image_path, "test_image.png")
      imsave(result, image_path)

  def model(self):
    '''
    將圖片經過3次卷積,步長都是1。
    卷積加偏置,前兩層有relu激活函數,最後一層無激活函數。
    :return: 最後的一次的卷積結果

    strides在官方定義中是一個一維具有四個元素的張量,
    其規定前後必須爲1,所以我們可以改的是中間兩個數,
    中間兩個數分別代表了水平滑動和垂直滑動步長值。
    '''
    conv1 = tf.nn.relu(tf.nn.conv2d(self.images, self.weights['w1'], strides=[1,1,1,1], padding='VALID') + self.biases['b1'])
    conv2 = tf.nn.relu(tf.nn.conv2d(conv1, self.weights['w2'], strides=[1,1,1,1], padding='VALID') + self.biases['b2'])
    conv3 = tf.nn.conv2d(conv2, self.weights['w3'], strides=[1,1,1,1], padding='VALID') + self.biases['b3']
    return conv3

  #確認路徑,存儲sess
  def save(self, checkpoint_dir, step):
    model_name = "SRCNN.model"
    model_dir = "%s_%s" % ("srcnn", self.label_size)
    checkpoint_dir = os.path.join(checkpoint_dir, model_dir)  #存儲路徑爲checkpoint/srcnn_label_size

    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    '''
    向文件夾中寫入包含當前模型中所有可訓練變量的checkpoint文件,
    之後可以使用saver.restore()方法,重載模型的參數,繼續訓練或者用於測試數據
    '''
    self.saver.save(self.sess,
                    os.path.join(checkpoint_dir, model_name),
                    global_step=step)

  #加載sess,成功加載返回True,否則返回False。
  def load(self, checkpoint_dir):
    print(" [*] Reading checkpoints...")
    model_dir = "%s_%s" % ("srcnn", self.label_size)
    checkpoint_dir = os.path.join(checkpoint_dir, model_dir)

    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)  #從文件夾中獲取checkpoint文件
    if ckpt and ckpt.model_checkpoint_path:
        ckpt_name = os.path.basename(ckpt.model_checkpoint_path)  #獲取checkpoint文件的文件名
        self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))  #重載模型的參數,繼續訓練或者用於測試數據
        return True
    else:
        return False

utils.py

"""
Scipy version > 0.18 is needed, due to 'mode' option from scipy.misc.imread function
由於scipy.misc.imread函數的'mode'選項,需要Scipy版本> 0.18
"""

import os
import glob  #glob庫,作用是類似於系統的文件路徑匹配查詢
import h5py  #h5py庫,主要用於讀取或創建datasets或groups
import random
import matplotlib.pyplot as plt

from PIL import Image  # for loading images as YCbCr format
import scipy.misc  #該庫主要用於將數組保存成圖像形式
import scipy.ndimage  #該庫用於圖像處理
import numpy as np

import tensorflow as tf

try:
  xrange  #處理異常中斷
except:
  xrange = range
  
FLAGS = tf.app.flags.FLAGS  #命令行參數傳遞

def read_data(path):
  """
  Read h5 format data file
  讀取h5文件中的data和label數據,需要將其轉換成np.array格式
  Args:
    path: file path of desired file
    data: '.h5' file format that contains train data values
    label: '.h5' file format that contains train label values
  """
  with h5py.File(path, 'r') as hf:  #讀取h5格式數據文件(用於訓練或測試)
    data = np.array(hf.get('data'))
    label = np.array(hf.get('label'))
    return data, label

def preprocess(path, scale=3):
  """
  處理圖片,input_,label_分別是輸入和輸出的圖片,對應低分辨率和高分辨率
  Preprocess single image file 
    (1) Read original image as YCbCr format (and grayscale as default)
    (2) Normalize
    (3) Apply image file with bicubic interpolation
  預處理單個圖像文件
     (1)以YCbCr格式讀取原始圖像(默認爲灰度)
     (2)歸一化  (將0-255的uint8型數據轉換到0-1之間)
     (3)應用三次插值的圖像文件
  Args:
    path: file path of desired file
    input_: image applied bicubic interpolation (low-resolution)
    label_: image with original resolution (high-resolution)
  """
  image = imread(path, is_grayscale=True)  #n維數組對象,灰度圖。暫不確定顏色通道數是1還是3,假設是1
  label_ = modcrop(image, scale)  #將圖片規整到可以被scale整除的寬和高,例如:scale=3時(197,176,1?)->(195,174,1)

  # Must be normalized
  image = image / 255.    #無用
  label_ = label_ / 255.

  '''
  scipy.ndimage.interpolation.zoom(
      input,  #輸入數組
      zoom,   #沿軸的縮放係數。
      output=None,  #放置輸出的數組,或返回數組的dtype
      order=3,      #樣條插值的順序(0~5),=3:三次插值,=0:最近插值,=1:雙線性插值.
      mode='constant', #根據給定的模式('常數','最近','反映'或'換行')填充輸入邊界之外的點
      cval=0.0,        #如果mode ='constant',則用於輸入邊界之外的點的值。
      prefilter=True)  #是否在插值之前使用spline_filter進行預過濾,如果爲False,則假定輸入已被過濾。
  '''
  input_ = scipy.ndimage.interpolation.zoom(label_, (1./scale), prefilter=False)  #使用三次插值縮小scale倍 (195,174,1)->(65,58,1)
  input_ = scipy.ndimage.interpolation.zoom(input_, (scale/1.), prefilter=False)  #使用三次插值擴大scale倍 (65,58,1)->(195,174,1)

  return input_, label_   #低分辨率,高分辨率 (195,174,1)

def prepare_data(sess, dataset):
  """
  得到圖片路徑list
  Args:
    dataset: choose train dataset or test dataset
    
    For train dataset, output data would be ['.../t1.bmp', '.../t2.bmp', ..., '.../t99.bmp']
  """
  if FLAGS.is_train:
    filenames = os.listdir(dataset)
    data_dir = os.path.join(os.getcwd(), dataset)  #os.getcwd():得到當前文件路徑
    data = glob.glob(os.path.join(data_dir, "*.png"))  #.bmp
  else:
    data_dir = os.path.join(os.sep, (os.path.join(os.getcwd(), dataset)), "Set")   #os.sep():無需考慮],\和/
    data = glob.glob(os.path.join(data_dir, "*.png"))

  return data

def make_data(sess, data, label):
  """
  製作h5文件,將data(checkpoint下的train.h5 或test.h5)利用h5的create_dataset 寫入
  Make input data as h5 file format
  Depending on 'is_train' (flag value), savepath would be changed.
  將輸入數據設置爲h5文件格式
  根據“ is_train”(標誌值),保存路徑將被更改。
  """
  if FLAGS.is_train:
    savepath = os.path.join(os.getcwd(), 'checkpoint/train.h5') #定義h5文件保存地址
  else:
    savepath = os.path.join(os.getcwd(), 'checkpoint/test.h5')

  with h5py.File(savepath, 'w') as hf:
    hf.create_dataset('data', data=data)   #建立一個名叫"data"的HDF5數據集
    hf.create_dataset('label', data=label) #建立一個名叫"label"的HDF5數據集

def imread(path, is_grayscale=True):
  """
  Read image using its path.
  Default value is gray-scale, and image is read by YCbCr format as the paper said.
  使用其路徑讀取圖像。
  默認值爲灰度,如論文所述,圖像以YCbCr格式讀取。
  """
  if is_grayscale:
    return scipy.misc.imread(path, flatten=True, mode='L').astype(np.float)  #flatten參數:將彩色圖層變爲灰度
  else:
    return scipy.misc.imread(path, mode='L').astype(np.float)    #mode='YCbCr'

def modcrop(image, scale=3):
  """
  將圖片規整到可以被scale整除的寬高。
  To scale down and up the original image, first thing to do is to have no remainder while scaling operation.
  
  We need to find modulo of height (and width) and scale factor.
  Then, subtract the modulo from height (and width) of original image size.
  There would be no remainder even after scaling operation.
  要按比例放大和縮小原始圖像,首先要做的是在縮放操作時沒有餘數。 
   我們需要找到高度(和寬度)和比例因子的模。
   然後,從原始圖像尺寸的高度(和寬度)中減去模。
   即使在縮放操作之後也將沒有剩餘。
  """
  if len(image.shape) == 3:
    h, w, _ = image.shape       #eg: image.shape=(197, 176, 3), scale=3
    h = h - np.mod(h, scale)    #h=195
    w = w - np.mod(w, scale)    #w=174
    image = image[0:h, 0:w, :]
  else:
    h, w = image.shape
    h = h - np.mod(h, scale)
    w = w - np.mod(w, scale)
    image = image[0:h, 0:w]
  return image

def input_setup(sess, config):
  """
  就是把輸入和輸出圖片,切成一小塊存起來
  Read image files and make their sub-images and saved them as a h5 file format.
  讀取圖像文件並製作其子圖像,並將其保存爲h5文件格式。
  """
  # Load data path
  if config.is_train:
    data = prepare_data(sess, dataset="Train")  #圖的位置列表
  else:
    data = prepare_data(sess, dataset="Test")

  sub_input_sequence = []
  sub_label_sequence = []
  padding = abs(config.image_size - config.label_size) / 2  #(33-21)/2=6

  if config.is_train:
    for i in xrange(len(data)):  #對於每張圖
      input_, label_ = preprocess(data[i], config.scale)  #低分辨率,高分辨率

      if len(input_.shape) == 3:
        h, w, _ = input_.shape  #195,174,1
      else:
        h, w = input_.shape

      for x in range(0, h-config.image_size+1, config.stride):   #(0~163),stride=測試用21,爲了拼接。爲什麼要加1?
        for y in range(0, w-config.image_size+1, config.stride): #(0~142)
          # (x ~ x+33, y ~ y+33),裁剪成[33*33]的小圖
          sub_input = input_[x:x+config.image_size, y:y+config.image_size]
          # (x+6 ~ x+6+21,y+6 ~ y+6+21),裁剪成[21*21]的小圖,捨棄邊緣部分
          sub_label = label_[x+int(padding):x+int(padding)+config.label_size, y+int(padding):y+int(padding)+config.label_size]

          # Make channel value,重定義圖片塊大小:image_size*image_size*1, 爲什麼要定義成三通道
          sub_input = sub_input.reshape([config.image_size, config.image_size, 1])  #image_size:33,即33*33*1
          sub_label = sub_label.reshape([config.label_size, config.label_size, 1])  #label_size:21,即21*21*1

          sub_input_sequence.append(sub_input)  #將小圖添加到輸入列表
          sub_label_sequence.append(sub_label)  #將小圖添加到標籤列表

  else:  #test的話少了一個循環,它只對其中的一張圖片做處理。
    input_, label_ = preprocess(data[0], config.scale)

    if len(input_.shape) == 3:
      h, w, _ = input_.shape
    else:
      h, w = input_.shape

    # Numbers of sub-images in height and width of image are needed to compute merge operation.
    # 計算合併操作需要圖像高度和寬度中的子圖像數量。
    nx = ny = 0 
    for x in range(0, h-config.image_size+1, config.stride):
      nx += 1; ny = 0
      for y in range(0, w-config.image_size+1, config.stride):
        ny += 1
        sub_input = input_[x:x+config.image_size, y:y+config.image_size] # [33 x 33]
        sub_label = label_[x+int(padding):x+int(padding)+config.label_size, y+int(padding):y+int(padding)+config.label_size] # [21 x 21]
        
        sub_input = sub_input.reshape([config.image_size, config.image_size, 1])  
        sub_label = sub_label.reshape([config.label_size, config.label_size, 1])

        sub_input_sequence.append(sub_input)
        sub_label_sequence.append(sub_label)

  """
  len(sub_input_sequence) : the number of sub_input (33 x 33 x ch) in one image
  (sub_input_sequence[0]).shape : (33, 33, 1)
  """
  # Make list to numpy array. With this transform
  # 需要將數據轉成numpy類型,被存爲h5格式。
  arrdata = np.asarray(sub_input_sequence)  # 輸入集,格式爲[?, 33, 33, 1],?代表小圖總數
  arrlabel = np.asarray(sub_label_sequence) # 標籤集,格式爲[?, 21, 21, 1]

  make_data(sess, arrdata, arrlabel)  #製作h5文件,這個是產生訓練數據的函數

  #nx,ny分別是有多少列,多少行。
  if not config.is_train:
    return nx, ny
    
def imsave(image, path):  #保存圖片
  return scipy.misc.imsave(path, image)

def merge(images, size):
  """
  將一個batch內的圖片拼接在一起
  images:一個batch圖片,
  size:第一個參數是高度上有幾張圖片,第二個寬度上有幾張圖片
  %:求模運算,取餘。
  //:取整,返回不大於結果的一個最大的整數
  /:浮點數除法
  """
  h, w = images.shape[1], images.shape[2]
  img = np.zeros((h*size[0], w*size[1], 1))
  for idx, image in enumerate(images):
    i = idx % size[1]
    j = idx // size[1]
    img[j*h:j*h+h, i*w:i*w+w, :] = image

  return img
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章