使用预训练Embedding,finetune DSSM模型

Milvus 小编:本文转载自公众号 Python 科技园,作者王多鱼。

1. 前言

DSSM模型是点击预估领域的经典召回模型,是由 “用户”端 和 “商品”端 两个塔式结构组成。“用户”端 和 “商品”端 两个子塔分别生成最终的 “用户” Embedding 和 “商品” Embedding。在线上应用时,实时生成 “用户” 端的 Embedding(因为用户的行为是动态的),在线从数据库中(例如:HBase, Redis)获取 “商品” 端的 Embedding(商品的Embedding生成后直接存储到数据库中,不需要实时生成)。然后通过NN的方式,检索出用户感兴趣的top-N商品候选集。

在训练模型时,如果某一场景的数据量较少,训练出的模型效果大概率不理想,容易造成模型不收敛的情况。最佳的解决方案:即采用预训练的方式,通过微调该场景下所构建的模型。例如:支付宝APP上的某个商品推荐位置,用户产生的点击或购买行为较少;但是在淘宝APP上用户的行为是海量的。可以通过淘宝APP上的数据训练出 “用户ID” 的 Embedding 和 “商品ID” 的 Embedding,然后使用该 Embedding 在支付宝APP上的商品推荐场景下对模型进行微调。

 

2. 构建DSSM模型

 

(1)加载模块

import sys
import time
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Lambda, Activation, Multiply, Dot
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import EarlyStopping, TensorBoard

from keras.utils import plot_model

 

(2)构建DSSM模型

def build_model():
    n_pin_vec = 128
    n_sku_vec = 128

    pin_vec = Input(shape=(n_pin_vec, ), dtype = 'float32')
    sku_vec = Input(shape=(n_sku_vec, ), dtype = 'float32')

    pin_part = Dense(64, activation='relu')(pin_vec)
    sku_part = Dense(64, activation='relu')(sku_vec)

    
    prod = Multiply()([pin_part, sku_part])
    prob = Dense(1, activation='sigmoid')(prod)

    model = Model(inputs = [pin_vec, sku_vec], outputs = prob)

    model.compile(optimizer = 'adam', loss = 'binary_crossentropy', metrics = ['accuracy'])
    
    model.__setattr__("user_input", pin_vec)
    model.__setattr__("item_input", sku_vec)
    model.__setattr__("user_embedding", pin_part)
    model.__setattr__("item_embedding", sku_part)

    return model

 

其中:“用户”端的 Embedding 和 “商品”端的 Embedding 向量维度均为128维。(输入的Embedding向量是已经预训练完毕的Embedding。例如通过word2vec模型对用户行为建模,即可得到“商品”端的 Embedding;然后通过 avg(用户产生行为的商品的Embedding),即可得到“用户”端的 Embedding)

 

查看一下模型的summary信息。

model = build_model()
print(model.summary())

 

所构造的DSSM模型结构如下所示。由于未对用户和商品的ID进行Embedding操作,所以该模型的参数较少。

 

打印一下模型的结构。

plot_model(model, to_file='finetune_dssm_model.png')

 

(3)加载数据

考虑到数据量较大,所以采用 generator 模式对数据进行处理,防止加载全部数据,撑爆内存。

def file_generator(input_path, batch_size = None):

    while True:
        with open(input_path, 'r') as f:

            pin_vec_array, sku_vec_array, y_array = [], [], []

            cnt = 0 
            for line in f:
                buf = line[:-1].split(',')

                pin_vec = np.array(buf[1:129], dtype=np.float32)
                sku_vec = np.array(buf[129:], dtype=np.float32)
                y = int(buf[0])

                pin_vec_array.append(pin_vec)
                sku_vec_array.append(sku_vec)
                y_array.append(y)
    
                cnt += 1

                if cnt % batch_size == 0:
                    pin_vec_array = np.array(pin_vec_array)
                    sku_vec_array = np.array(sku_vec_array)
                    y_array = np.array(y_array)
                    
                    yield [pin_vec_array, sku_vec_array], y_array

                    cnt = 0
                    pin_vec_array, sku_vec_array, y_array = [], [], []

 

