機器學習中的維度合併,切分轉換,不再傻傻分不清。(numpy 和 tensorflow 中 split,concat 等維度切分合並函數用法詳解)

對於機器學習來說,需要在各種維度之間操作,轉換,然而對於初學者來說,往往讓人難感到很疑惑,之前有大佬的博客以某個矩陣爲輸入數據,詳細演示了各種維度變換,但是缺乏一定的直觀性,一堆數字看起來很煩,本篇博客以幾張圖片爲例來講述numpy,tensorflow等API關於維度轉換的函數,轉換效果清晰可見。

準備數據

需要準備四張圖片,jpg格式就好,轉換爲224 × 224 的格式大小,圖片尺寸不滿足也沒關係,下面有代碼可以將圖片轉換爲對應大小。在項目文件夾目錄下建立一個 images 目錄,複製四張jpg 圖像加入此目錄。在images 上級目錄下新建一個Python文件。輸入以下代碼:

import os
from PIL import Image
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
images=os.listdir("images")
image_set=[]
for img in images:
    image=Image.open(os.path.join("images",img))
    image=image.resize((224,224),Image.ANTIALIAS)
    image=np.asarray(image)
    image_set.append(image)
image_set=np.asarray(image_set)
print(image_set.shape)
print(type(image_set))

我們得到的輸出如下:

(4, 224, 224, 3)
<class 'numpy.ndarray'>

得到一個 numpy ndarray,維度大小分別是 (4,224,224,3),很好理解,這裏的4代表四站圖片,兩個224 分別代表圖片的長和寬,3代表通道數目,實際上,圖像有RGB三個通道組成。

我們要對這個 shape 爲 (4,224,224,3) 的數組來進行維度轉換,通過看圖片展示的效果來理解 各種維度轉換api 的用法。

numpy.vstack() , numpy.hstack() 以及 numpy.dstack()

這三個函數的用法比較簡單,先hstack 的效果圖,其它博客解釋到 hstack 是在水平方向上拼接維度,輸入一下代碼,結果如下:

