SRCNN代码及注释

SRCNN训练流程分析:https://download.csdn.net/download/baixue0729/12451414

main.py

from model import SRCNN
from utils import input_setup

import numpy as np
import tensorflow as tf

import pprint
import os

'''
定义训练和测试参数
(包括:如果采用SGD时的batchSize、学习率、步长stride、训练还是测试模式),
此后由设定的参数进行训练或测试。
'''

flags = tf.app.flags  #命令行执行时传参数,使命令行运行的时候可以定义里面的参数
#第一个是参数名称,第二个参数是默认值,第三个是参数描述
flags.DEFINE_integer("epoch", 15000, "Number of epoch [15000]")
#一个batch更新一次参数
flags.DEFINE_integer("batch_size", 128, "The size of batch images [128]")
flags.DEFINE_integer("image_size", 33, "The size of image to use [33]")
#因为卷积时不进行padding,三层卷积后特征尺寸变为21,因此将图像中心的21*21的小图像作为标签值。
flags.DEFINE_integer("label_size", 21, "The size of label to produce [21]")
##学习率文中设置为:前两层1e-4 第三层1e-5。SGD+指数学习率10-2作为初始?
flags.DEFINE_float("learning_rate", 1e-4, "The learning rate of gradient descent algorithm [1e-4]")
#图像颜色维度
flags.DEFINE_integer("c_dim", 1, "Dimension of image color. [1]")
#三次插值时的尺度
flags.DEFINE_integer("scale", 3, "The size of scale factor for preprocessing input image [3]")
#卷积步长:训练采用14(可变,越小训练集越多);测试采用21,因为image_size33-(image_size33-label_size21)=21
flags.DEFINE_integer("stride", 14, "The size of stride to apply input image [14]")
#checkpoint所在文件夹名称,数据保存检查点
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Name of checkpoint directory [checkpoint]")
#sample所在文件夹名称
flags.DEFINE_string("sample_dir", "sample", "Name of sample directory [sample]")
flags.DEFINE_boolean("is_train", True, "True for training, False for testing [True]")
FLAGS = flags.FLAGS
#测试:stride=21,is_train=False。训练:stride=14,is_train=True。

#创建了一个打印的类,这样就可以调用pp的函数了
pp = pprint.PrettyPrinter()

def main(_):
  pp.pprint(flags.FLAGS.__flags)

  if not os.path.exists(FLAGS.checkpoint_dir):
    os.makedirs(FLAGS.checkpoint_dir)  #在当前地址创建"checkpoint"文件夹
  if not os.path.exists(FLAGS.sample_dir):
    os.makedirs(FLAGS.sample_dir)  #在当前地址创建"sample"文件夹

  with tf.Session() as sess:
    srcnn = SRCNN(sess, 
                  image_size=FLAGS.image_size, 
                  label_size=FLAGS.label_size, 
                  batch_size=FLAGS.batch_size,
                  c_dim=FLAGS.c_dim, 
                  checkpoint_dir=FLAGS.checkpoint_dir,
                  sample_dir=FLAGS.sample_dir)   #创建一个SRCNN对象,自动调用初始化函数

    srcnn.train(FLAGS)
    
if __name__ == '__main__':   #import到其他脚本中不会被执行
  tf.app.run()

model.py

from utils import (
  read_data, 
  input_setup, 
  imsave,
  merge
)

import time
import os
import matplotlib.pyplot as plt

import numpy as np
import tensorflow as tf

try:
  xrange
except:
  xrange = range