本文使用小数据量进行试验,数据格式如下:

1,0.111400,0.298000,0.520000,-2.107100,-0.658500,-0.060500,-0.755700,-0.317100,0.786800,-0.051100,-0.514300,-0.772700,0.947900,0.045500,-0.146600,0.670900,0.739700,0.715800,0.519000,1.733300,-0.567100,0.475800,0.392100,0.386000,0.038900,-0.267600,-0.597700,0.365000,-1.514600,0.362100,-0.316900,0.873700,-0.208400,-0.079500,-0.401500,-0.040200,-0.545500,0.001900,0.018300,0.836700,-0.154500,-0.114000,0.648800,-0.949100,-0.074600,0.075200,0.846000,-0.234500,0.590100,-1.521400,0.374400,-0.194700,-0.309800,1.297600,0.329300,-1.250700,0.958500,-0.247100,0.083100,-1.150500,-0.535000,0.112800,-1.356800,0.879200,-0.353400,0.034500,0.241300,-0.205700,0.670600,0.633200,-0.368100,-0.754100,-0.153500,-0.475300,0.347100,0.370000,-0.380000,-0.739700,0.471700,-0.177900,0.308500,-0.058100,1.279900,0.776900,-0.088300,-1.248500,-0.973700,-0.211500,-0.210300,0.631500,-0.652400,0.866200,0.464500,-0.682000,-0.627600,-0.598000,-0.119200,0.473700,0.381500,0.567900,0.003600,-0.514900,0.536100,-0.803500,-0.619500,-0.141500,0.010400,1.268600,0.406200,-0.632000,0.250500,-0.218300,-0.168800,0.015000,-1.186700,-0.683500,1.632600,0.430000,-0.098000,0.436500,-0.068900,0.601700,0.006100,0.540800,-0.227800,-1.126100,1.165200,-0.220900,-0.202962,-3.636311,-0.504060,-2.546363,-1.235034,-0.883959,-0.348022,-0.219954,0.907031,-1.482731,0.669218,-0.477431,4.881980,3.885695,0.578319,1.427294,2.173270,-2.765083,0.004624,1.796896,1.087227,0.389897,0.604141,-1.155123,1.274209,-2.239976,-1.858146,3.090227,-0.206842,2.549677,2.601414,-0.692583,0.388238,-0.117103,-2.207036,-3.230492,-3.375904,-1.553133,2.262967,-2.091266,-0.825930,-2.791187,2.190521,-0.433236,-0.217687,-2.277860,-0.432154,-1.141102,-0.850199,-3.686642,2.615366,0.076896,-1.115686,1.734991,-1.578039,1.183485,0.641641,-2.347620,1.625458,-1.123846,1.017014,2.852135,-0.979481,0.912863,0.727238,-0.418464,-0.958715,-0.861919,0.282138,1.843323,0.175354,-1.792245,-1.370620,1.089480,0.778957,-2.377766,0.829453,-2.713742,-3.567303,-1.208078,1.233118,1.125459,4.193498,-2.459454,0.897581,1.001604,0.674028,-1.428830,-0.025545,1.150639,-3.673055,-0.666604,0.064266,0.285329,-1.370663,-0.463825,-0.842921,0.618591,1.990929,0.457696,-2.935576,0.301109,3.309814,-2.633363,-1.209220,-0.564443,-0.663638,1.399326,1.430363,-1.934421,-2.455737,-1.447479,0.263726,-0.861657,0.584651,-2.341039,3.445074,1.608032,0.724370,-0.370727,-2.025292,-0.842234,0.977376,3.447604,2.289111,2.478286,0.241298,-1.674832
0,-0.804500,0.572300,-0.357900,0.472200,1.037200,0.266700,-0.023200,0.858800,-0.484500,-0.782800,0.480700,0.119000,-0.293300,-0.504600,0.374600,-0.039300,0.935600,-1.255600,-0.258700,-0.582000,-1.719200,0.307800,0.052900,0.381800,0.577100,-0.998900,0.060600,0.373900,-0.281600,0.024100,-0.332200,0.038900,0.136100,-0.002500,0.724800,0.038700,-0.148800,1.535200,-0.059800,0.322100,-0.811600,0.363400,-1.402800,0.158200,-0.507700,-0.108200,-0.051600,-0.286800,-0.345700,-0.152300,-0.201400,-0.494600,-0.716300,0.541900,-1.629700,-0.287000,-1.277400,1.244700,0.011400,0.549900,0.883000,-1.100400,-0.700300,-0.079900,-1.227600,0.047900,-0.769000,0.821900,0.783400,0.173500,0.697400,0.499200,0.602800,0.548200,-0.256100,-0.751800,1.143400,0.295100,-0.123700,-0.503200,-0.160300,-0.908800,-0.056600,0.107600,0.436000,0.679800,0.313100,-0.249200,0.779700,0.801200,-1.650800,0.089900,0.026200,-0.338600,-0.115900,0.495700,0.088600,0.526900,0.595000,0.156700,-0.736900,0.558100,-0.095900,0.072100,-0.209400,-0.999600,-0.567300,-0.017400,-0.232500,-0.538800,-0.041200,1.247400,-0.610300,0.085700,0.321900,0.478900,-0.274800,0.074000,-0.387400,-0.306000,0.204200,0.978300,-0.738800,0.267800,0.299300,0.989500,-0.597800,-0.211500,0.302525,0.926751,0.444355,2.095530,0.641599,0.585963,-0.007165,-0.225599,1.195284,0.743535,-0.283189,0.421811,-0.900632,-1.775821,0.194162,-0.131157,2.221316,-0.871263,0.611026,1.586028,0.208971,1.728807,-1.214678,-0.006417,-0.487578,-1.347446,1.257976,-1.105078,-0.641283,2.040870,-1.064334,1.848631,0.021456,1.044769,1.046561,-0.382474,0.511813,1.991464,1.541210,1.197348,-0.132546,-1.227524,-1.825696,0.637844,0.266854,0.627479,-1.939037,1.784560,-1.572687,1.319858,-0.297955,-0.648528,1.552862,-0.390313,-1.862317,-1.434988,1.003443,2.372627,0.048504,-1.178071,0.345171,-0.493632,0.708266,0.439852,1.367206,0.587270,-1.676261,1.519096,2.178505,0.398875,-0.987587,-1.099164,2.224100,-0.032785,-1.974257,-2.476301,1.279583,0.368386,0.118637,-0.390930,0.206159,-1.526931,-0.706359,-0.666684,1.660718,2.577286,2.185187,-0.082288,1.171966,-0.962591,-1.345657,3.024471,0.326179,-1.740565,0.338833,2.163889,-1.306316,0.962814,2.811996,0.795088,0.042636,-1.563679,0.169866,-0.691936,0.281116,-0.114342,-0.654810,-0.018624,-1.712857,-1.027673,0.120613,1.324406,-0.825408,0.978356,-0.286835,1.155605,-0.480432,-0.661304,0.434739,0.736817,-1.921379,1.111957,0.592577,-0.935139,-0.926583,2.585314,-0.798262,-0.515275

 

