超分辨率重構之SRCNN整理總結(七)

         到此爲止關於超分重建的理論部分八成已經作結,關於這個tensorflow版本的SRCNN的代碼解讀不知道究竟需要寫到什麼程度纔可以完美收官。大家也都明白,這個東西若寫太細,略顯冗雜;若寫太粗,略顯不夠明析。反正吧,儘可能的寫清楚寫明細。下面是我的GitHub代碼倉庫:https://github.com/XiaoYunChaos,關於這篇的代碼隨後完整作結後我會上傳至倉庫,供大家討論學習,歡迎star哦!

SRCNN(tensorflow)詳解分析

  • 【1】首先,介紹一下項目結構:

              main.py 定義訓練和測試參數,此後由設定的參數進行訓練或測試。

    model.py是模型文件以類的方式實現

    utils.py是用來封裝項目中的函數作爲函數池

    psnr.py是用來做評價函數的,功能就是進行計算評價指標

              checkpoint文件夾是用來保訓練模型,即chekpoint的路徑

              sample文件夾是樣本路徑

              Train文件夾是訓練集路徑

              Test文件夾是測試集路徑,包含Set5與Set14

        在看懂代碼前,一定要明白一件事就是我們每一次訓練實際上是訓練圖片的大小和輸出圖片等的大小等參數的設置。項目除了一般的預處理操作,還需要將圖片分割,最後的訓練完還做實驗的時候還需要將圖片結合起來。

  • 【2】main.py

        功能:定義訓練和測試參數,包括:batchSize、學習率、步長stride、訓練、測試等。

函數運行開啓:

if __name__ == '__main__':
    # main()
    tf.app.run()

隨後tf.app運行,此時涉及相關參數:

flags = tf.app.flags
#第一個是參數名稱,第二個參數是默認值,第三個是參數描述
flags.DEFINE_integer("epoch", 15000, "訓練多少波Number of epoch [15000]")
#flags.DEFINE_integer("batch_size", 128, "The size of batch images [128]")
flags.DEFINE_integer("batch_size", 128, "batch size")
#一開始將batch size設爲128和64,不僅參數初始loss很大,而且往往一段時間後訓練就發散
#batch中每個樣本產生梯度競爭可能比較激烈,所以導致了收斂過慢
#後來就改回了128
flags.DEFINE_integer("image_size", 33, "圖像使用的尺寸 The size of image to use [33]")
flags.DEFINE_integer("label_size", 21, "label_製作的尺寸 The size of label to produce [21]")
#學習率文中設置爲 前兩層1e-4 第三層1e-5
#SGD+指數學習率10-2作爲初始
flags.DEFINE_float("learning_rate", 1e-4, "學習率 The learning rate of gradient descent algorithm [1e-4]")
flags.DEFINE_integer("c_dim", 1, "圖像維度 Dimension of image color. [1]")
flags.DEFINE_integer("scale", 3, "sample的scale大小 The size of scale factor for preprocessing input image [3]")
#stride訓練採用14,測試採用21
flags.DEFINE_integer("stride", 14, "步長爲14或者21 The size of stride to apply input image [14]")
flags.DEFINE_string("checkpoint_dir", "checkpoint", "名字 Name of checkpoint directory [checkpoint]")
flags.DEFINE_string("sample_dir", "sample", "名字 Name of sample directory [sample]")
flags.DEFINE_boolean("is_train", True, "True for training, False for testing [True]")#訓練
#flags.DEFINE_boolean("is_train", False, "True for training, False for testing")#測試
FLAGS = flags.FLAGS
#第一句是賦值,將前面的一系列參數賦值給FLAGS。
#第二句是創建了一個打印的類,這樣就可以調用pp的函數了。
pp = pprint.PrettyPrinter()

此時需要注意這些參數:

  • epoch:迭代次數
  • batch_size:批處理參數
  • image_size:圖像大小
  • label_size:高分辨率圖像大小,即真實標籤的大小
  • learning_rate:學習率
  • c_dim:圖像顏色維度
  • scale:縮放倍數
  • stride:卷積步長
  • checkpoint_dir:模型保存路徑
  • sample_dir:樣本路徑
  • is_train:是否訓練
     
  • 【3】main函數

