關於tensorflow中feature_column.embedding_column()的筆記

 一直好奇embedding_column(),在處理每個element中有多個元素時,是怎麼映射的。比如下面的color,每個color有兩個元素,那麼embedding_column()是怎麼查表對每個element進行映射的?

A、對於每個element中只有一個元素,直接從embedding_column()生成的embedding表中,按照元素映射的編號查表,即可得到每個元素的embedding。

def test_embedding():
    color_data = {'color': [['G'], ['B'], ['B'], ['R']]}  # 4行樣本

    color_column = feature_column.categorical_column_with_vocabulary_list(
        'color', ['R', 'G', 'B'], dtype=tf.string, default_value=-1
    )
   
    color_embeding = feature_column.embedding_column(color_column, 7)
    color_embeding_dense_tensor = feature_column.input_layer(color_data, [color_embeding])
    builder = _LazyBuilder(color_data)
    color_column_tensor = color_column._get_sparse_tensors(builder)
    
    with tf.Session() as session:
        session.run(tf.global_variables_initializer())
        session.run(tf.tables_initializer())
        print(session.run([color_column_tensor.id_tensor]))

    with tf.Session() as session:
        session.run(tf.global_variables_initializer())
        session.run(tf.tables_initializer())
        print('embeding' + '_' * 40)
        print(session.run([color_embeding_dense_tensor]))
test_embedding()
[SparseTensorValue(indices=array([[0, 0],
       [1, 0],
       [2, 0],
       [3, 0]]), values=array([1, 2, 2, 0]), dense_shape=array([4, 1]))]
embeding________________________________________
[array([[-0.5110804 ,  0.40523612,  0.11027244, -0.18846236,  0.24071017,
         0.06515816,  0.09236987],
       [-0.39017957,  0.14889447, -0.34367365,  0.32619542,  0.46648583,
         0.3640195 , -0.12630698],
       [-0.39017957,  0.14889447, -0.34367365,  0.32619542,  0.46648583,
         0.3640195 , -0.12630698],
       [ 0.26945257, -0.22851205, -0.02635379, -0.01459604,  0.32915694,
        -0.08775295,  0.24930897]], dtype=float32)]

B、當element中有兩個或多個元素時,embedding_column()輸出的是每個元素在look up table 中的embedding 向量的平均。

下面舉例子驗證:

def test_embedding():
    color_data = {'color': [['G','G'], ['G','B'], ['B','B'], ['A','R']]}  # 4行樣本

    color_column = feature_column.categorical_column_with_vocabulary_list(
        'color', ['R', 'G', 'B'], dtype=tf.string, default_value=-1
    )
   
    color_embeding = feature_column.embedding_column(color_column, 7)
    color_embeding_dense_tensor = feature_column.input_layer(color_data, [color_embeding])
    builder = _LazyBuilder(color_data)
    color_column_tensor = color_column._get_sparse_tensors(builder)
    
    with tf.Session() as session:
        session.run(tf.global_variables_initializer())
        session.run(tf.tables_initializer())
        print(session.run([color_column_tensor.id_tensor]))

    with tf.Session() as session:
        session.run(tf.global_variables_initializer())
        session.run(tf.tables_initializer())
        print('embeding' + '_' * 40)
        print(session.run([color_embeding_dense_tensor]))
test_embedding()
[SparseTensorValue(indices=array([[0, 0],
       [0, 1],
       [1, 0],
       [1, 1],
       [2, 0],
       [2, 1],
       [3, 0],
       [3, 1]]), values=array([ 1,  1,  1,  2,  2,  2, -1,  0]), dense_shape=array([4, 2]))]
embeding________________________________________
[array([[ 0.73096615,  0.10957518, -0.1657246 ,  0.17001966, -0.22539927,
        -0.50863737, -0.37135717],
       [ 0.6246785 ,  0.02085713, -0.04949204, -0.14722404,  0.09994595,
        -0.03458649, -0.04306053],
       [ 0.51839083, -0.06786092,  0.06674052, -0.46446773,  0.42529118,
         0.4394644 ,  0.28523612],
       [ 0.34275472, -0.08808891,  0.08895188,  0.24801058, -0.12121174,
         0.26907632,  0.3819868 ]], dtype=float32)]

[ 0.73096615,  0.10957518, -0.1657246 ,  0.17001966, -0.22539927,
        -0.50863737, -0.37135717]

GG=np.array([ 0.73096615,  0.10957518, -0.1657246 ,  0.17001966, -0.22539927,
        -0.50863737, -0.37135717]) # ['G','G']
GB=np.array([ 0.6246785 ,  0.02085713, -0.04949204, -0.14722404,  0.09994595,
        -0.03458649, -0.04306053]) #['G','B']
BB=np.array([ 0.51839083, -0.06786092,  0.06674052, -0.46446773,  0.42529118,
         0.4394644 ,  0.28523612]) #['B','B']

(GG+BB)/2
array([ 0.62467849,  0.02085713, -0.04949204, -0.14722404,  0.09994595,
       -0.03458649, -0.04306053])
# 輸出整好等於GB

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章