解释:第一个数据为label,1表示正样本,0表示负样本;第2列到第129列表示用户的Embedding数据;第130列到第257列表示商品的Embedding数据;

 

3. 训练DSSM模型

 

接下来开始训练DSSM模型。

def train_finetune_dssm(train_path, val_path, model_path, \
    n_train = None, \
    n_val = None):

    model = build_model()

    print("train samples numbers: %s" % n_train)
    print("val samples numbers: %s" % n_val)
    batch_size = 128
    epochs = 2
    
    train_steps_per_epoch = int(n_train / batch_size)
    val_steps_per_epoch = int(n_val / batch_size)
    
    train_generator = file_generator(train_path, batch_size = batch_size)
    val_generator = file_generator(val_path, batch_size = batch_size)

    early_stopping_cb = EarlyStopping(monitor = 'val_loss', patience = 10, restore_best_weights = True) 
    tensorboard_cb = TensorBoard(\
        log_dir = './logs', \
        histogram_freq = 0, \
        write_graph = True, \
        write_grads = True, \
        write_images = True)
        
    
    callbacks = [early_stopping_cb, tensorboard_cb]
    start = time.time()

    history = model.fit_generator(\
        train_generator, \
        steps_per_epoch = train_steps_per_epoch, \
        epochs = epochs, \
        verbose = 1, \
        callbacks = callbacks, \
        validation_data = val_generator, \
        validation_steps = val_steps_per_epoch, \
        max_queue_size = 10, \
        workers = 1, \
        use_multiprocessing = False, \
        shuffle = True, \
        initial_epoch = 0)

    model.save_weights(model_path)

    last = time.time() - start
    print("Train model to %s done! Lasts %.2fs" % (model_path, last))

 