CPU版本:

def main(_): #CPU版本
  pp.pprint(flags.FLAGS.__flags)
  #路徑檢查,沒有則創建
  if not os.path.exists(FLAGS.checkpoint_dir):
    os.makedirs(FLAGS.checkpoint_dir)
  if not os.path.exists(FLAGS.sample_dir):
    os.makedirs(FLAGS.sample_dir)
  #tf的相關參數傳入及srcnn模型訓練或測試
  with tf.Session() as sess:  
    #new出一個類對象,這個對象你可以理解爲這個三層神經網絡
    srcnn = SRCNN(sess, 
                  image_size=FLAGS.image_size, 
                  label_size=FLAGS.label_size, 
                  batch_size=FLAGS.batch_size,
                  c_dim=FLAGS.c_dim, 
                  checkpoint_dir=FLAGS.checkpoint_dir,
                  sample_dir=FLAGS.sample_dir)
    #訓練模型
    srcnn.train(FLAGS)

 

GPU版本:

def main(_): #GPU版本:
  pp.pprint(flags.FLAGS.__flags)
  #os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  #主函數驗證路徑是否存在,如果不存在就創造一個
  if not os.path.exists(FLAGS.checkpoint_dir):
    os.makedirs(FLAGS.checkpoint_dir)
  if not os.path.exists(FLAGS.sample_dir):
    os.makedirs(FLAGS.sample_dir)
  config = tf.ConfigProto()
  config.gpu_options.allow_growth = True
  with tf.Session(config=config) as sess:
    
    #sess = tf.Session()
    srcnn = SRCNN(sess, 
                  image_size=FLAGS.image_size, 
                  label_size=FLAGS.label_size, 
                  batch_size=FLAGS.batch_size,
                  c_dim=FLAGS.c_dim,
                  #圖像維度 
                  checkpoint_dir=FLAGS.checkpoint_dir,
                  sample_dir=FLAGS.sample_dir)

    srcnn.train(FLAGS)
    print(srcnn.train(FLAGS))

        GPU版本與CPU版本代碼理解無多大區別,就是在項目部署上可能不一樣,GPU的存在有什麼好處呢,說白了就是模型訓練加速器,可以更快更高效的將模型訓練出來,對於GPU的相關筆記隨後再做解釋吧,你只要把CPU代碼理解了,其他的都是錦上添花。

        上述main函數可以說是已經將項目框架跑完了,隨後就是一些細節上的理解和處理了。

  • 【4】model.py

from utils import (
  read_data, 
  input_setup, 
  imsave,
  merge
)

import time
import os
import matplotlib.pyplot as plt

import numpy as np
import tensorflow as tf

try:
  xrange
except:
  xrange = range

class SRCNN(object):

  def __init__(self, 
               sess, 
               image_size=33,
               label_size=21, 
               batch_size=128,
               c_dim=1, 
               checkpoint_dir=None, 
               sample_dir=None):

    self.sess = sess
    self.is_grayscale = (c_dim == 1)
    self.image_size = image_size
    self.label_size = label_size
    self.batch_size = batch_size

    self.c_dim = c_dim

    self.checkpoint_dir = checkpoint_dir
    self.sample_dir = sample_dir
    self.build_model()