class SRCNN(object):
  #对象初始化设置
  def __init__(self, 
               sess, 
               image_size=33,
               label_size=21, 
               batch_size=128,
               c_dim=1, 
               checkpoint_dir=None, 
               sample_dir=None):

    self.sess = sess
    self.is_grayscale = (c_dim == 1)
    self.image_size = image_size
    self.label_size = label_size
    self.batch_size = batch_size

    self.c_dim = c_dim  #?

    self.checkpoint_dir = checkpoint_dir
    self.sample_dir = sample_dir
    self.build_model()

  '''
  三次卷积,卷积核大小分别是9,1,5。输出通道分别是64,32,1。
    #第一层CNN:对输入图片的特征提取。(9 x 9 x 64卷积核)
    #第二层CNN:对第一层提取的特征的非线性映射(1 x 1 x 32卷积核)
    #第三层CNN:对映射后的特征进行重建,生成高分辨率图像(5 x 5 x 1卷积核)
  
  '''
  def build_model(self):
    self.images = tf.placeholder(tf.float32, [None, self.image_size, self.image_size, self.c_dim], name='images')
    self.labels = tf.placeholder(tf.float32, [None, self.label_size, self.label_size, self.c_dim], name='labels')

    self.weights = {
      # 卷积核:f1*f1*c*n1. c为输入图像通道数,文中取YCrCb中Y通道,c=1;f1=9;n1为当前卷积核输出深度取64
      'w1': tf.Variable(tf.random_normal([9, 9, 1, 64], stddev=1e-3), name='w1'),
      'w2': tf.Variable(tf.random_normal([1, 1, 64, 32], stddev=1e-3), name='w2'),
      'w3': tf.Variable(tf.random_normal([5, 5, 32, 1], stddev=1e-3), name='w3')
    }
    self.biases = {
      'b1': tf.Variable(tf.zeros([64]), name='b1'),
      'b2': tf.Variable(tf.zeros([32]), name='b2'),
      'b3': tf.Variable(tf.zeros([1]), name='b3')
    }

    self.pred = self.model()  #只要调用了model函数就会有输出,返回三层卷积后的结果

    # Loss function (MSE)  loss用的是MSE
    self.loss = tf.reduce_mean(tf.square(self.labels - self.pred))

    # 主函数调用(训练或测试),创建一个Saver变量
    self.saver = tf.train.Saver()

  def train(self, config):
    #数据的准备
    if config.is_train:
      input_setup(self.sess, config)  #读取图像文件并制作其子图像,并将其保存为h5文件格式。
    else:
      nx, ny = input_setup(self.sess, config)

    # 训练为checkpoint下train.h5;测试为checkpoint下test.h5
    if config.is_train:     
      data_dir = os.path.join('./{}'.format(config.checkpoint_dir), "train.h5")
    else:
      data_dir = os.path.join('./{}'.format(config.checkpoint_dir), "test.h5")

    train_data, train_label = read_data(data_dir)

    #建立优化器,初始化所有参数,计数器,计时器
    #采用SGD(具有标准反向传播的随机梯度下降)优化器。还有Adam优化器,据说SGD的效果更好(待验证)
    self.train_op = tf.train.GradientDescentOptimizer(config.learning_rate).minimize(self.loss)

    tf.initialize_all_variables().run()
    
    counter = 0  #计数器
    start_time = time.time()  #计时器

    #加载训练过的参数
    if self.load(self.checkpoint_dir):
      print(" [*] Load SUCCESS")
    else:
      print(" [!] Load failed...")

    #训练模型,每隔500次保存一次模型
    if config.is_train:
      print("Training...")

      for ep in xrange(config.epoch):  #对于每次epoch
        # Run by batch images
        batch_idxs = len(train_data) // config.batch_size  #计算一个epoch有多少batch
        for idx in xrange(0, batch_idxs):  #以batch为单元更新
          batch_images = train_data[idx*config.batch_size : (idx+1)*config.batch_size]
          batch_labels = train_label[idx*config.batch_size : (idx+1)*config.batch_size]

          counter += 1
          _, err = self.sess.run([self.train_op, self.loss], feed_dict={self.images: batch_images, self.labels: batch_labels})

          if counter % 10 == 0:  #每更新10次输出一次数据
            print("Epoch: [%2d], step: [%2d], time: [%4.4f], loss: [%.8f]" \
              % ((ep+1), counter, time.time()-start_time, err))

          if counter % 500 == 0:  #每更新500次保存一次数据
            self.save(config.checkpoint_dir, counter)  #确认路径,存储sess

    else:
      print("Testing...")

      result = self.pred.eval({self.images: train_data, self.labels: train_label})

      result = merge(result, [nx, ny])
      result = result.squeeze()  #squeeze去除维度为1的地方
      image_path = os.path.join(os.getcwd(), config.sample_dir)
      image_path = os.path.join(image_path, "test_image.png")
      imsave(result, image_path)

  def model(self):
    '''
    将图片经过3次卷积,步长都是1。
    卷积加偏置,前两层有relu激活函数,最后一层无激活函数。
    :return: 最后的一次的卷积结果

    strides在官方定义中是一个一维具有四个元素的张量,
    其规定前后必须为1,所以我们可以改的是中间两个数,
    中间两个数分别代表了水平滑动和垂直滑动步长值。
    '''
    conv1 = tf.nn.relu(tf.nn.conv2d(self.images, self.weights['w1'], strides=[1,1,1,1], padding='VALID') + self.biases['b1'])
    conv2 = tf.nn.relu(tf.nn.conv2d(conv1, self.weights['w2'], strides=[1,1,1,1], padding='VALID') + self.biases['b2'])
    conv3 = tf.nn.conv2d(conv2, self.weights['w3'], strides=[1,1,1,1], padding='VALID') + self.biases['b3']
    return conv3

  #确认路径,存储sess
  def save(self, checkpoint_dir, step):
    model_name = "SRCNN.model"
    model_dir = "%s_%s" % ("srcnn", self.label_size)
    checkpoint_dir = os.path.join(checkpoint_dir, model_dir)  #存储路径为checkpoint/srcnn_label_size

    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    '''
    向文件夹中写入包含当前模型中所有可训练变量的checkpoint文件,
    之后可以使用saver.restore()方法,重载模型的参数,继续训练或者用于测试数据
    '''
    self.saver.save(self.sess,
                    os.path.join(checkpoint_dir, model_name),
                    global_step=step)

  #加载sess,成功加载返回True,否则返回False。
  def load(self, checkpoint_dir):
    print(" [*] Reading checkpoints...")
    model_dir = "%s_%s" % ("srcnn", self.label_size)
    checkpoint_dir = os.path.join(checkpoint_dir, model_dir)

    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)  #从文件夹中获取checkpoint文件
    if ckpt and ckpt.model_checkpoint_path:
        ckpt_name = os.path.basename(ckpt.model_checkpoint_path)  #获取checkpoint文件的文件名
        self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))  #重载模型的参数,继续训练或者用于测试数据
        return True
    else:
        return False

