SRGAN生成器Deconv的pixelShuffler代码理解

假设input是 [b_s,h,w,256] 形状的数组(batch_size简写为b_s):

def pixelShuffler(inputs, scale=2):
    size = tf.shape(inputs)  # size为(b_s,h,w,256)
    batch_size = size[0]     # batch_size = b_s
    h = size[1]              
    w = size[2]
    c = inputs.get_shape().as_list()[-1]  # c = 256

    # Get the target channel size
    channel_target = c // (scale * scale)  # 256/4 = 64
    channel_factor = c // channel_target   # 4

    shape_1 = [batch_size, h, w, channel_factor // scale, channel_factor // scale] #[b_s,h,w,2,2]
    shape_2 = [batch_size, h * scale, w * scale, 1]     #[b_s,hx2,wx2,1]

    # Reshape and transpose for periodic shuffling for each channel
    input_split = tf.split(inputs, channel_target, axis=3)     #将inputs从第4维度拆分,得64个[b_s,h,w,4]形状的数组,即input_split为[b_s,h,w,4,64]的5维数组

    output = tf.concat([phaseShift(x, scale, shape_1, shape_2) for x in input_split], axis=3)    #从第四维度拼接64个[b_s,hx2,wx2,1]数组 => [b_s,hx2,wx2,64]

    return output        #


def phaseShift(inputs, scale, shape_1, shape_2):
    # Tackle the condition when the batch is None
    X = tf.reshape(inputs, shape_1)    # [b_s,h,w,4] => [b_s,h,w,2,2]
    X = tf.transpose(X, [0, 1, 3, 2, 4])    #将第三维与第四维互换 => [b_s,h,2,w,2]

    return tf.reshape(X, shape_2)    #[b_s,h,2,w,2] => [b_s,hx2,wx2,1]

其中用到的一些tensorflow的方法:

tf.split(value, num_or_size_splits, axis=0, num=None, name='split‘):传入axis的数值就代表切割哪个维度(从0开始计数)。

tf.reshape(tensor,shape,name=None):形状发生变化的原则时数组元素的个数是不能发生改变的,否则就会报错。

tf.concat([tensor1, tensor2, tensor3,...], axis):axis=0 代表在第0个维度拼接axis=1 代表在第1个维度拼接。如 axis=0时(2,3)+(2,3)=(4,3);axis=1时(2,3)+(2,3)=(2,6)。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章