使用預訓練Embedding,finetune DSSM模型

Milvus 小編:本文轉載自公衆號 Python 科技園,作者王多魚。

1. 前言

DSSM模型是點擊預估領域的經典召回模型,是由 “用戶”端 和 “商品”端 兩個塔式結構組成。“用戶”端 和 “商品”端 兩個子塔分別生成最終的 “用戶” Embedding 和 “商品” Embedding。在線上應用時,實時生成 “用戶” 端的 Embedding(因爲用戶的行爲是動態的),在線從數據庫中(例如:HBase, Redis)獲取 “商品” 端的 Embedding(商品的Embedding生成後直接存儲到數據庫中,不需要實時生成)。然後通過NN的方式,檢索出用戶感興趣的top-N商品候選集。

在訓練模型時,如果某一場景的數據量較少,訓練出的模型效果大概率不理想,容易造成模型不收斂的情況。最佳的解決方案:即採用預訓練的方式,通過微調該場景下所構建的模型。例如:支付寶APP上的某個商品推薦位置,用戶產生的點擊或購買行爲較少;但是在淘寶APP上用戶的行爲是海量的。可以通過淘寶APP上的數據訓練出 “用戶ID” 的 Embedding 和 “商品ID” 的 Embedding,然後使用該 Embedding 在支付寶APP上的商品推薦場景下對模型進行微調。

 

2. 構建DSSM模型

 

(1)加載模塊

import sys
import time
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Lambda, Activation, Multiply, Dot
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import EarlyStopping, TensorBoard

from keras.utils import plot_model

 

(2)構建DSSM模型

def build_model():
    n_pin_vec = 128
    n_sku_vec = 128

    pin_vec = Input(shape=(n_pin_vec, ), dtype = 'float32')
    sku_vec = Input(shape=(n_sku_vec, ), dtype = 'float32')

    pin_part = Dense(64, activation='relu')(pin_vec)
    sku_part = Dense(64, activation='relu')(sku_vec)

    
    prod = Multiply()([pin_part, sku_part])
    prob = Dense(1, activation='sigmoid')(prod)

    model = Model(inputs = [pin_vec, sku_vec], outputs = prob)

    model.compile(optimizer = 'adam', loss = 'binary_crossentropy', metrics = ['accuracy'])
    
    model.__setattr__("user_input", pin_vec)
    model.__setattr__("item_input", sku_vec)
    model.__setattr__("user_embedding", pin_part)
    model.__setattr__("item_embedding", sku_part)

    return model

 

其中:“用戶”端的 Embedding 和 “商品”端的 Embedding 向量維度均爲128維。(輸入的Embedding向量是已經預訓練完畢的Embedding。例如通過word2vec模型對用戶行爲建模,即可得到“商品”端的 Embedding;然後通過 avg(用戶產生行爲的商品的Embedding),即可得到“用戶”端的 Embedding)

 

查看一下模型的summary信息。

model = build_model()
print(model.summary())

 

所構造的DSSM模型結構如下所示。由於未對用戶和商品的ID進行Embedding操作,所以該模型的參數較少。

 

打印一下模型的結構。

plot_model(model, to_file='finetune_dssm_model.png')

 

(3)加載數據

考慮到數據量較大,所以採用 generator 模式對數據進行處理,防止加載全部數據,撐爆內存。

def file_generator(input_path, batch_size = None):

    while True:
        with open(input_path, 'r') as f:

            pin_vec_array, sku_vec_array, y_array = [], [], []

            cnt = 0 
            for line in f:
                buf = line[:-1].split(',')

                pin_vec = np.array(buf[1:129], dtype=np.float32)
                sku_vec = np.array(buf[129:], dtype=np.float32)
                y = int(buf[0])

                pin_vec_array.append(pin_vec)
                sku_vec_array.append(sku_vec)
                y_array.append(y)
    
                cnt += 1

                if cnt % batch_size == 0:
                    pin_vec_array = np.array(pin_vec_array)
                    sku_vec_array = np.array(sku_vec_array)
                    y_array = np.array(y_array)
                    
                    yield [pin_vec_array, sku_vec_array], y_array

                    cnt = 0
                    pin_vec_array, sku_vec_array, y_array = [], [], []

 