if __name__ == "__main__":
    train_path = "data/train_data"
    val_path = "data/val_data"
    model_path = "data/finetune_dssm.model"
    train_val_summary_path = "data/train_val_summary"

    n_train = 0
    n_val = 0
    fr = open(train_val_summary_path, 'r')
    for line in fr:
        buf = line[:-1].split(',')
        n_train = int(buf[0].split('=')[1])
        n_val = int(buf[1].split('=')[1])
        break
    fr.close()

    train_finetune_dssm(train_path, val_path, model_path, \
        n_train = n_train, \
        n_val = n_val)

 

其中:data/train_data 为训练集数据;data/val_data 为验证集数据;data/finetune_dssm.model 为最后训练完成后的模型;data/train_val_summary 为训练集和验证集数据信息;

模型训练过程如下图所示:

 

4. 生成最终的用户Embedding和商品Embedding

该模型产生的最终用户Embedding和商品Embedding分别对应 “模型结构图” 中的 dense_3 和 dense_4。

test_user_vec_embedding = np.array([0.1114, 0.298, 0.52, -2.1071, -0.6585, -0.0605, -0.7557, -0.3171, 0.7868, -0.0511, -0.5143, -0.7727, 0.9479, 0.0455, -0.1466, 0.6709, 0.7397, 0.7158, 0.519, 1.7333, -0.5671, 0.4758, 0.3921, 0.386, 0.0389, -0.2676, -0.5977, 0.365, -1.5146, 0.3621, -0.3169, 0.8737, -0.2084, -0.0795, -0.4015, -0.0402, -0.5455, 0.0019, 0.0183, 0.8367, -0.1545, -0.114, 0.6488, -0.9491, -0.0746, 0.0752, 0.846, -0.2345, 0.5901, -1.5214, 0.3744, -0.1947, -0.3098, 1.2976, 0.3293, -1.2507, 0.9585, -0.2471, 0.0831, -1.1505, -0.535, 0.1128, -1.3568, 0.8792, -0.3534, 0.0345, 0.2413, -0.2057, 0.6706, 0.6332, -0.3681, -0.7541, -0.1535, -0.4753, 0.3471, 0.37, -0.38, -0.7397, 0.4717, -0.1779, 0.3085, -0.0581, 1.2799, 0.7769, -0.0883, -1.2485, -0.9737, -0.2115, -0.2103, 0.6315, -0.6524, 0.8662, 0.4645, -0.682, -0.6276, -0.598, -0.1192, 0.4737, 0.3815, 0.5679, 0.0036, -0.5149, 0.5361, -0.8035, -0.6195, -0.1415, 0.0104, 1.2686, 0.4062, -0.632, 0.2505, -0.2183, -0.1688, 0.015, -1.1867, -0.6835, 1.6326, 0.43, -0.098, 0.4365, -0.0689, 0.6017, 0.0061, 0.5408, -0.2278, -1.1261, 1.1652, -0.2209]).reshape(1, -1)
test_item_vec_embedding = np.array([-0.202962, -3.636311, -0.50406, -2.546363, -1.235034, -0.883959, -0.348022, -0.219954, 0.907031, -1.482731, 0.669218, -0.477431, 4.88198, 3.885695, 0.578319, 1.427294, 2.17327, -2.765083, 0.004624, 1.796896, 1.087227, 0.389897, 0.604141, -1.155123, 1.274209, -2.239976, -1.858146, 3.090227, -0.206842, 2.549677, 2.601414, -0.692583, 0.388238, -0.117103, -2.207036, -3.230492, -3.375904, -1.553133, 2.262967, -2.091266, -0.82593, -2.791187, 2.190521, -0.433236, -0.217687, -2.27786, -0.432154, -1.141102, -0.850199, -3.686642, 2.615366, 0.076896, -1.115686, 1.734991, -1.578039, 1.183485, 0.641641, -2.34762, 1.625458, -1.123846, 1.017014, 2.852135, -0.979481, 0.912863, 0.727238, -0.418464, -0.958715, -0.861919, 0.282138, 1.843323, 0.175354, -1.792245, -1.37062, 1.08948, 0.778957, -2.377766, 0.829453, -2.713742, -3.567303, -1.208078, 1.233118, 1.125459, 4.193498, -2.459454, 0.897581, 1.001604, 0.674028, -1.42883, -0.025545, 1.150639, -3.673055, -0.666604, 0.064266, 0.285329, -1.370663, -0.463825, -0.842921, 0.618591, 1.990929, 0.457696, -2.935576, 0.301109, 3.309814, -2.633363, -1.20922, -0.564443, -0.663638, 1.399326, 1.430363, -1.934421, -2.455737, -1.447479, 0.263726, -0.861657, 0.584651, -2.341039, 3.445074, 1.608032, 0.72437, -0.370727, -2.025292, -0.842234, 0.977376, 3.447604, 2.289111, 2.478286, 0.241298, -1.674832]).reshape(1, -1)

