神經網絡的入口——fluid.data

前面筆者已經提到了,paddlepaddle官方文檔的一大特點——文檔寫的好,就是看不懂。 在這裏沒有“詆譭”的意思,畢竟官方文檔寫的嚴謹是應該的。


開篇

在學習過程中,我們經常會用到paddle.fluid.data,這個方法定義的是網絡的輸入層,可以說是一個網絡的入口。fluid.data有一個重要的參數shape,通常需要根據定義的reader中的數據shape,來確定該shape的值。在學習過程中,案例代碼中shape參數值通常都是確定好的,學習者可以很好的運行代碼。但是拓展一下,需要自己來確定shape值的時候,很多初學者就會有霧裏看花的感覺,無從下手。基於筆者踩過的坑,和大家一起學習。

官方文檔中的fluid.data

首先看一下在官方文檔中如何說明shape:

paddle.fluid.data(name, shape, dtype='float32', lod_level=0)
  • shape (list|tuple)- 聲明維度信息的list或tuple。在示例代碼中也給出瞭解釋:
# Creates a variable with fixed size [3, 2, 1]
# User can only feed data of the same shape to x
x = fluid.data(name='x', shape=[321], dtype='float32')

# Creates a variable with changable batch size -1.
# Users can feed data of any batch size into y,
# but size of each data sample has to be [2, 1]
y = fluid.data(name='y', shape=[-121], dtype='float32')

大意就是:

  • x的shape中第一個維度是固定的,則feed的數據要和該shape完全一致
  • y的shape中第一個維度爲-1,則可以是任意值,feed的數據第一個維度是可變的,但是同樣要滿足數組運算的法則

深入解讀

文檔中給出的信息就這麼多,不知道您有沒有明白。不明白沒關係,下面再詳細說一下筆者的理解。

對於固定的shape很好理解,比如一張shape爲m * m * 3的彩色圖片,在fluid.data中shape定義爲[3, m, m],則表示每次輸入一個三通道,每個通道爲m * m的數組。

對於可變shape,通常應用於每次輸入的是一組batch數據。比如在fluid.data中shape定義爲[-1, 3, m, m],則表示爲每次輸入一個batch size大小的三通道,每個通道爲m * m的數組。

筆者在學習過程中經常思考一個問題,(m, m, 3)表示彩色圖片通道數在後面,然而在fluid.data中如果沒有batch,則用(3, m, m)表示輸入的是三通道的彩色圖片。爲什麼會這樣呢?

答案就在於,在卷積過程中,卷積計算是按照通道進行計算的。3通道的圖片會和3通道的卷積核進行卷積計算。

拓展延伸

對於初學者可能對圖片如何用數組表示還不是非常清楚,筆者在這裏詳細介紹一下。先看一下原圖片:

該圖片的shape爲:(32, 32, 3),長寬爲32像素,每個像素點由類似於[180 191 203]的數組表示,該像素點數組的三個值表示RGB三個通道,每個通道取值範圍爲0-255。

圖片輸出爲數組則爲:

[[[180 191 203]
  [180 188 199]
  [196 200 209]
  ...
  [160 167 184]
  [156 163 182]
  [162 170 188]]

 [[175 189 197]
  [179 190 199]
  [193 193 198]
  ...

單通道圖片爲:

單通道圖片shape:(32, 32),這裏每個像素由一個0-255的值表示。

圖片輸出爲數組則爲:

[[180 180 196 ... 160 156 162]
[175 179 193 ... 157 154 157]
[177 186 185 ... 158 156 159]
...
[180 175 185 ... 134 117  87]
[180 175 174 ... 140 154 102]
[162 177 167 ... 184 193 131]]

實例代碼

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

img = Image.open('work/dog.png')
image = np.array(img)
print('圖片shape:')
print(image.shape)
print('輸出爲:')
print(image)
圖片shape:
(32323)
輸出爲:
[[[180 191 203]
  [180 188 199]
  [196 200 209]
  ...
  [160 167 184]
  [156 163 182]
  [162 170 188]]

 [[175 189 197]
  [179 190 199]
  [193 193 198]
  ...
  [157 164 180]
  [154 161 179]
  [157 164 182]]

 [[177 195 202]
  [186 199 208]
  [185 181 182]
  ...
  [158 165 182]
  [156 163 181]
  [159 166 184]]

 ...

 [[180 193 206]
  [175 187 199]
  [185 191 199]
  ...
  [134 120 110]
  [117 102  90]
  [ 87  74  61]]

 [[180 193 206]
  [175 184 200]
  [174 175 182]
  ...
  [140 124 112]
  [154 139 127]
  [102  89  77]]

 [[162 172 187]
  [177 186 198]
  [167 169 171]
  ...
  [184 173 169]
  [193 181 175]
  [131 115 106]]]
# 原圖
plt.imshow(image)
plt.axis('off')
plt.show()
# 單通道圖片
s = image[:, :, 0]
print('單通道圖片shape:')
print(s.shape)
print('輸出爲:')
print(s)
# 結果
單通道圖片shape:
(3232)
輸出爲:
[[180 180 196 ... 160 156 162]
 [175 179 193 ... 157 154 157]
 [177 186 185 ... 158 156 159]
 ...
 [180 175 185 ... 134 117  87]
 [180 175 174 ... 140 154 102]
 [162 177 167 ... 184 193 131]]
# 查看各個通道,通道0
plt.imshow(image[:, :, 0], cmap='gray')
plt.axis('off')
plt.show()
# 查看各個通道,通道1
plt.imshow(image[:, :, 1], cmap='gray')
plt.axis('off')
plt.show()
# 查看各個通道,通道2
plt.imshow(image[:, :, 2], cmap='gray')
plt.axis('off')
plt.show()

本文分享自微信公衆號 - 飛槳PPDB(Tsing_Liu)。
如有侵權,請聯繫 [email protected] 刪除。
本文參與“OSC源創計劃”,歡迎正在閱讀的你也加入,一起分享。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章