utils.py

"""
Scipy version > 0.18 is needed, due to 'mode' option from scipy.misc.imread function
由于scipy.misc.imread函数的'mode'选项,需要Scipy版本> 0.18
"""

import os
import glob  #glob库,作用是类似于系统的文件路径匹配查询
import h5py  #h5py库,主要用于读取或创建datasets或groups
import random
import matplotlib.pyplot as plt

from PIL import Image  # for loading images as YCbCr format
import scipy.misc  #该库主要用于将数组保存成图像形式
import scipy.ndimage  #该库用于图像处理
import numpy as np

import tensorflow as tf

try:
  xrange  #处理异常中断
except:
  xrange = range
  
FLAGS = tf.app.flags.FLAGS  #命令行参数传递

def read_data(path):
  """
  Read h5 format data file
  读取h5文件中的data和label数据,需要将其转换成np.array格式
  Args:
    path: file path of desired file
    data: '.h5' file format that contains train data values
    label: '.h5' file format that contains train label values
  """
  with h5py.File(path, 'r') as hf:  #读取h5格式数据文件(用于训练或测试)
    data = np.array(hf.get('data'))
    label = np.array(hf.get('label'))
    return data, label

def preprocess(path, scale=3):
  """
  处理图片,input_,label_分别是输入和输出的图片,对应低分辨率和高分辨率
  Preprocess single image file 
    (1) Read original image as YCbCr format (and grayscale as default)
    (2) Normalize
    (3) Apply image file with bicubic interpolation
  预处理单个图像文件
     (1)以YCbCr格式读取原始图像(默认为灰度)
     (2)归一化  (将0-255的uint8型数据转换到0-1之间)
     (3)应用三次插值的图像文件
  Args:
    path: file path of desired file
    input_: image applied bicubic interpolation (low-resolution)
    label_: image with original resolution (high-resolution)
  """
  image = imread(path, is_grayscale=True)  #n维数组对象,灰度图。暂不确定颜色通道数是1还是3,假设是1
  label_ = modcrop(image, scale)  #将图片规整到可以被scale整除的宽和高,例如:scale=3时(197,176,1?)->(195,174,1)

  # Must be normalized
  image = image / 255.    #无用
  label_ = label_ / 255.

  '''
  scipy.ndimage.interpolation.zoom(
      input,  #输入数组
      zoom,   #沿轴的缩放系数。
      output=None,  #放置输出的数组,或返回数组的dtype
      order=3,      #样条插值的顺序(0~5),=3:三次插值,=0:最近插值,=1:双线性插值.
      mode='constant', #根据给定的模式('常数','最近','反映'或'换行')填充输入边界之外的点
      cval=0.0,        #如果mode ='constant',则用于输入边界之外的点的值。
      prefilter=True)  #是否在插值之前使用spline_filter进行预过滤,如果为False,则假定输入已被过滤。
  '''
  input_ = scipy.ndimage.interpolation.zoom(label_, (1./scale), prefilter=False)  #使用三次插值缩小scale倍 (195,174,1)->(65,58,1)
  input_ = scipy.ndimage.interpolation.zoom(input_, (scale/1.), prefilter=False)  #使用三次插值扩大scale倍 (65,58,1)->(195,174,1)

  return input_, label_   #低分辨率,高分辨率 (195,174,1)