#搭建網絡
  def build_model(self):   #三層網絡結構
    self.images = tf.placeholder(tf.float32, [None, self.image_size, self.image_size, self.c_dim], name='images')
    self.labels = tf.placeholder(tf.float32, [None, self.label_size, self.label_size, self.c_dim], name='labels')
    #第一層CNN:對輸入圖片的特徵提取。(9 x 9 x 64卷積核)
    #第二層CNN:對第一層提取的特徵的非線性映射(1 x 1 x 32卷積核)
    #第三層CNN:對映射後的特徵進行重建,生成高分辨率圖像(5 x 5 x 1卷積核)
    #權重    
    self.weights = {
      #論文中爲提高訓練速度的設置 n1=32 n2=16
      'w1': tf.Variable(tf.random_normal([9, 9, 1, 64], stddev=1e-3), name='w1'),
      'w2': tf.Variable(tf.random_normal([1, 1, 64, 32], stddev=1e-3), name='w2'),
      'w3': tf.Variable(tf.random_normal([5, 5, 32, 1], stddev=1e-3), name='w3')
    }
    self.biases = {
      'b1': tf.Variable(tf.zeros([64]), name='b1'),
      'b2': tf.Variable(tf.zeros([32]), name='b2'),
      'b3': tf.Variable(tf.zeros([1]), name='b3')
    }

    self.pred = self.model()
    # Loss function (MSE)以MSE爲損失函數
    self.loss = tf.reduce_mean(tf.square(self.labels - self.pred))
    #主函數調用(訓練或測試)
    self.saver = tf.train.Saver()
#訓練
  def train(self, config):
    if config.is_train:#判斷是否爲訓練(main傳入)
      input_setup(self.sess, config)
    else:
      nx, ny = input_setup(self.sess, config)
	#訓練爲checkpoint下train.h5
    #測試爲checkpoint下test.h5
    if config.is_train:     
      data_dir = os.path.join('./{}'.format(config.checkpoint_dir), "train.h5")
    else:
      data_dir = os.path.join('./{}'.format(config.checkpoint_dir), "test.h5")
	#訓練數據標籤
    train_data, train_label = read_data(data_dir)
	#讀取.h5文件(由測試和訓練決定)
    # Stochastic gradient descent with the standard backpropagation
    self.train_op = tf.train.GradientDescentOptimizer(config.learning_rate).minimize(self.loss)

    tf.global_variables_initializer().run()
    
    counter = 0
    start_time = time.time()

    if self.load(self.checkpoint_dir):
      print(" [*] Load SUCCESS")
    else:
      print(" [!] Load failed...")
	#訓練
    if config.is_train:
      print("Training...")

      for ep in xrange(config.epoch):#迭代次數的循環
      	#以batch爲單元
        # Run by batch images
        batch_idxs = len(train_data) // config.batch_size
        for idx in xrange(0, batch_idxs):
          batch_images = train_data[idx*config.batch_size : (idx+1)*config.batch_size]
          batch_labels = train_label[idx*config.batch_size : (idx+1)*config.batch_size]

          counter += 1
          _, err = self.sess.run([self.train_op, self.loss], feed_dict={self.images: batch_images, self.labels: batch_labels})

          if counter % 10 == 0:#10的倍數的step顯示
            print("Epoch: [%2d], step: [%2d], time: [%4.4f], loss: [%.8f]" \
              % ((ep+1), counter, time.time()-start_time, err))

          if counter % 500 == 0:#500的倍數step儲存
            self.save(config.checkpoint_dir, counter)
	#測試
    else:
      print("Testing...")

      result = self.pred.eval({self.images: train_data, self.labels: train_label})

      result = merge(result, [nx, ny])
      result = result.squeeze()#除去size爲1的維度
      #result= exposure.adjust_gamma(result, 1.07)#調暗一些
      image_path = os.path.join(os.getcwd(), config.sample_dir)
      image_path = os.path.join(image_path, "test_image.png")
      imsave(result, image_path)

  def model(self):
  #strides在官方定義中是一個一維具有四個元素的張量,其規定前後必須爲1,所以我們可以改的是中間兩個數,中間兩個數分別代表了水平滑動和垂直滑動步長值。
    conv1 = tf.nn.relu(tf.nn.conv2d(self.images, self.weights['w1'], strides=[1,1,1,1], padding='VALID') + self.biases['b1'])
    conv2 = tf.nn.relu(tf.nn.conv2d(conv1, self.weights['w2'], strides=[1,1,1,1], padding='VALID') + self.biases['b2'])
    conv3 = tf.nn.conv2d(conv2, self.weights['w3'], strides=[1,1,1,1], padding='VALID') + self.biases['b3']
    return conv3

  def save(self, checkpoint_dir, step):
    model_name = "SRCNN.model"
    model_dir = "%s_%s" % ("srcnn", self.label_size)
    checkpoint_dir = os.path.join(checkpoint_dir, model_dir)#再一次確定路徑爲 checkpoint->srcnn_21下

    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    self.saver.save(self.sess,
                    os.path.join(checkpoint_dir, model_name), #文件名爲SRCNN.model-迭代次數
                    global_step=step)

  def load(self, checkpoint_dir):
    print(" [*] Reading checkpoints...")
    model_dir = "%s_%s" % ("srcnn", self.label_size)
    checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
