Pytorch——CNN Flatten Operation Visualized - Tensor Batch Processing for Deep Learning

Flatten operation for a batch of image inputs to a CNN

Flattening an entire tensor

To flatten a tensor, we need to have at least two axes. This makes it so that we are starting with something that is not already flat.
But what if we want to only flatten specific axes within the tensor? This is typically required when working with CNNs

Flattening specific axes of a tensor

Tensor inputs to a convolutional neural network typically have 4 axes, one for batch size, one for color channels, and one each for height and width.

  • (Batch Size, Channels, Height, Width)

Suppose we have the following three tensors.

t1 = torch.tensor([
    [1,1,1,1],
    [1,1,1,1],
    [1,1,1,1],
    [1,1,1,1]
])

t2 = torch.tensor([
    [2,2,2,2],
    [2,2,2,2],
    [2,2,2,2],
    [2,2,2,2]
])

t3 = torch.tensor([
    [3,3,3,3],
    [3,3,3,3],
    [3,3,3,3],
    [3,3,3,3]
])

Each of these has a shape of 4 x 4, so we have three rank-2 tensors. For our purposes here, we’ll consider these to be three 4 x 4 images that well use to create a batch that can be passed to a CNN.
Remember, batches are represented using a single tensor, so we’ll need to combine these three tensors into a single larger tensor that has three axes instead of 2.

> t = torch.stack((t1, t2, t3))
> t.shape

torch.Size([3, 4, 4])

The axis with a length of 3 represents the batch size while the axes of length 4 represent the height and width respectively. This is what the output for this this tensor representation of batch looks like.

> t
tensor([[[1, 1, 1, 1],
         [1, 1, 1, 1],
         [1, 1, 1, 1],
         [1, 1, 1, 1]],

        [[2, 2, 2, 2],
         [2, 2, 2, 2],
         [2, 2, 2, 2],
         [2, 2, 2, 2]],

        [[3, 3, 3, 3],
         [3, 3, 3, 3],
         [3, 3, 3, 3],
         [3, 3, 3, 3]]])

At this point, we have a rank-3 tensor that contains a batch of three 4 x 4 images. All we need to do now to get this tensor into a form that a CNN expects is add an axis for the color channels. We basically have an implicit single color channel for each of these image tensors, so in practice, these would be grayscale images.
A CNN will expect to see an explicit color channel axis, so let’s add one by reshaping this tensor.

 t = t.reshape(3,1,4,4)
> t
tensor(
[
    [
        [
            [1, 1, 1, 1],
            [1, 1, 1, 1],
            [1, 1, 1, 1],
            [1, 1, 1, 1]
        ]
    ],
    [
        [
            [2, 2, 2, 2],
            [2, 2, 2, 2],
            [2, 2, 2, 2],
            [2, 2, 2, 2]
        ]
    ],
    [
        [
            [3, 3, 3, 3],
            [3, 3, 3, 3],
            [3, 3, 3, 3],
            [3, 3, 3, 3]
        ]
    ]
])

The first axis has 3 elements. Each element of the first axis represents an image. For each image, we have a single color channel on the channel axis. Each of these channels contain 4 arrays that contain 4 numbers or scalar components.
We have the first image.

> t[0]
tensor([[[1, 1, 1, 1],
         [1, 1, 1, 1],
         [1, 1, 1, 1],
         [1, 1, 1, 1]]])

We have the first color channel in the first image.

> t[0][0]
tensor([[1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1]])

We have the first first row of pixels in the first color channel of the first image.

> t[0][0][0]
tensor([1, 1, 1, 1])

We have the first pixel value in the first row of the first color channel of the first image.

> t[0][0][0][0]
tensor(1)

Flattening the tensor batch

Let’s flatten the whole thing first just to see what it will look like.

> t.reshape(1,-1)[0] # Method 1
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
    2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])

> t.reshape(-1) # Method 2
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
    2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])

> t.view(t.numel()) # Method 3
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
    2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])

> t.flatten() # Method 4
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
    2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])

We want our flattened batch work well inside our CNN, the way to do it is to flatten each image while still maintaining the batch axis. We want to flatten the, color channel axis with the height and width axes.These axes need to be flattened: (C,H,W)

Flattening specific axes of a tensor

從維度1開始,之後的全部被flatten,保留維度0不變

> t.flatten(start_dim=1).shape
torch.Size([3, 16])

> t.flatten(start_dim=1)
tensor(
[
    [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
    [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
    [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]
]
)

The start_dim parameter tells the flatten() method which axis it should start the flatten operation. The one here is an index, so it’s the second axis which is the color channel axis.

Flattening an RGB Image

If we flatten an RGB image, each color channel will be flattened first. Then, the flattened channels will be lined up side by side on a single axis of the tensor. Let’s look at an example in code.

# torch.ones()注意
>torch.ones(2,2)
tensor([[1., 1.],
        [1., 1.]])


> torch.ones(1,2,2)
tensor([[[1., 1.],
         [1., 1.]]])

> torch.ones(2, 2, 2)
tensor([[[1., 1.],
         [1., 1.]],
			
	 [1., 1.],
	 [1., 1.]]])

We’ll build an example RGB image tensor with a height of two and a width of two.

r = torch.ones(1,2,2)
g = torch.ones(1,2,2) + 1
b = torch.ones(1,2,2) + 2

img = torch.cat(
    (r,g,b)
    ,dim=0
)

> img.shape
torch.Size([3, 2, 2])

> img
tensor([
    [
        [1., 1.]
        ,[1., 1.]
    ]
    ,[
        [2., 2.]
        , [2., 2.]
    ],
    [
        [3., 3.]
        ,[3., 3.]
    ]
])

> img.flatten(start_dim=0)
tensor([1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3.])

> img.flatten(start_dim=1)
tensor([
    [1., 1., 1., 1.],
    [2., 2., 2., 2.],
    [3., 3., 3., 3.]
])
發佈了24 篇原創文章 · 獲贊 6 · 訪問量 3686
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章