本文使用小數據量進行試驗,數據格式如下:

1,0.111400,0.298000,0.520000,-2.107100,-0.658500,-0.060500,-0.755700,-0.317100,0.786800,-0.051100,-0.514300,-0.772700,0.947900,0.045500,-0.146600,0.670900,0.739700,0.715800,0.519000,1.733300,-0.567100,0.475800,0.392100,0.386000,0.038900,-0.267600,-0.597700,0.365000,-1.514600,0.362100,-0.316900,0.873700,-0.208400,-0.079500,-0.401500,-0.040200,-0.545500,0.001900,0.018300,0.836700,-0.154500,-0.114000,0.648800,-0.949100,-0.074600,0.075200,0.846000,-0.234500,0.590100,-1.521400,0.374400,-0.194700,-0.309800,1.297600,0.329300,-1.250700,0.958500,-0.247100,0.083100,-1.150500,-0.535000,0.112800,-1.356800,0.879200,-0.353400,0.034500,0.241300,-0.205700,0.670600,0.633200,-0.368100,-0.754100,-0.153500,-0.475300,0.347100,0.370000,-0.380000,-0.739700,0.471700,-0.177900,0.308500,-0.058100,1.279900,0.776900,-0.088300,-1.248500,-0.973700,-0.211500,-0.210300,0.631500,-0.652400,0.866200,0.464500,-0.682000,-0.627600,-0.598000,-0.119200,0.473700,0.381500,0.567900,0.003600,-0.514900,0.536100,-0.803500,-0.619500,-0.141500,0.010400,1.268600,0.406200,-0.632000,0.250500,-0.218300,-0.168800,0.015000,-1.186700,-0.683500,1.632600,0.430000,-0.098000,0.436500,-0.068900,0.601700,0.006100,0.540800,-0.227800,-1.126100,1.165200,-0.220900,-0.202962,-3.636311,-0.504060,-2.546363,-1.235034,-0.883959,-0.348022,-0.219954,0.907031,-1.482731,0.669218,-0.477431,4.881980,3.885695,0.578319,1.427294,2.173270,-2.765083,0.004624,1.796896,1.087227,0.389897,0.604141,-1.155123,1.274209,-2.239976,-1.858146,3.090227,-0.206842,2.549677,2.601414,-0.692583,0.388238,-0.117103,-2.207036,-3.230492,-3.375904,-1.553133,2.262967,-2.091266,-0.825930,-2.791187,2.190521,-0.433236,-0.217687,-2.277860,-0.432154,-1.141102,-0.850199,-3.686642,2.615366,0.076896,-1.115686,1.734991,-1.578039,1.183485,0.641641,-2.347620,1.625458,-1.123846,1.017014,2.852135,-0.979481,0.912863,0.727238,-0.418464,-0.958715,-0.861919,0.282138,1.843323,0.175354,-1.792245,-1.370620,1.089480,0.778957,-2.377766,0.829453,-2.713742,-3.567303,-1.208078,1.233118,1.125459,4.193498,-2.459454,0.897581,1.001604,0.674028,-1.428830,-0.025545,1.150639,-3.673055,-0.666604,0.064266,0.285329,-1.370663,-0.463825,-0.842921,0.618591,1.990929,0.457696,-2.935576,0.301109,3.309814,-2.633363,-1.209220,-0.564443,-0.663638,1.399326,1.430363,-1.934421,-2.455737,-1.447479,0.263726,-0.861657,0.584651,-2.341039,3.445074,1.608032,0.724370,-0.370727,-2.025292,-0.842234,0.977376,3.447604,2.289111,2.478286,0.241298,-1.674832
0,-0.804500,0.572300,-0.357900,0.472200,1.037200,0.266700,-0.023200,0.858800,-0.484500,-0.782800,0.480700,0.119000,-0.293300,-0.504600,0.374600,-0.039300,0.935600,-1.255600,-0.258700,-0.582000,-1.719200,0.307800,0.052900,0.381800,0.577100,-0.998900,0.060600,0.373900,-0.281600,0.024100,-0.332200,0.038900,0.136100,-0.002500,0.724800,0.038700,-0.148800,1.535200,-0.059800,0.322100,-0.811600,0.363400,-1.402800,0.158200,-0.507700,-0.108200,-0.051600,-0.286800,-0.345700,-0.152300,-0.201400,-0.494600,-0.716300,0.541900,-1.629700,-0.287000,-1.277400,1.244700,0.011400,0.549900,0.883000,-1.100400,-0.700300,-0.079900,-1.227600,0.047900,-0.769000,0.821900,0.783400,0.173500,0.697400,0.499200,0.602800,0.548200,-0.256100,-0.751800,1.143400,0.295100,-0.123700,-0.503200,-0.160300,-0.908800,-0.056600,0.107600,0.436000,0.679800,0.313100,-0.249200,0.779700,0.801200,-1.650800,0.089900,0.026200,-0.338600,-0.115900,0.495700,0.088600,0.526900,0.595000,0.156700,-0.736900,0.558100,-0.095900,0.072100,-0.209400,-0.999600,-0.567300,-0.017400,-0.232500,-0.538800,-0.041200,1.247400,-0.610300,0.085700,0.321900,0.478900,-0.274800,0.074000,-0.387400,-0.306000,0.204200,0.978300,-0.738800,0.267800,0.299300,0.989500,-0.597800,-0.211500,0.302525,0.926751,0.444355,2.095530,0.641599,0.585963,-0.007165,-0.225599,1.195284,0.743535,-0.283189,0.421811,-0.900632,-1.775821,0.194162,-0.131157,2.221316,-0.871263,0.611026,1.586028,0.208971,1.728807,-1.214678,-0.006417,-0.487578,-1.347446,1.257976,-1.105078,-0.641283,2.040870,-1.064334,1.848631,0.021456,1.044769,1.046561,-0.382474,0.511813,1.991464,1.541210,1.197348,-0.132546,-1.227524,-1.825696,0.637844,0.266854,0.627479,-1.939037,1.784560,-1.572687,1.319858,-0.297955,-0.648528,1.552862,-0.390313,-1.862317,-1.434988,1.003443,2.372627,0.048504,-1.178071,0.345171,-0.493632,0.708266,0.439852,1.367206,0.587270,-1.676261,1.519096,2.178505,0.398875,-0.987587,-1.099164,2.224100,-0.032785,-1.974257,-2.476301,1.279583,0.368386,0.118637,-0.390930,0.206159,-1.526931,-0.706359,-0.666684,1.660718,2.577286,2.185187,-0.082288,1.171966,-0.962591,-1.345657,3.024471,0.326179,-1.740565,0.338833,2.163889,-1.306316,0.962814,2.811996,0.795088,0.042636,-1.563679,0.169866,-0.691936,0.281116,-0.114342,-0.654810,-0.018624,-1.712857,-1.027673,0.120613,1.324406,-0.825408,0.978356,-0.286835,1.155605,-0.480432,-0.661304,0.434739,0.736817,-1.921379,1.111957,0.592577,-0.935139,-0.926583,2.585314,-0.798262,-0.515275

 