#路徑爲checkpoint->srcnn_labelsize(21)
#加載路徑下的模型(.meta文件保存當前圖的結構; 
#.index文件保存當前參數名; .data文件保存當前參數值)
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)  
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
    if ckpt and ckpt.model_checkpoint_path:
        ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
        self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
        #saver.restore()函數給出model.-n路徑後會自動尋找參數名-值文件進行加載
    
        return True
    else:
        return False

訓練方式:SGD的效果更好

  • 【5】utils.py
"""
Scipy version > 0.18 is needed, due to 'mode' option from scipy.misc.imread function
"""

import os
import glob#導入glob庫,作用是類似於系統的文件路徑匹配查詢
import h5py#h5py庫,主要用於讀取或創建datasets或groups
import random
import matplotlib.pyplot as plt

from PIL import Image  # for loading images as YCbCr format
import scipy.misc#該庫主要用於將數組保存成圖像形式
import scipy.ndimage#該庫用於圖像處理
import numpy as np

import tensorflow as tf

try:
  xrange#處理異常中斷
except:
  xrange = range
  
FLAGS = tf.app.flags.FLAGS#命令行參數傳遞

def read_data(path):#讀取.h5文件的data和label數據,轉化np.array格式
  """
  Read h5 format data file
  讀取h5格式數據文件,用於訓練或者測試
  參數:
    路徑: 文件
    data.h5 包含訓練輸入
    label.h5 包含訓練輸出
  Args:
    path: file path of desired file
    data: '.h5' file format that contains train data values
    label: '.h5' file format that contains train label values
  """
  with h5py.File(path, 'r') as hf:#讀取h5格式數據文件(用於訓練或測試)
    data = np.array(hf.get('data'))
    label = np.array(hf.get('label'))
    return data, label

def preprocess(path, scale=3):#定義預處理函數
#(1)讀取灰度圖像;
#(2)modcrop;
#(3)歸一化;
#(4)兩次bicubic interpolation

返回input_ ,label_

make_data(sess,data,label)**
作用:將data(checkpoint下的train.h5 或test.h5)利用h5的create_dataset 寫入
  """
  #對路徑下的image裁剪成scale整數倍,再對image縮小1/scale倍後,放大scale倍以得到低分辨率圖input_,調整尺寸後的image爲高分辨率圖label_
  #image = imread(path, is_grayscale=True)
  #label_ = modcrop(image, scale)
  Preprocess single image file 
    (1) Read original image as YCbCr format (and grayscale as default)
    (2) Normalize
    (3) Apply image file with bicubic interpolation

  Args:
    path: file path of desired file
    input_: image applied bicubic interpolation (low-resolution)
    label_: image with original resolution (high-resolution)
  """
  image = imread(path, is_grayscale=True)
  label_ = modcrop(image, scale)

  # Must be normalized
  image = image / 255.
  label_ = label_ / 255.

  input_ = scipy.ndimage.interpolation.zoom(label_, (1./scale), prefilter=False)
  input_ = scipy.ndimage.interpolation.zoom(input_, (scale/1.), prefilter=False)

  return input_, label_

