面對不同維度大小矩陣乘法操作的處理(Tensorflow)

遇到的問題:
面對矩陣的大小不同的兩個矩陣,其中一個矩陣如何根據另一個矩陣的要求實現相應的行或列縮放。目標效果如下所示:

x:(2,2,3)
[[[ 1.  2.  3.],
  [ 4.  5.  6.]],
 [[ 7.  8.  9.],
  [10. 11. 12.]]]

w:(2,2)
[[0.5, 0.4],
   [0.1, 0.2]]

x*w:(2,2,3)
[[[0.5 1.  1.5]
  [1.6 2.  2.4]]
 [[0.7 0.8 0.9]
  [2.  2.2 2.4]]]

上面的效果,如果只利用點乘(w * x)乘法(tf.matmul(w, x))操作是無法完成的,需要利用到矩陣的維度變換。具體處理流程爲:

  1. 對w進行維度擴張
    w(2,2) --> w(2,1,2)
  2. 將x的第二維和第三維變換
    x(2,2,3) -->x(2,3,2)
  3. 這時候再進行矩陣點乘操作,才能得到上面的效果。
    具體代碼爲:
# tensorflow的點乘
def test2():
    # a = tf.constant([[1, 2, 3, 4], [5, 6, 7, 8]])
    # [2, 2, 3] [2, 2]
    a1 = np.array([[[1.0,2.0,3.0],[4.0,5.0,6.0]],
                       [[7.0,8.0,9.0], [10.0,11.0,12.0]]])
    w = np.array([[0.5, 0.4],
                  [0.1, 0.2]])
    a1 = tf.convert_to_tensor(a1)
    w = tf.convert_to_tensor(w)
    # 
    # y = w * a1
    a_trans = tf.transpose(a1, [0, 2, 1])
    w = tf.expand_dims(w, 1)
    y = tf.multiply(a_trans, w)
    y = tf.transpose(y, [0, 2, 1])
    # y = a_trans*w
    with tf.Session() as sess:
        init = tf.global_variables_initializer()
        sess.run(init)
        print("x:")
        print(sess.run(a1))
        print("w")
        print(sess.run(w))
        print("x*w:")
        print(sess.run(y))
test2()

注意事項:
(1) 點乘,只有在w的列爲1或與x的列相等時,才能進行點乘運算;
(2) 乘法, 只有前一個矩陣的最後一維和後面一個矩陣的第一維相等時,才能進行乘法操作;

打個小廣告: 歡迎關注本人github: https://github.com/wuxiaoxiaoer
隨時會有新想法,或技術更新,尤其是假新聞方面的研究。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章