def prepare_data(sess, dataset):
  """
  得到图片路径list
  Args:
    dataset: choose train dataset or test dataset
    
    For train dataset, output data would be ['.../t1.bmp', '.../t2.bmp', ..., '.../t99.bmp']
  """
  if FLAGS.is_train:
    filenames = os.listdir(dataset)
    data_dir = os.path.join(os.getcwd(), dataset)  #os.getcwd():得到当前文件路径
    data = glob.glob(os.path.join(data_dir, "*.png"))  #.bmp
  else:
    data_dir = os.path.join(os.sep, (os.path.join(os.getcwd(), dataset)), "Set")   #os.sep():无需考虑],\和/
    data = glob.glob(os.path.join(data_dir, "*.png"))

  return data

def make_data(sess, data, label):
  """
  制作h5文件,将data(checkpoint下的train.h5 或test.h5)利用h5的create_dataset 写入
  Make input data as h5 file format
  Depending on 'is_train' (flag value), savepath would be changed.
  将输入数据设置为h5文件格式
  根据“ is_train”(标志值),保存路径将被更改。
  """
  if FLAGS.is_train:
    savepath = os.path.join(os.getcwd(), 'checkpoint/train.h5') #定义h5文件保存地址
  else:
    savepath = os.path.join(os.getcwd(), 'checkpoint/test.h5')

  with h5py.File(savepath, 'w') as hf:
    hf.create_dataset('data', data=data)   #建立一个名叫"data"的HDF5数据集
    hf.create_dataset('label', data=label) #建立一个名叫"label"的HDF5数据集

def imread(path, is_grayscale=True):
  """
  Read image using its path.
  Default value is gray-scale, and image is read by YCbCr format as the paper said.
  使用其路径读取图像。
  默认值为灰度,如论文所述,图像以YCbCr格式读取。
  """
  if is_grayscale:
    return scipy.misc.imread(path, flatten=True, mode='L').astype(np.float)  #flatten参数:将彩色图层变为灰度
  else:
    return scipy.misc.imread(path, mode='L').astype(np.float)    #mode='YCbCr'

def modcrop(image, scale=3):
  """
  将图片规整到可以被scale整除的宽高。
  To scale down and up the original image, first thing to do is to have no remainder while scaling operation.
  
  We need to find modulo of height (and width) and scale factor.
  Then, subtract the modulo from height (and width) of original image size.
  There would be no remainder even after scaling operation.
  要按比例放大和缩小原始图像,首先要做的是在缩放操作时没有余数。 
   我们需要找到高度(和宽度)和比例因子的模。
   然后,从原始图像尺寸的高度(和宽度)中减去模。
   即使在缩放操作之后也将没有剩余。
  """
  if len(image.shape) == 3:
    h, w, _ = image.shape       #eg: image.shape=(197, 176, 3), scale=3
    h = h - np.mod(h, scale)    #h=195
    w = w - np.mod(w, scale)    #w=174
    image = image[0:h, 0:w, :]
  else:
    h, w = image.shape
    h = h - np.mod(h, scale)
    w = w - np.mod(w, scale)
    image = image[0:h, 0:w]
  return image

