遇到的問題:
面對矩陣的大小不同的兩個矩陣,其中一個矩陣如何根據另一個矩陣的要求實現相應的行或列縮放。目標效果如下所示:
x:(2,2,3)
[[[ 1. 2. 3.],
[ 4. 5. 6.]],
[[ 7. 8. 9.],
[10. 11. 12.]]]
w:(2,2)
[[0.5, 0.4],
[0.1, 0.2]]
x*w:(2,2,3)
[[[0.5 1. 1.5]
[1.6 2. 2.4]]
[[0.7 0.8 0.9]
[2. 2.2 2.4]]]
上面的效果,如果只利用點乘(w * x)
和乘法(tf.matmul(w, x))
操作是無法完成的,需要利用到矩陣的維度變換。具體處理流程爲:
- 對w進行維度擴張
w(2,2) --> w(2,1,2) - 將x的第二維和第三維變換
x(2,2,3) -->x(2,3,2) - 這時候再進行矩陣
點乘
操作,才能得到上面的效果。
具體代碼爲:
# tensorflow的點乘
def test2():
# a = tf.constant([[1, 2, 3, 4], [5, 6, 7, 8]])
# [2, 2, 3] [2, 2]
a1 = np.array([[[1.0,2.0,3.0],[4.0,5.0,6.0]],
[[7.0,8.0,9.0], [10.0,11.0,12.0]]])
w = np.array([[0.5, 0.4],
[0.1, 0.2]])
a1 = tf.convert_to_tensor(a1)
w = tf.convert_to_tensor(w)
#
# y = w * a1
a_trans = tf.transpose(a1, [0, 2, 1])
w = tf.expand_dims(w, 1)
y = tf.multiply(a_trans, w)
y = tf.transpose(y, [0, 2, 1])
# y = a_trans*w
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
print("x:")
print(sess.run(a1))
print("w")
print(sess.run(w))
print("x*w:")
print(sess.run(y))
test2()
注意事項:
(1) 點乘,只有在w的列爲1或與x的列相等時,才能進行點乘運算;
(2) 乘法, 只有前一個矩陣的最後一維和後面一個矩陣的第一維相等時,才能進行乘法操作;
打個小廣告: 歡迎關注本人github: https://github.com/wuxiaoxiaoer
隨時會有新想法,或技術更新,尤其是假新聞方面的研究。