import matplotlib.pyplot as plt
hstack_image=np.hstack(image_set)
plt.imshow(hstack_image)
print(

我們可以看到輸出如下的圖片和信息:
在這裏插入圖片描述
直觀的理解就是 hstack 函數會講將4張圖片按照行進行拼接,首先是第一張圖像的第一行,然後是第二張圖像的第一行,第三張圖像的第一行,第四張圖像的第一行,依次進行拼接,每一行的元素個數是之前的四倍(即列數是之前的四倍),依次拼接每一行,所以行數不會改變,可以看到拼接後的shape 爲 (224, 224 × 4,3) ,即將原來的四張圖像在一行顯示。

知道了hstack之後,就很容易知道 vstack了,這個相當於按照列進行拼接,拼接之後的圖片效果如下:

import matplotlib.pyplot as plt
vstack_image=np.vstack(image_set)
plt.imshow(vstack_image)
print(vstack_image.shape)

在這裏插入圖片描述

至於 dstack ,文檔上說是按照深度進行的拼接,即 將通道進行拼接,拼接之後通道數目變爲原來的四倍,因爲通道拼接之後無法展示,所以下面展示拼接之後的維度。
圖片展示第一張圖片,即只取出前三個通道 【0-2】。後面【3-5】爲第二張圖片,【6-8】爲第三張圖片,【9-11】爲第四張圖片,可以自行驗證。

import matplotlib.pyplot as plt
dstack_image=np.dstack(image_set)
print(dstack_image.shape)
plt.imshow(dstack_image[:,:,:3])

輸出如下:
在這裏插入圖片描述

numpy.concatenate() 函數的用法

前面提到的hstack ,vstack , dstack 可以看做是這個函數的特例。 通過傳入axis 參數,即可控制在哪一個維度上進行拼接。

如果 axis 的參數爲0 ,效果和 vstack 一致:

import matplotlib.pyplot as plt
dim_0=np.concatenate(image_set,axis=0)
print(dim_0.shape)
plt.imshow(dim_0)

在這裏插入圖片描述
這裏有一種一般的理解,對於拼接來說,一定是降維的過程,有一個維度被取消了,取消啥意思?就是原來的多個被拼接爲一個了,以這個 (4,224,224,3) 的四張圖片的numpy 數組來說意思就是拼接四張圖片爲一張圖片,我們可以通過觀察下面的代碼來猜測這個 axis 參數是什麼意思:

import matplotlib.pyplot as plt
dim_0=np.concatenate(image_set,axis=0)
dim_1=np.concatenate(image_set,axis=1)
dim_2=np.concatenate(image_set,axis=2)
print(dim_0.shape)
print(dim_1.shape)
print(dim_2.shape)

輸出結果如下:

(896, 224, 3)
(224, 896, 3)
(224, 224, 12)

我們可以看到,當axis 爲 0的時候,最終結果的第1個維度 是原來的4倍,axis 爲1的時候,最終結果的第2個維度爲原來的4倍,當axis=2的時候,最終結果的第三個維度爲原來的4倍。 所以這樣的話 axis 的作用就不言而喻了。

那我們考慮一下,如果想要 hstack 的效果,那麼圖片的大小應該是 (224,4*224,3) 這樣的話行高爲一張圖片的高度,寬度爲四張圖片拼接起來的寬度,效果如下:

import matplotlib.pyplot as plt
dim_0=np.concatenate(image_set,axis=0)
dim_1=np.concatenate(image_set,axis=1)
dim_2=np.concatenate(image_set,axis=2)
plt.imshow(dim_1)

在這裏插入圖片描述

numpy.split()

講完了維度的合併,現在要來看一下拆分了,同樣這個函數有三個參數:

  • array: 要拆分的數組
  • indices_or_sections 要拆分爲多少份或者要拆分的起始節點列表。
  • axis 要按照哪一個維度進行拆分

之前我看過別人的博客,甚至有一些大V的博客,都說axis 是按照行劃分,按照列劃分,這樣解釋讓人摸不着頭腦,最好的理解方式就是動手實驗。

首先,試驗一下axis=0 的情況:
我們先平均分爲兩份,看這兩份分別是什麼東西:

import matplotlib.pyplot as plt
split_0,split_1=np.split(image_set,2,axis=0)
print(split_0.shape)
print(split_1.shape)

輸出結果:
(2, 224, 224, 3)
(2, 224, 224, 3)

可以看出,當axis 爲 0的時候,實際上就是將原來的一個集合(具有四張圖片)劃分爲兩個集合(每一個集合有兩張圖片),那麼易得當axis 爲 0 的時候就是對 第一維度的劃分 原來是(4,224,224,3) 劃分爲兩個(2,224,224,3)。

當axis=1 的時候。依舊劃分爲兩份,代碼如下:

import matplotlib.pyplot as plt
split_0,split_1=np.split(image_set,2,axis=1)
print(split_0.shape)
print(split_1.shape)

輸出結果
(4, 112, 224, 3)
(4, 112, 224, 3)

可以看到,這裏將第二維度拆分爲了兩個,相當於將每張圖片橫着切成兩半,一個存儲着所有圖片的上半部分,一個存放着所有圖片的下半部分。不信,打印出來給你看:(打印上半部分)

import matplotlib.pyplot as plt
split_0,split_1=np.split(image_set,2,axis=1)
plt.imshow(np.hstack(split_0))

在這裏插入圖片描述
打印下半部分:

import matplotlib.pyplot as plt
split_0,split_1=np.split(image_set,2,axis=1)
plt.imshow(np.hstack(split_1))

在這裏插入圖片描述

ok,講到這裏,大家對 split 函數的axis 的作用就很清楚了,當axis =2 的時候,實際上就是將每一張圖片豎着切開,所有圖片的左半邊在一個集合,所有圖片的右半邊在另外一個集合,打印出來效果如下:

import matplotlib.pyplot as plt
split_0,split_1=np.split(image_set,2,axis=2)
plt.imshow(np.hstack(split_0))

在這裏插入圖片描述
右半邊的圖不再放出,節省篇幅~~~

ok,axis 就講到這裏了,下面說說 indices_or_sections,上面的演示中,一直將它設置爲 2 ,就是平均分爲兩份,是幾就平均分成幾份,但是特別注意的是,這裏平均分成幾份數要是整份啊,必須除得開才行,否則會報錯,這就需要使用另外一個函數了:array_split(),這個函數唯一的區別就是如果出現除不開的情況,就做不均等劃分,例如 將4 劃分爲 3份數,那麼劃分之後爲 【2,1,1】,將5劃分爲3份,就是【2,2,1】,將 6劃分爲 4份就是 【2,2,1,1】 這個很好理解。

indices_or_sections除了是一個數字代表平均劃分爲幾份之外,還可以是一個數組,裏面的元素必須遞增,這個數組表示0爲起點,一共幾份爲終點,在包括起點終點的情況下,元素的間距(我上面說的就不是人話,我自己都聽不懂)
舉個例子:
例如 我想將 4 張圖片劃分爲 【1,2,1】 ,要怎麼搞?均分不可以,array_split() 也不行,此時就可以 設置 indices_or_sections 爲:[1,3],請看下面示例代碼:

a,b,c=np.split(image_set,[1,3],axis=0)
print(a.shape)
print(b.shape)
print(c.shape)

輸出結果:
(1, 224, 224, 3)
(2, 224, 224, 3)
(1, 224, 224, 3)

看到例子,應該就清楚了~

tensorflow 中相關維度變換的函數

tensorflow 中相關的函數和numpy 函數區別不大,下面是一些演示。
爲了方便,建立placeholder來進行演示,通過查看placeholder 的維度變換體會函數的用法。

import tensorflow as tf
x=tf.placeholder(tf.float32,shape=(4,224,224,3))
print(x.get_shape())
1. tf.split() 函數

拆分函數,用法請見代碼:

import tensorflow as tf
x=tf.placeholder(tf.float32,shape=(4,224,224,3))
one,two=tf.split(x,2,axis=0)
print("one: ",one.get_shape())
print("two: ",two.get_shape())


# 特別注意,這裏的第二個參數可以是一個數字,代表平均分成多少份數,這個和numpy 的split沒有什麼不同
# 但是如果是列表的情況,裏面元素的個數就是分成多份,每個元素的值代表對應子集所佔的的份額大小,看代碼,這不難理解。

three,four,five=tf.split(x,[1,2,1],axis=0)
print("three: ",three.get_shape())
print("four: ",four.get_shape())
print("five: ",five.get_shape())

# 對於 axis 的理解,和 numpy 的 axis 沒有什麼不同。
seven,eight=tf.split(x,2,axis=1)
print("seven: ",seven.get_shape())
print("egith: ",eight.get_shape())

輸出:
one:  (2, 224, 224, 3)
two:  (2, 224, 224, 3)
three:  (1, 224, 224, 3)
four:  (2, 224, 224, 3)
five:  (1, 224, 224, 3)
seven:  (4, 112, 224, 3)
egith:  (4, 112, 224, 3)

2.tf.concat() 函數

合併函數,用法請見代碼:

x1=tf.placeholder(tf.float32,shape=(4,224,224,1))
x2=tf.placeholder(tf.float32,shape=(4,224,224,1))
x3=tf.placeholder(tf.float32,shape=(4,224,224,1))

res_0=tf.concat([x1,x2,x3],axis=0)
res_1=tf.concat([x1,x2,x3],axis=1)
res_2=tf.concat([x1,x2,x3],axis=2)
res_3=tf.concat([x1,x2,x3],axis=3)

print("res0:" ,res_0.get_shape())
print("res1:" ,res_1.get_shape())
print("res2:" ,res_2.get_shape())
print("res3:" ,res_3.get_shape())

輸出:
res0: (12, 224, 224, 1)
res1: (4, 672, 224, 1)
res2: (4, 224, 672, 1)
res3: (4, 224, 224, 3)

不解釋了,大家都懂。


發佈了57 篇原創文章 · 獲贊 29 · 訪問量 5萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章