tensorflow筆記 ---tf.nn.conv3d使用

1. 視覺角度

我們首先先通過一張圖來直觀的看看2D與3D卷積的區別:

標題

從圖p0116中(只包含一個卷積核)我們可以看出,對於:

    2D convolution: 使用場景一般是單通道的數據(例如MNIST),輸出也是單通道,對整個通道同時執行卷積操作;

    2D convolution on multiple frames: 使用場景一般是多通道的數據(例如cifar-10),輸出也是單通道,對整個通道同時執行卷積操作;
        2D卷積在執行時是在各自的通道中共享卷積核;

    3D convolution: 使用場景一般是多幀(單/多通道)的frame-like數據(視頻幀),且輸出也是多幀,依次對連續k幀的整個通道同時執行卷積操作;

        3D卷積在執行時不僅在各自的通道中共享卷積核,而且在各幀(連續k幀)之間也共享卷積核;

標題

2. 計算角度

話說數無形時少直覺,形少數時難入微。在我們從視覺角度觀察之後,我們再來從計算的角度看看3D卷積到底是怎麼在工作。

假設現在有一個3幀的畫面,且每一幀有2個通道,在時間維度的跨度爲2幀,卷積核的寬度爲3

  • 我們首先再次從視覺的角度看看這個結果:
  • 標題

     

計算結果:

由於在時間維度的跨度爲2幀,且每幀有2個通道,所以從“矩陣”個數來看的話,我們的卷積核應該有4矩陣。

標題

 

3. tensorflow示例

3.1 接口介紹

在知道3D CNN的原理之後,我們現在來看怎麼用tensorflow提供的接口來實現上面的計算操作。首先根據上面的示例,我們有了下列參數:

輸入數據:

  • batch = 1;
  • in_depth = 3; 序列長度
  • in_channels = 2; 每一幀的通道數
  • in_height = 4;
  • in_width = 4;

卷積核:

  • filter_depth = 2; 時間維度的連續跨度
  • filter_height = 3;
  • filter_width = 3;
  • in_channels = 2; 輸入時每幀的通道,必須核輸入數據的通道一樣
  • out_channels = 1;卷積核的個數,對應的就是輸出之後每幀的通道數
def conv3d(input, filter, strides, padding):

這是conv3d的接口,其主要接收4個參數:

input : 輸入,其格式爲[batch, in_depth, in_height, in_width, in_channels].

filter: 卷積核,其格式爲[filter_depth, filter_height, filter_width, in_channels,out_channels]

 

strides: 移動步長[1,1,1,1,1]即可

padding: 是否padding

最後,其輸出結果的格式同輸入,也爲[batch, in_depth, in_height, in_width, in_channels].

3.2 生成數據和實現

在介紹完接口後,只需要給定數據即可了。爲了驗證第二節中實驗,我們下面先生成數據,然後再進行卷積。

生成數據:

image_in_man = np.linspace(1, 96, 96).reshape(1, 3, 2, 4, 4)  
# [batch, in_depth, in_channels, in_height, in_width]

image_in_tf = image_in_man.transpose(0, 1, 3, 4, 2)  
# [batch, in_depth, in_height, in_width, in_channels].

值得注意的是,爲了查看我們生成的數據,我們將inchannels這個維度放在了第2個(從0開始)維度,因爲這樣看才直觀(詳見tf.nn.conv2d 你真的會用麼),但在餵給conv3d是要轉成其接收的格式

卷積核:

weight_in_man = np.array(
    [1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0,
     1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0]).reshape(1, 2, 2, 3, 3) 
# [out_channels,filter_depth, in_channels,filter_height, filter_width, ]

weight_in_tf = weight_in_man.transpose(1, 3, 4, 2,0)  
# [filter_depth, filter_height, filter_width, in_channels,out_channels]

計算:

import tensorflow as tf
import numpy as np

image_in_man = np.linspace(1, 96, 96).reshape(1, 3, 2, 4, 4)
# [batch, in_depth, in_channels, in_height, in_width]
image_in_tf = image_in_man.transpose(0, 1, 3, 4, 2)
# [batch, in_depth, in_height, in_width, in_channels].
# shape:[1,2,4,4,2]
weight_in_man = np.array(
    [1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0,
     1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0]).reshape(1, 2, 2, 3, 3)  # 1,3,4,2,0
weight_in_tf = weight_in_man.transpose(1, 3, 4, 2, 0)
# [filter_depth, filter_height, filter_width, in_channels,out_channels]
# shape: [2,3,3,2,1]
print(image_in_man)
print(weight_in_man)

x = tf.placeholder(dtype=tf.float32, shape=[1, 3, 4, 4, 2], name='x')
w = tf.placeholder(dtype=tf.float32, shape=[2, 3, 3, 2, 1], name='w')
conv = tf.nn.conv3d(x, w, strides=[1, 1, 1, 1, 1], padding='VALID')
with tf.Session() as sess:
    r_in_tf = sess.run(conv, feed_dict={x: image_in_tf, w: weight_in_tf})
    # [batch, in_depth, in_height, in_width, in_channels].
    print(r_in_tf.shape)
    r_in_man = r_in_tf.transpose(0, 1, 4, 2, 3)
    # [batch, in_depth,in_channels,in_height, in_width].
    print(r_in_man)

結果:

[[[[[ 1.  2.  3.  4.]
    [ 5.  6.  7.  8.]
    [ 9. 10. 11. 12.]
    [13. 14. 15. 16.]]

   [[17. 18. 19. 20.]
    [21. 22. 23. 24.]
    [25. 26. 27. 28.]
    [29. 30. 31. 32.]]]


  [[[33. 34. 35. 36.]
    [37. 38. 39. 40.]
    [41. 42. 43. 44.]
    [45. 46. 47. 48.]]

   [[49. 50. 51. 52.]
    [53. 54. 55. 56.]
    [57. 58. 59. 60.]
    [61. 62. 63. 64.]]]


  [[[65. 66. 67. 68.]
    [69. 70. 71. 72.]
    [73. 74. 75. 76.]
    [77. 78. 79. 80.]]

   [[81. 82. 83. 84.]
    [85. 86. 87. 88.]
    [89. 90. 91. 92.]
    [93. 94. 95. 96.]]]]]
[[[[[1 0 0]
    [0 1 0]
    [1 0 0]]

   [[0 1 0]
    [0 0 1]
    [0 0 0]]]


  [[[1 1 0]
    [0 0 0]
    [0 1 0]]

   [[0 1 1]
    [0 1 0]
    [0 0 0]]]]]
    
  
   (1, 2, 2, 2, 1)
[[[[[321. 332.]
    [365. 376.]]]


  [[[673. 684.]
    [717. 728.]]]]]

 

 

 

 

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章