主要藉助tf.diag_part
和tf.matrix_diag
兩個方法來將方陣對角線置0.
- tf.diag_part
函數返回tensor的對角線元素in : inputs = [[1,2,3,4], [2,3,4,5], [3,4,5,6], [4,5,6,7]] in : sess.run(tf.diag_part(inputs)) out: array([1, 3, 5, 7], dtype=int32)
- tf.matrix_diag
構造對角線矩陣# 對角線元素 in : x = tf.diag_part(inputs) in: matrix = tf.matrix_diag(x) # 原矩陣減去對角矩陣,即可實現對角線元素置0 in: sess.run(inputs- matrix) out: array([[0, 2, 3, 4], [2, 0, 4, 5], [3, 4, 0, 6], [4, 5, 6, 0]], dtype=int32)