假设input是 [b_s,h,w,256] 形状的数组(batch_size简写为b_s):
def pixelShuffler(inputs, scale=2):
size = tf.shape(inputs) # size为(b_s,h,w,256)
batch_size = size[0] # batch_size = b_s
h = size[1]
w = size[2]
c = inputs.get_shape().as_list()[-1] # c = 256
# Get the target channel size
channel_target = c // (scale * scale) # 256/4 = 64
channel_factor = c // channel_target # 4
shape_1 = [batch_size, h, w, channel_factor // scale, channel_factor // scale] #[b_s,h,w,2,2]
shape_2 = [batch_size, h * scale, w * scale, 1] #[b_s,hx2,wx2,1]
# Reshape and transpose for periodic shuffling for each channel
input_split = tf.split(inputs, channel_target, axis=3) #将inputs从第4维度拆分,得64个[b_s,h,w,4]形状的数组,即input_split为[b_s,h,w,4,64]的5维数组
output = tf.concat([phaseShift(x, scale, shape_1, shape_2) for x in input_split], axis=3) #从第四维度拼接64个[b_s,hx2,wx2,1]数组 => [b_s,hx2,wx2,64]
return output #
def phaseShift(inputs, scale, shape_1, shape_2):
# Tackle the condition when the batch is None
X = tf.reshape(inputs, shape_1) # [b_s,h,w,4] => [b_s,h,w,2,2]
X = tf.transpose(X, [0, 1, 3, 2, 4]) #将第三维与第四维互换 => [b_s,h,2,w,2]
return tf.reshape(X, shape_2) #[b_s,h,2,w,2] => [b_s,hx2,wx2,1]
其中用到的一些tensorflow的方法:
tf.split(value, num_or_size_splits, axis=0, num=None, name='split‘):传入axis的数值就代表切割哪个维度(从0开始计数)。
tf.reshape(tensor,shape,name=None):形状发生变化的原则时数组元素的个数是不能发生改变的,否则就会报错。
tf.concat([tensor1, tensor2, tensor3,...], axis):axis=0 代表在第0个维度拼接axis=1 代表在第1个维度拼接。如 axis=0时(2,3)+(2,3)=(4,3);axis=1时(2,3)+(2,3)=(2,6)。