def input_setup(sess, config):
  """
  就是把输入和输出图片,切成一小块存起来
  Read image files and make their sub-images and saved them as a h5 file format.
  读取图像文件并制作其子图像,并将其保存为h5文件格式。
  """
  # Load data path
  if config.is_train:
    data = prepare_data(sess, dataset="Train")  #图的位置列表
  else:
    data = prepare_data(sess, dataset="Test")

  sub_input_sequence = []
  sub_label_sequence = []
  padding = abs(config.image_size - config.label_size) / 2  #(33-21)/2=6

  if config.is_train:
    for i in xrange(len(data)):  #对于每张图
      input_, label_ = preprocess(data[i], config.scale)  #低分辨率,高分辨率

      if len(input_.shape) == 3:
        h, w, _ = input_.shape  #195,174,1
      else:
        h, w = input_.shape

      for x in range(0, h-config.image_size+1, config.stride):   #(0~163),stride=测试用21,为了拼接。为什么要加1?
        for y in range(0, w-config.image_size+1, config.stride): #(0~142)
          # (x ~ x+33, y ~ y+33),裁剪成[33*33]的小图
          sub_input = input_[x:x+config.image_size, y:y+config.image_size]
          # (x+6 ~ x+6+21,y+6 ~ y+6+21),裁剪成[21*21]的小图,舍弃边缘部分
          sub_label = label_[x+int(padding):x+int(padding)+config.label_size, y+int(padding):y+int(padding)+config.label_size]

          # Make channel value,重定义图片块大小:image_size*image_size*1, 为什么要定义成三通道
          sub_input = sub_input.reshape([config.image_size, config.image_size, 1])  #image_size:33,即33*33*1
          sub_label = sub_label.reshape([config.label_size, config.label_size, 1])  #label_size:21,即21*21*1

          sub_input_sequence.append(sub_input)  #将小图添加到输入列表
          sub_label_sequence.append(sub_label)  #将小图添加到标签列表

  else:  #test的话少了一个循环,它只对其中的一张图片做处理。
    input_, label_ = preprocess(data[0], config.scale)

    if len(input_.shape) == 3:
      h, w, _ = input_.shape
    else:
      h, w = input_.shape

    # Numbers of sub-images in height and width of image are needed to compute merge operation.
    # 计算合并操作需要图像高度和宽度中的子图像数量。
    nx = ny = 0 
    for x in range(0, h-config.image_size+1, config.stride):
      nx += 1; ny = 0
      for y in range(0, w-config.image_size+1, config.stride):
        ny += 1
        sub_input = input_[x:x+config.image_size, y:y+config.image_size] # [33 x 33]
        sub_label = label_[x+int(padding):x+int(padding)+config.label_size, y+int(padding):y+int(padding)+config.label_size] # [21 x 21]
        
        sub_input = sub_input.reshape([config.image_size, config.image_size, 1])  
        sub_label = sub_label.reshape([config.label_size, config.label_size, 1])

        sub_input_sequence.append(sub_input)
        sub_label_sequence.append(sub_label)

  """
  len(sub_input_sequence) : the number of sub_input (33 x 33 x ch) in one image
  (sub_input_sequence[0]).shape : (33, 33, 1)
  """
  # Make list to numpy array. With this transform
  # 需要将数据转成numpy类型,被存为h5格式。
  arrdata = np.asarray(sub_input_sequence)  # 输入集,格式为[?, 33, 33, 1],?代表小图总数
  arrlabel = np.asarray(sub_label_sequence) # 标签集,格式为[?, 21, 21, 1]

  make_data(sess, arrdata, arrlabel)  #制作h5文件,这个是产生训练数据的函数

  #nx,ny分别是有多少列,多少行。
  if not config.is_train:
    return nx, ny
    
def imsave(image, path):  #保存图片
  return scipy.misc.imsave(path, image)

def merge(images, size):
  """
  将一个batch内的图片拼接在一起
  images:一个batch图片,
  size:第一个参数是高度上有几张图片,第二个宽度上有几张图片
  %:求模运算,取余。
  //:取整,返回不大于结果的一个最大的整数
  /:浮点数除法
  """
  h, w = images.shape[1], images.shape[2]
  img = np.zeros((h*size[0], w*size[1], 1))
  for idx, image in enumerate(images):
    i = idx % size[1]
    j = idx // size[1]
    img[j*h:j*h+h, i*w:i*w+w, :] = image

  return img
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章