def prepare_data(sess, dataset):#作用:返回data是訓練集或測試集bmp格式的圖像
#(1)參數說明:dataset是train dataset 或 test dataset
#(2)glob.glob得到所有的訓練集或是測試集圖像
  """
  Args:
    dataset: choose train dataset or test dataset
    
    For train dataset, output data would be ['.../t1.bmp', '.../t2.bmp', ..., '.../t99.bmp']
  """
  if FLAGS.is_train:
    filenames = os.listdir(dataset)
    data_dir = os.path.join(os.getcwd(), dataset)
    data = glob.glob(os.path.join(data_dir, "*.bmp"))
    #(2)glob.glob得到所有的訓練集或是測試集圖像
  else:
  #確定測試數據集合的文件夾爲Set5
    data_dir = os.path.join(os.sep, (os.path.join(os.getcwd(), dataset)), "Set5")
    data = glob.glob(os.path.join(data_dir, "*.bmp"))

  return data

def make_data(sess, data, label):
  """
  Make input data as h5 file format
  Depending on 'is_train' (flag value), savepath would be changed.
  """
  #把數據保存成.h5格式
  if FLAGS.is_train:
    savepath = os.path.join(os.getcwd(), 'checkpoint/train.h5')
  else:
    savepath = os.path.join(os.getcwd(), 'checkpoint/test.h5')

  with h5py.File(savepath, 'w') as hf:
    hf.create_dataset('data', data=data)
    hf.create_dataset('label', data=label)

def imread(path, is_grayscale=True):#目的:讀取指定路徑的圖像
  """
  Read image using its path.
  Default value is gray-scale, and image is read by YCbCr format as the paper said.
  """
  #讀指定路徑的圖像
  if is_grayscale:
    return scipy.misc.imread(path, flatten=True, mode='YCbCr').astype(np.float)
  else:
    return scipy.misc.imread(path, mode='YCbCr').astype(np.float)

def modcrop(image, scale=3):
#把圖像的長和寬都變成scale的倍數
  """
  To scale down and up the original image, first thing to do is to have no remainder while scaling operation.
  
  We need to find modulo of height (and width) and scale factor.
  Then, subtract the modulo from height (and width) of original image size.
  There would be no remainder even after scaling operation.
  """
  if len(image.shape) == 3:
    h, w, _ = image.shape
    h = h - np.mod(h, scale)
    w = w - np.mod(w, scale)
    image = image[0:h, 0:w, :]
  else:
    h, w = image.shape
    h = h - np.mod(h, scale)
    w = w - np.mod(w, scale)
    image = image[0:h, 0:w]
  return image
  #把result變爲和origin一樣的大小

def input_setup(sess, config):#功能:讀取train set or test set ;做sub-images;保存成h5文件
  """
  Read image files and make their sub-images and saved them as a h5 file format.
  """
  #global nx#後加
  #global ny#後加
  #讀圖像集,製作子圖並保存爲h5文件格式
  # 讀取數據路徑
  # Load data path
  if config.is_train:
    data = prepare_data(sess, dataset="Train")
  else:
    data = prepare_data(sess, dataset="Test")

  sub_input_sequence = []
  sub_label_sequence = []
  padding = abs(config.image_size - config.label_size) / 2 # 6
#padding=0;#修改padding值,測試效果
  #訓練
  if config.is_train:
    for i in xrange(len(data)):#一幅圖作爲一個data
      input_, label_ = preprocess(data[i], config.scale)
#得到data[]的LR和HR圖input_和label_
      if len(input_.shape) == 3:
      if len(input_.shape) == 3:
        h, w, _ = input_.shape
      else:
        h, w = input_.shape
