np.tensordot 的理解和使用

Numpy是使用最廣的科學計算庫,對於多維數組的操作更是在實踐中用的最多,而且也是比較困惑的地方,但是用好了事半功倍,今天講一下numpy的 tensordot 的使用,這個函數在卷積神經網絡的卷積中用到。

數組的基本屬性

數組基本屬性:維度、形狀、strides(跨越數組各個維度所需要經過的字節數)、數組元素個數、元素佔用字節數、數組佔用空間,用以下例子說明:

>>> X = np.random.randint(0,9,(3,4,5))
>>> X
array([[[5, 1, 3, 6, 5],
        [5, 1, 8, 0, 5],
        [8, 5, 7, 8, 5],
        [8, 1, 5, 1, 4]],

       [[7, 7, 7, 7, 6],
        [0, 3, 4, 4, 6],
        [8, 4, 2, 1, 1],
        [6, 3, 4, 5, 4]],

       [[0, 2, 8, 0, 7],
        [6, 5, 8, 2, 2],
        [0, 1, 2, 3, 5],
        [7, 8, 7, 7, 6]]])
>>> X.ndim
3
>>> X.shape
(3, 4, 5)
>>> X.strides
(160, 40, 8)
>>> X.size
60
>>> X.itemsize
8
>>> X.nbytes
480

多維數組軸向取值

數組的取值看似簡單但是在高緯度下,還是需要注意一下取法.
最原始取法,如取第一個元素

>>> X[0][0][0]
5

按軸取值則不同,取出來的值可能是數組,仍以上述爲例,X.shape爲(3,4,5),說明是3維數組,或者說有三個軸0,1,2. 第0軸上3個元素,第1軸上4個元素,第2軸上5個元素,如果要取軸上元素如何寫?看以下例子。以下取第0軸第一個元素。

>>> X[0]
array([[5, 1, 3, 6, 5],
       [5, 1, 8, 0, 5],
       [8, 5, 7, 8, 5],
       [8, 1, 5, 1, 4]])
>>> X[1]
array([[7, 7, 7, 7, 6],
       [0, 3, 4, 4, 6],
       [8, 4, 2, 1, 1],
       [6, 3, 4, 5, 4]])
>>> X[2]
array([[0, 2, 8, 0, 7],
       [6, 5, 8, 2, 2],
       [0, 1, 2, 3, 5],
       [7, 8, 7, 7, 6]])
>>> X[4]
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
IndexError: index 4 is out of bounds for axis 0 with size 3

取1軸上的元素

>>> X[:,0,:]
array([[5, 1, 3, 6, 5],
      [7, 7, 7, 7, 6],
      [0, 2, 8, 0, 7]])
>>> X[:,5,:]
Traceback (most recent call last):
 File "<stdin>", line 1, in <module>
IndexError: index 5 is out of bounds for axis 1 with size 4

可以看到,按軸取出的元素實際上是一個子數組!

Tensordot的使用

進入正題, 運行如下代碼:

>>> np.random.seed(10)
>>> A = np.random.randint(0,9,(3,4,5))
>>> B = np.random.randint(0,9,(4,5,2))
>>> np.tensordot(A, B, [(1,2), (0,1)])
array([[233,  89],
       [250, 234],
       [199, 244]])

解釋:
(1,2) 是對A而言,不是取第1,2軸,而是除去1,2 軸,所以要取的是第0軸
(0,1) 是對B而言,不是取第0,1軸,而是除去0,1 軸,所以要取的是第2軸

以上兩句是精華

A的形狀是(3,4,5),第0軸上有3個元素,取法上面講了;B的形狀(4,5,2),第2軸上有2個元素,所以結果形狀是(3,2)

Tensordot 的作用就是把取出的子數組做點乘操作,即是 np.sum(a*b) 操作。
我們來驗證一下,上述的說法看結果形狀(3,2)的第一個元素:A第0軸上第一個元素與B第2軸上的第一個元素點乘。

>>> A[0]
array([[4, 0, 1, 0, 1],
       [8, 0, 8, 6, 4],
       [3, 0, 4, 6, 8],
       [1, 8, 4, 1, 3]])
>>> B[:,:,0]
array([[8, 2, 5, 2, 3],
       [4, 0, 3, 2, 0],
       [0, 0, 1, 0, 5],
       [4, 6, 2, 3, 6]])
>>> np.sum(A[0]*B[:,:,0])
233

結果完全正確!就是這麼簡單,多說都是廢話!

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章