user_embedding_model = Model(inputs=model.user_input, outputs=model.user_embedding)
item_embedding_model = Model(inputs=model.item_input, outputs=model.item_embedding)

user_emb = user_embedding_model.predict(test_user_vec_embedding, batch_size=1)
item_emb = item_embedding_model.predict(test_item_vec_embedding, batch_size=1)

print(user_emb)
print(item_emb)

可以看到新生成的用户Embedding和商品Embedding,均为64维。

根据某一用户的Embedding和商品集合的Embedding数据,使用NN方式检索用户感兴趣的商品集。可参考:https://github.com/milvus-iohttps://github.com/spotify/annoyhttps://github.com/facebookresearch/faiss

 

5. 结语

这里强烈推荐 Milvus, Milvus 基于高度优化的 Approximate Nearest Neighbor Search (ANNS) 索引库构建,包括 faiss、annoy、和 hnswlib 等。可以针对不同使用场景选择不同的索引类型。还提供了 Python、Java、Go 和 C++ SDK 与 Restful API,简单易用, 欢迎有需要的小伙伴请到 Milvus 官网与 GitHub 了解更多技术细节!

Milvus 官网:https://www.milvus.io/cn/

Milvus GitHub:https://github.com/milvus-io

展开阅读全文

没有更多推荐了,返回首页

应支付0元
点击重新获取
扫码支付

支付成功即可阅读