#把input_和label_分割成若干自圖sub_input和sub_label
      for x in range(0, h-config.image_size+1, config.stride):
        for y in range(0, w-config.image_size+1, config.stride):
          sub_input = input_[x:x+config.image_size, y:y+config.image_size] # [33 x 33]
          sub_label = label_[x+int(padding):x+int(padding)+config.label_size, y+int(padding):y+int(padding)+config.label_size] # [21 x 21]

          # Make channel value
          sub_input = sub_input.reshape([config.image_size, config.image_size, 1])  
          #按image size大小重排 因此 imgae_size應爲33 而label_size應爲21
          sub_label = sub_label.reshape([config.label_size, config.label_size, 1])

          sub_input_sequence.append(sub_input)
          #在sub_input_sequence末尾加sub_input中元素 但考慮爲空
          sub_label_sequence.append(sub_label)
          sub_label_sequence.append(sub_label)

  else:
  #測試
    input_, label_ = preprocess(data[2], config.scale)#測試圖片
    if len(input_.shape) == 3:
      h, w, _ = input_.shape
    else:
      h, w = input_.shape

    # Numbers of sub-images in height and width of image are needed to compute merge operation.
    nx = ny = 0 
    #自圖需要進行合併操作
    for x in range(0, h-config.image_size+1, config.stride):#x從0到h-33+1 步長stride(21)
      nx += 1; ny = 0
      for y in range(0, w-config.image_size+1, config.stride):#y從0到w-33+1 步長stride(21)
        ny += 1
        sub_input = input_[x:x+config.image_size, y:y+config.image_size] # [33 x 33]
        sub_label = label_[x+int(padding):x+int(padding)+config.label_size, y+int(padding):y+int(padding)+config.label_size] # [21 x 21]
        
        sub_input = sub_input.reshape([config.image_size, config.image_size, 1])  
        sub_label = sub_label.reshape([config.label_size, config.label_size, 1])

        sub_input_sequence.append(sub_input)
        sub_label_sequence.append(sub_label)

  """
  len(sub_input_sequence) : the number of sub_input (33 x 33 x ch) in one image
  (sub_input_sequence[0]).shape : (33, 33, 1)
  """
  # Make list to numpy array. With this transform
  # 上面的部分和訓練是一樣的
  arrdata = np.asarray(sub_input_sequence) # [?, 33, 33, 1]
  arrlabel = np.asarray(sub_label_sequence) # [?, 21, 21, 1]

  make_data(sess, arrdata, arrlabel)

  if not config.is_train:#存成h5格式
    return nx, ny
    
def imsave(image, path):
  return scipy.misc.imsave(path, image)

def merge(images, size):
  h, w = images.shape[1], images.shape[2]#覺得下標應該是0,1
  img = np.zeros((h*size[0], w*size[1], 1))
  for idx, image in enumerate(images):
    i = idx % size[1]
    j = idx // size[1]
    img[j*h:j*h+h, i*w:i*w+w, :] = image

  return img

        utils.py說明了就是一個函數池,注意下面函數就可以:

  • prepare_data(sess,dataset):返回data,data是訓練集或測試集中bmp格式的圖像。
  • input_setup(sess,config):讀取train set or test set ;做sub-images;保存成h5文件。
  • read_data(path):讀取.h5文件的data和label數據,轉化np.array格式。
  • preprocess(path,scale=3):(1)讀取灰度圖像;(2)modcrop;(3)歸一化;(4)兩次bicubic interpolation,返回input_ ,label_。即對路徑下的image裁剪成scale整數倍,再對image縮小1/scale倍後,放大scale倍以得到低分辨率圖input_,調整尺寸後的image爲高分辨率圖label_。
  • make_data(sess,data,label):將data保存爲h5格式的數據,保存到指定路徑,是通過create_dataset函數寫入的。
  • imread(path,is_grayscale=True):讀取指定路徑的圖像。
  • modcrop(image, scale=3) #把圖像的長和寬都變成scale的倍數。
  • modcrop_small(image) #把result變爲和origin一樣的大小(需要自己寫或參考其他)。
  • imsave(image,path):將scipy.misc.imsave封裝到imsave供自己使用。
  • merge(image,size):合併分割後的圖片。

到這裏差不多,代碼解讀基本完成。相信你看完之後也可以自己完成運行測試啦!

  • 【6】最後,再附一個項目運行基本流程:
  • 準備數據集(訓練集、測試集);
  • 訓練模型
  • 利用模型測試數據
  • 模型評價

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章