開篇
在學習過程中,我們經常會用到paddle.fluid.data
,這個方法定義的是網絡的輸入層,可以說是一個網絡的入口。fluid.data
有一個重要的參數shape,通常需要根據定義的reader中的數據shape,來確定該shape的值。在學習過程中,案例代碼中shape參數值通常都是確定好的,學習者可以很好的運行代碼。但是拓展一下,需要自己來確定shape值的時候,很多初學者就會有霧裏看花的感覺,無從下手。基於筆者踩過的坑,和大家一起學習。
官方文檔中的fluid.data
首先看一下在官方文檔中如何說明shape:
paddle.fluid.data(name, shape, dtype='float32', lod_level=0)
-
shape (list|tuple)- 聲明維度信息的list或tuple。在示例代碼中也給出瞭解釋:
# Creates a variable with fixed size [3, 2, 1]
# User can only feed data of the same shape to x
x = fluid.data(name='x', shape=[3, 2, 1], dtype='float32')
# Creates a variable with changable batch size -1.
# Users can feed data of any batch size into y,
# but size of each data sample has to be [2, 1]
y = fluid.data(name='y', shape=[-1, 2, 1], dtype='float32')
大意就是:
-
x的shape中第一個維度是固定的,則feed的數據要和該shape完全一致 -
y的shape中第一個維度爲-1,則可以是任意值,feed的數據第一個維度是可變的,但是同樣要滿足數組運算的法則
深入解讀
文檔中給出的信息就這麼多,不知道您有沒有明白。不明白沒關係,下面再詳細說一下筆者的理解。
對於固定的shape很好理解,比如一張shape爲m * m * 3的彩色圖片,在fluid.data中shape定義爲[3, m, m],則表示每次輸入一個三通道,每個通道爲m * m的數組。
對於可變shape,通常應用於每次輸入的是一組batch數據。比如在fluid.data中shape定義爲[-1, 3, m, m],則表示爲每次輸入一個batch size大小的三通道,每個通道爲m * m的數組。
筆者在學習過程中經常思考一個問題,(m, m, 3)表示彩色圖片通道數在後面,然而在fluid.data中如果沒有batch,則用(3, m, m)表示輸入的是三通道的彩色圖片。爲什麼會這樣呢?
答案就在於,在卷積過程中,卷積計算是按照通道進行計算的。3通道的圖片會和3通道的卷積核進行卷積計算。
拓展延伸
對於初學者可能對圖片如何用數組表示還不是非常清楚,筆者在這裏詳細介紹一下。先看一下原圖片:
該圖片的shape爲:(32, 32, 3),長寬爲32像素,每個像素點由類似於[180 191 203]的數組表示,該像素點數組的三個值表示RGB三個通道,每個通道取值範圍爲0-255。
圖片輸出爲數組則爲:
[[[180 191 203]
[180 188 199]
[196 200 209]
...
[160 167 184]
[156 163 182]
[162 170 188]]
[[175 189 197]
[179 190 199]
[193 193 198]
...
單通道圖片爲:
單通道圖片shape:(32, 32),這裏每個像素由一個0-255的值表示。
圖片輸出爲數組則爲:
[[180 180 196 ... 160 156 162]
[175 179 193 ... 157 154 157]
[177 186 185 ... 158 156 159]
...
[180 175 185 ... 134 117 87]
[180 175 174 ... 140 154 102]
[162 177 167 ... 184 193 131]]
實例代碼
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
img = Image.open('work/dog.png')
image = np.array(img)
print('圖片shape:')
print(image.shape)
print('輸出爲:')
print(image)
圖片shape:
(32, 32, 3)
輸出爲:
[[[180 191 203]
[180 188 199]
[196 200 209]
...
[160 167 184]
[156 163 182]
[162 170 188]]
[[175 189 197]
[179 190 199]
[193 193 198]
...
[157 164 180]
[154 161 179]
[157 164 182]]
[[177 195 202]
[186 199 208]
[185 181 182]
...
[158 165 182]
[156 163 181]
[159 166 184]]
...
[[180 193 206]
[175 187 199]
[185 191 199]
...
[134 120 110]
[117 102 90]
[ 87 74 61]]
[[180 193 206]
[175 184 200]
[174 175 182]
...
[140 124 112]
[154 139 127]
[102 89 77]]
[[162 172 187]
[177 186 198]
[167 169 171]
...
[184 173 169]
[193 181 175]
[131 115 106]]]
# 原圖
plt.imshow(image)
plt.axis('off')
plt.show()
# 單通道圖片
s = image[:, :, 0]
print('單通道圖片shape:')
print(s.shape)
print('輸出爲:')
print(s)
# 結果
單通道圖片shape:
(32, 32)
輸出爲:
[[180 180 196 ... 160 156 162]
[175 179 193 ... 157 154 157]
[177 186 185 ... 158 156 159]
...
[180 175 185 ... 134 117 87]
[180 175 174 ... 140 154 102]
[162 177 167 ... 184 193 131]]
# 查看各個通道,通道0
plt.imshow(image[:, :, 0], cmap='gray')
plt.axis('off')
plt.show()
# 查看各個通道,通道1
plt.imshow(image[:, :, 1], cmap='gray')
plt.axis('off')
plt.show()
# 查看各個通道,通道2
plt.imshow(image[:, :, 2], cmap='gray')
plt.axis('off')
plt.show()
本文分享自微信公衆號 - 飛槳PPDB(Tsing_Liu)。
如有侵權,請聯繫 [email protected] 刪除。
本文參與“OSC源創計劃”,歡迎正在閱讀的你也加入,一起分享。