解釋:第一個數據爲label,1表示正樣本,0表示負樣本;第2列到第129列表示用戶的Embedding數據;第130列到第257列表示商品的Embedding數據;

 

3. 訓練DSSM模型

 

接下來開始訓練DSSM模型。

def train_finetune_dssm(train_path, val_path, model_path, \
    n_train = None, \
    n_val = None):

    model = build_model()

    print("train samples numbers: %s" % n_train)
    print("val samples numbers: %s" % n_val)
    batch_size = 128
    epochs = 2
    
    train_steps_per_epoch = int(n_train / batch_size)
    val_steps_per_epoch = int(n_val / batch_size)
    
    train_generator = file_generator(train_path, batch_size = batch_size)
    val_generator = file_generator(val_path, batch_size = batch_size)

    early_stopping_cb = EarlyStopping(monitor = 'val_loss', patience = 10, restore_best_weights = True) 
    tensorboard_cb = TensorBoard(\
        log_dir = './logs', \
        histogram_freq = 0, \
        write_graph = True, \
        write_grads = True, \
        write_images = True)
        
    
    callbacks = [early_stopping_cb, tensorboard_cb]
    start = time.time()

    history = model.fit_generator(\
        train_generator, \
        steps_per_epoch = train_steps_per_epoch, \
        epochs = epochs, \
        verbose = 1, \
        callbacks = callbacks, \
        validation_data = val_generator, \
        validation_steps = val_steps_per_epoch, \
        max_queue_size = 10, \
        workers = 1, \
        use_multiprocessing = False, \
        shuffle = True, \
        initial_epoch = 0)

    model.save_weights(model_path)

    last = time.time() - start
    print("Train model to %s done! Lasts %.2fs" % (model_path, last))

 

