tf.newaxis和np.newaxis

# -*- coding: utf-8 -*-
"""
tf.newaxis 和 numpy newaxis
"""
import numpy as np
import tensorflow as tf


if __name__ == '__main__':
    feature = np.array([[1,2,3],
                        [2,4,6]])
    center = np.array([[1,1,1],
                       [0,0,0]])
    
    print("原始數組大小:")
    print(feature.shape)
    print(center.shape)
    
    
    np_feature = feature[:, np.newaxis]  
    np_center = center[np.newaxis, :]
    
    print("添加 np.newaxis 後數組大小:")
    print(np_feature.shape)
    print(np_center.shape)
    
    np_diff = np_feature - np_center
    
    print("矩陣相減,np_diff:")
    print(np_diff)
      
    print('\n*********************\n')
    
    tf_feature = tf.constant(feature)[:,tf.newaxis]
    tf_center = tf.constant(center)[tf.newaxis,:]
    
    print("添加 tf.newaxis 後數組大小:")
    print(tf_feature.shape)
    print(tf_center.shape)
   
    tf_diff = tf_feature - tf_center       
    mask = 1 - tf.eye(2, 2, dtype=tf_diff.dtype)
    diffs = tf_diff * mask[:, :, tf.newaxis]
    
    sess = tf.Session()
    print("矩陣相減,tf_diff:")
    print(sess.run(tf_diff))
    
    print("對角線元素置爲0:")
    print(sess.run(diffs))

結果:

原始數組大小:
(2, 3)
(2, 3)
添加 np.newaxis 後數組大小:
(2, 1, 3)
(1, 2, 3)
矩陣相減,np_diff:
[[[0 1 2]
  [1 2 3]]

 [[1 3 5]
  [2 4 6]]]

*********************

添加 tf.newaxis 後數組大小:
(2, 1, 3)
(1, 2, 3)
矩陣相減,tf_diff:
[[[0 1 2]
  [1 2 3]]

 [[1 3 5]
  [2 4 6]]]
對角線元素置爲0:
[[[0 0 0]
  [1 2 3]]

 [[1 3 5]
  [0 0 0]]]
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章