假設input是 [b_s,h,w,256] 形狀的數組(batch_size簡寫爲b_s):
def pixelShuffler(inputs, scale=2):
size = tf.shape(inputs) # size爲(b_s,h,w,256)
batch_size = size[0] # batch_size = b_s
h = size[1]
w = size[2]
c = inputs.get_shape().as_list()[-1] # c = 256
# Get the target channel size
channel_target = c // (scale * scale) # 256/4 = 64
channel_factor = c // channel_target # 4
shape_1 = [batch_size, h, w, channel_factor // scale, channel_factor // scale] #[b_s,h,w,2,2]
shape_2 = [batch_size, h * scale, w * scale, 1] #[b_s,hx2,wx2,1]
# Reshape and transpose for periodic shuffling for each channel
input_split = tf.split(inputs, channel_target, axis=3) #將inputs從第4維度拆分,得64個[b_s,h,w,4]形狀的數組,即input_split爲[b_s,h,w,4,64]的5維數組
output = tf.concat([phaseShift(x, scale, shape_1, shape_2) for x in input_split], axis=3) #從第四維度拼接64個[b_s,hx2,wx2,1]數組 => [b_s,hx2,wx2,64]
return output #
def phaseShift(inputs, scale, shape_1, shape_2):
# Tackle the condition when the batch is None
X = tf.reshape(inputs, shape_1) # [b_s,h,w,4] => [b_s,h,w,2,2]
X = tf.transpose(X, [0, 1, 3, 2, 4]) #將第三維與第四維互換 => [b_s,h,2,w,2]
return tf.reshape(X, shape_2) #[b_s,h,2,w,2] => [b_s,hx2,wx2,1]
其中用到的一些tensorflow的方法:
tf.split(value, num_or_size_splits, axis=0, num=None, name='split‘):傳入axis的數值就代表切割哪個維度(從0開始計數)。
tf.reshape(tensor,shape,name=None):形狀發生變化的原則時數組元素的個數是不能發生改變的,否則就會報錯。
tf.concat([tensor1, tensor2, tensor3,...], axis):axis=0 代表在第0個維度拼接axis=1 代表在第1個維度拼接。如 axis=0時(2,3)+(2,3)=(4,3);axis=1時(2,3)+(2,3)=(2,6)。