numpy的axis,看完就懂了

場景

  • 經常用到,但是每碰到一次,就得停頓半天,畢竟這東西有點反人類,如果只是死記硬背,很容易出現問題,而且還不容易排錯

二維矩陣的例子

X矩陣的內容 x.shape=(4,5)

array([[ 0.3200494 ,  0.01722924, -0.10703773,  0.76887001,  1.21976509],
       [ 0.33939746, -1.61387027, -0.32144178, -0.07413717, -1.52419233],
       [-0.77941338, -1.13721978, -0.06767945,  0.09952699,  0.24625685],
       [ 0.42816061, -0.50676614, -0.99743203, -0.32314114, -2.09250823]])

np.max(x, axis=0)

array([ 0.42816061,  0.01722924, -0.06767945,  0.76887001,  1.21976509])
  • x.shape=(4,5), axis=0對應的就是形狀中的4,也就是行
  • 也就是說比較的項目是以行爲單位的,是行與行之間的比較
  • 一共四行,四行分別是ABCD,那麼就將元素A[0], B[0], C[0], D[0]來比較大小,然後發現D[0]最大, 所以輸出結果的第一個元素是0.42816061
  • 依次計算出剩下的計算即可
  • 計算完之後的結果的形狀是(5,), 是將原來的形狀(4,5)中的4給去掉了(也就是0維度的數字)

np.max(x, axis=1)

array([-0.10703773, -1.61387027, -1.13721978, -2.09250823])
  • x.shape=(4,5), axis=0對應的就是5,現在針對的就是每行中的元素了,是行內元素之間的比較
    • 計算完之後的結果的形狀是(4,), 是將原來的形狀(4,5)中的4給去掉了(也就是原來維度1上的數字)

可能更多的人關注的是三維以及三維以上的例子

np.max(x, axis=0),shape=(2, 3, 4)

array([[[ 0.77533663, -1.10778776,  0.49358915, -0.90216646],
        [ 2.24042164,  0.48353235,  0.98318199, -0.14353624],
        [-0.04447856,  0.24223262, -1.36565458,  0.07751463]],

       [[-0.10827103, -1.21993926,  1.90490377,  0.32569059],
        [ 0.2609489 , -0.20085888, -1.30428121,  2.71533268],
        [-0.35308164, -0.64957618, -0.4589211 , -0.76238811]]])
array([[ 0.77533663, -1.10778776,  1.90490377,  0.32569059],
       [ 2.24042164,  0.48353235,  0.98318199,  2.71533268],
       [-0.04447856,  0.24223262, -0.4589211 ,  0.07751463]])
  • x.shape=(2, 3, 4),axis=0就是對應着形狀中的2,也就是說矩陣x裏面有兩個形狀是(3,4)的矩陣,這沒毛病吧?
  • 然後比較的對象就是形狀爲(3,4)的兩個矩陣之間的比較,假設是這兩個小矩陣分別是A 和 B,矩陣之間怎麼比較大小呢? 不慌,按照point-wise的方式比較不就好了,你就對比着看結果對不對吧!
    • A[0][0]和B[0][0]比較大小,
    • A[0][2]和B[0][2]比較大小
    • A[0][3]和B[0][3]比較大小
  • 最終輸出矩陣的形狀爲(3,4),巧不巧,又是把(2,3, 4)中的2給去掉了,也就是axis=0對應的位置
    • 如果爲axis=1呢
      • 輸出結果shape=(2,4)
    • 如果爲axis=2呢
      • 輸出結果shape=(2,3), 這都不是巧合,這是規律,實在不理解,可以先記住

np.max(x, axis=1), x.shape=(2, 3, 4)

array([[ 2.24042164,  0.48353235,  0.98318199,  0.07751463],
       [ 0.2609489 , -0.20085888,  1.90490377,  2.71533268]])
  • 現在axis=1了,對應的是原(2, 3, 4),中的3對吧,3又對應到了行,然後就是行與行的比較
  • 輸出結果的形狀一定是(2,4)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章