SRGAN生成器Deconv的pixelShuffler代碼理解

假設input是 [b_s,h,w,256] 形狀的數組(batch_size簡寫爲b_s):

def pixelShuffler(inputs, scale=2):
    size = tf.shape(inputs)  # size爲(b_s,h,w,256)
    batch_size = size[0]     # batch_size = b_s
    h = size[1]              
    w = size[2]
    c = inputs.get_shape().as_list()[-1]  # c = 256

    # Get the target channel size
    channel_target = c // (scale * scale)  # 256/4 = 64
    channel_factor = c // channel_target   # 4

    shape_1 = [batch_size, h, w, channel_factor // scale, channel_factor // scale] #[b_s,h,w,2,2]
    shape_2 = [batch_size, h * scale, w * scale, 1]     #[b_s,hx2,wx2,1]

    # Reshape and transpose for periodic shuffling for each channel
    input_split = tf.split(inputs, channel_target, axis=3)     #將inputs從第4維度拆分,得64個[b_s,h,w,4]形狀的數組,即input_split爲[b_s,h,w,4,64]的5維數組

    output = tf.concat([phaseShift(x, scale, shape_1, shape_2) for x in input_split], axis=3)    #從第四維度拼接64個[b_s,hx2,wx2,1]數組 => [b_s,hx2,wx2,64]

    return output        #


def phaseShift(inputs, scale, shape_1, shape_2):
    # Tackle the condition when the batch is None
    X = tf.reshape(inputs, shape_1)    # [b_s,h,w,4] => [b_s,h,w,2,2]
    X = tf.transpose(X, [0, 1, 3, 2, 4])    #將第三維與第四維互換 => [b_s,h,2,w,2]

    return tf.reshape(X, shape_2)    #[b_s,h,2,w,2] => [b_s,hx2,wx2,1]

其中用到的一些tensorflow的方法:

tf.split(value, num_or_size_splits, axis=0, num=None, name='split‘):傳入axis的數值就代表切割哪個維度(從0開始計數)。

tf.reshape(tensor,shape,name=None):形狀發生變化的原則時數組元素的個數是不能發生改變的,否則就會報錯。

tf.concat([tensor1, tensor2, tensor3,...], axis):axis=0 代表在第0個維度拼接axis=1 代表在第1個維度拼接。如 axis=0時(2,3)+(2,3)=(4,3);axis=1時(2,3)+(2,3)=(2,6)。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章