if __name__ == "__main__":
    train_path = "data/train_data"
    val_path = "data/val_data"
    model_path = "data/finetune_dssm.model"
    train_val_summary_path = "data/train_val_summary"

    n_train = 0
    n_val = 0
    fr = open(train_val_summary_path, 'r')
    for line in fr:
        buf = line[:-1].split(',')
        n_train = int(buf[0].split('=')[1])
        n_val = int(buf[1].split('=')[1])
        break
    fr.close()

    train_finetune_dssm(train_path, val_path, model_path, \
        n_train = n_train, \
        n_val = n_val)

 

其中:data/train_data 爲訓練集數據;data/val_data 爲驗證集數據;data/finetune_dssm.model 爲最後訓練完成後的模型;data/train_val_summary 爲訓練集和驗證集數據信息;

模型訓練過程如下圖所示:

 

4. 生成最終的用戶Embedding和商品Embedding

該模型產生的最終用戶Embedding和商品Embedding分別對應 “模型結構圖” 中的 dense_3 和 dense_4。

test_user_vec_embedding = np.array([0.1114, 0.298, 0.52, -2.1071, -0.6585, -0.0605, -0.7557, -0.3171, 0.7868, -0.0511, -0.5143, -0.7727, 0.9479, 0.0455, -0.1466, 0.6709, 0.7397, 0.7158, 0.519, 1.7333, -0.5671, 0.4758, 0.3921, 0.386, 0.0389, -0.2676, -0.5977, 0.365, -1.5146, 0.3621, -0.3169, 0.8737, -0.2084, -0.0795, -0.4015, -0.0402, -0.5455, 0.0019, 0.0183, 0.8367, -0.1545, -0.114, 0.6488, -0.9491, -0.0746, 0.0752, 0.846, -0.2345, 0.5901, -1.5214, 0.3744, -0.1947, -0.3098, 1.2976, 0.3293, -1.2507, 0.9585, -0.2471, 0.0831, -1.1505, -0.535, 0.1128, -1.3568, 0.8792, -0.3534, 0.0345, 0.2413, -0.2057, 0.6706, 0.6332, -0.3681, -0.7541, -0.1535, -0.4753, 0.3471, 0.37, -0.38, -0.7397, 0.4717, -0.1779, 0.3085, -0.0581, 1.2799, 0.7769, -0.0883, -1.2485, -0.9737, -0.2115, -0.2103, 0.6315, -0.6524, 0.8662, 0.4645, -0.682, -0.6276, -0.598, -0.1192, 0.4737, 0.3815, 0.5679, 0.0036, -0.5149, 0.5361, -0.8035, -0.6195, -0.1415, 0.0104, 1.2686, 0.4062, -0.632, 0.2505, -0.2183, -0.1688, 0.015, -1.1867, -0.6835, 1.6326, 0.43, -0.098, 0.4365, -0.0689, 0.6017, 0.0061, 0.5408, -0.2278, -1.1261, 1.1652, -0.2209]).reshape(1, -1)
test_item_vec_embedding = np.array([-0.202962, -3.636311, -0.50406, -2.546363, -1.235034, -0.883959, -0.348022, -0.219954, 0.907031, -1.482731, 0.669218, -0.477431, 4.88198, 3.885695, 0.578319, 1.427294, 2.17327, -2.765083, 0.004624, 1.796896, 1.087227, 0.389897, 0.604141, -1.155123, 1.274209, -2.239976, -1.858146, 3.090227, -0.206842, 2.549677, 2.601414, -0.692583, 0.388238, -0.117103, -2.207036, -3.230492, -3.375904, -1.553133, 2.262967, -2.091266, -0.82593, -2.791187, 2.190521, -0.433236, -0.217687, -2.27786, -0.432154, -1.141102, -0.850199, -3.686642, 2.615366, 0.076896, -1.115686, 1.734991, -1.578039, 1.183485, 0.641641, -2.34762, 1.625458, -1.123846, 1.017014, 2.852135, -0.979481, 0.912863, 0.727238, -0.418464, -0.958715, -0.861919, 0.282138, 1.843323, 0.175354, -1.792245, -1.37062, 1.08948, 0.778957, -2.377766, 0.829453, -2.713742, -3.567303, -1.208078, 1.233118, 1.125459, 4.193498, -2.459454, 0.897581, 1.001604, 0.674028, -1.42883, -0.025545, 1.150639, -3.673055, -0.666604, 0.064266, 0.285329, -1.370663, -0.463825, -0.842921, 0.618591, 1.990929, 0.457696, -2.935576, 0.301109, 3.309814, -2.633363, -1.20922, -0.564443, -0.663638, 1.399326, 1.430363, -1.934421, -2.455737, -1.447479, 0.263726, -0.861657, 0.584651, -2.341039, 3.445074, 1.608032, 0.72437, -0.370727, -2.025292, -0.842234, 0.977376, 3.447604, 2.289111, 2.478286, 0.241298, -1.674832]).reshape(1, -1)

user_embedding_model = Model(inputs=model.user_input, outputs=model.user_embedding)
item_embedding_model = Model(inputs=model.item_input, outputs=model.item_embedding)

user_emb = user_embedding_model.predict(test_user_vec_embedding, batch_size=1)
item_emb = item_embedding_model.predict(test_item_vec_embedding, batch_size=1)

print(user_emb)
print(item_emb)

可以看到新生成的用戶Embedding和商品Embedding,均爲64維。

根據某一用戶的Embedding和商品集合的Embedding數據,使用NN方式檢索用戶感興趣的商品集。可參考:https://github.com/milvus-iohttps://github.com/spotify/annoyhttps://github.com/facebookresearch/faiss

 

5. 結語

這裏強烈推薦 Milvus, Milvus 基於高度優化的 Approximate Nearest Neighbor Search (ANNS) 索引庫構建,包括 faiss、annoy、和 hnswlib 等。可以針對不同使用場景選擇不同的索引類型。還提供了 Python、Java、Go 和 C++ SDK 與 Restful API,簡單易用, 歡迎有需要的小夥伴請到 Milvus 官網與 GitHub 瞭解更多技術細節!

Milvus 官網:https://www.milvus.io/cn/

Milvus GitHub:https://github.com/milvus-io

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章