深度学习 tensorflow 三维矩阵乘法(batch 迭代必须搞懂的矩阵乘法,维度增加)

import tensorflow as tf

# 2 * 2 * 2 的embedding 矩阵 , 一个batch,每个样本有f个字段,每个字段有k维
# 那么矩阵的大小就是batch * f * k
embedding_index = tf.constant([[[0.1,0.2],
                                [0.3,0.2]],
                              
                               [[0.8,0.2],
                                [0.5,0.4]]
                              ])
#  2 * 2 的系数矩阵
X_sparse  = tf.constant([[1.0,2.0],
                         [3.0,4.0]])

# 为了和embedding_index 相乘,需要增加一维度, 增加一维有下面两种写法
# 增加维度的方法1:sparse_value = tf.reshape(X_sparse, shape=[-1, 2, 1])
# 下面是增加维度的方法2
sparse_value = tf.expand_dims(X_sparse,2)

embedding_matmul = tf.matmul(embedding_index, sparse_value)

embedding_multiply = tf.multiply(embedding_index, sparse_value)


print(embedding_index)

print(sparse_value)

with tf.Session() as sess:
    
    print(sess.run(X_sparse))
    
    print("-"*10)    
    print(sess.run(sparse_value))
    
    print("-"*10)
    print(sess.run(embedding_index))
    
    print("-"*10)
    print(sess.run(embedding_matmul))
    
    print("-"*10)
    print(sess.run(embedding_multiply))

 

 

 

Tensor("Const_4:0", shape=(2, 2, 2), dtype=float32)
Tensor("ExpandDims_2:0", shape=(2, 2, 1), dtype=float32)
[[1. 2.]
 [3. 4.]]
----------
[[[1.]
  [2.]]

 [[3.]
  [4.]]]
----------
[[[0.1 0.2]
  [0.3 0.2]]

 [[0.8 0.2]
  [0.5 0.4]]]
----------
[[[0.5       ]
  [0.70000005]]

 [[3.2       ]
  [3.1       ]]]
----------
[[[0.1 0.2]
  [0.6 0.4]]

 [[2.4 0.6]
  [2.  1.6]]]

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章