CN116911956A - 基于知识蒸馏的推荐模型训练方法、装置及存储介质 - Google Patents
基于知识蒸馏的推荐模型训练方法、装置及存储介质 Download PDFInfo
- Publication number
- CN116911956A CN116911956A CN202311168646.6A CN202311168646A CN116911956A CN 116911956 A CN116911956 A CN 116911956A CN 202311168646 A CN202311168646 A CN 202311168646A CN 116911956 A CN116911956 A CN 116911956A
- Authority
- CN
- China
- Prior art keywords
- model
- loss function
- training
- sample set
- popularity
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Classifications
-
- G—PHYSICS
- G06—COMPUTING OR CALCULATING; COUNTING
- G06Q—INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES; SYSTEMS OR METHODS SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES, NOT OTHERWISE PROVIDED FOR
- G06Q30/00—Commerce
- G06Q30/06—Buying, selling or leasing transactions
- G06Q30/0601—Electronic shopping [e-shopping]
- G06Q30/0631—Recommending goods or services
-
- G—PHYSICS
- G06—COMPUTING OR CALCULATING; COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING OR CALCULATING; COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/042—Knowledge-based neural networks; Logical representations of neural networks
-
- G—PHYSICS
- G06—COMPUTING OR CALCULATING; COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/09—Supervised learning
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- General Physics & Mathematics (AREA)
- Artificial Intelligence (AREA)
- Life Sciences & Earth Sciences (AREA)
- General Engineering & Computer Science (AREA)
- Evolutionary Computation (AREA)
- Business, Economics & Management (AREA)
- Computational Linguistics (AREA)
- Molecular Biology (AREA)
- Accounting & Taxation (AREA)
- Software Systems (AREA)
- Mathematical Physics (AREA)
- Computing Systems (AREA)
- Finance (AREA)
- General Health & Medical Sciences (AREA)
- Biophysics (AREA)
- Biomedical Technology (AREA)
- Health & Medical Sciences (AREA)
- Marketing (AREA)
- General Business, Economics & Management (AREA)
- Evolutionary Biology (AREA)
- Strategic Management (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Bioinformatics & Computational Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Economics (AREA)
- Development Economics (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本申请提供一种基于知识蒸馏的推荐模型训练方法、装置及存储介质。该方法包括:获取原始训练样本集,将原始训练样本集中的物品按照流行度进行分组,得到若干个物品流行度分组;获取利用原始训练样本集对教师模型进行训练得到的用户点击物品的概率,基于用户点击物品的概率将每个物品流行度分组内的物品进行排序;对排序后的物品流行度分组内的物品进行无偏采样,得到采样后的训练样本集;基于蒸馏损失函数以及监督损失函数生成模型训练损失函数,利用采样后的训练样本集以及模型训练损失函数,对学生模型进行知识蒸馏训练,其中学生模型采用推荐模型。本申请能够有效降低推荐偏差,增强推荐模型的泛化能力,提升推荐模型的性能和训练效果。
Description
技术领域
本申请涉及计算机技术领域,尤其涉及一种基于知识蒸馏的推荐模型训练方法、装置及存储介质。
背景技术
推荐系统在当今生活扮演着不可或缺的作用,无论是网络购物,新闻阅读,还是视频观看,都有其身影。为了让推荐系统推的更准,首先要对物品和用户进行充分建模,通过复杂的手段将用户最有可能点击的物品优先推送给用户,以提升用户的满意度和整个系统的效率。
目前,推荐模型的整体结构一般都是输入给模型一批用户特征和商品特征,对某一个特定商品进行判别用户是否会点击,购买该商品。该判别结果被作为模型的输出结果与真实的用户点击,购买结果进行损失函数计算,从而指导模型进行优化。在工业界中,为了响应海量的在线请求,提升推荐模型在线性能,知识蒸馏的方法往往会被采用到模型训练中,兼顾了模型的精准度和线上高并发的需求。
然而,现有推荐模型采用的知识蒸馏方法,没有考虑到模型的流行度偏差问题,会导致越流行的物体得到的曝光越多,从而引起更严重的马太效应,削弱了用户个性化的体验,这对在线推荐系统的长期影响是极其负面的。因此,现有的推荐模型的知识蒸馏训练方法导致推荐模型的训练效果降低。
发明内容
有鉴于此,本申请实施例提供了一种基于知识蒸馏的推荐模型训练方法、装置及存储介质,以解决现有技术存在的推荐模型的知识蒸馏训练方法削弱了用户个性化的体验,降低推荐模型的训练效果的问题。
本申请实施例的第一方面,提供了一种基于知识蒸馏的推荐模型训练方法,包括:获取原始训练样本集,将原始训练样本集中的物品按照流行度进行分组,得到若干个物品流行度分组;获取利用原始训练样本集对教师模型进行训练得到的用户点击物品的概率,基于用户点击物品的概率将每个物品流行度分组内的物品进行排序;对排序后的物品流行度分组内的物品进行无偏采样,得到采样后的训练样本集;基于蒸馏损失函数以及监督损失函数生成模型训练损失函数,利用采样后的训练样本集以及模型训练损失函数,对学生模型进行知识蒸馏训练,其中学生模型采用推荐模型。
本申请实施例的第二方面,提供了一种基于知识蒸馏的推荐模型训练装置,包括:分组模块,被配置为获取原始训练样本集,将原始训练样本集中的物品按照流行度进行分组,得到若干个物品流行度分组;排序模块,被配置为获取利用原始训练样本集对教师模型进行训练得到的用户点击物品的概率,基于用户点击物品的概率将每个物品流行度分组内的物品进行排序;采样模块,被配置为对排序后的物品流行度分组内的物品进行无偏采样,得到采样后的训练样本集;训练模块,被配置为基于蒸馏损失函数以及监督损失函数生成模型训练损失函数,利用采样后的训练样本集以及模型训练损失函数,对学生模型进行知识蒸馏训练,其中学生模型采用推荐模型。
本申请实施例的第三方面,提供了一种电子设备,包括存储器,处理器及存储在存储器上并可在处理器上运行的计算机程序,处理器执行计算机程序时实现上述方法的步骤。
本申请实施例的第四方面,提供了一种计算机可读存储介质,该计算机可读存储介质存储有计算机程序,该计算机程序被处理器执行时实现上述方法的步骤。
本申请实施例采用的上述至少一个技术方案能够达到以下有益效果:
通过获取原始训练样本集,将原始训练样本集中的物品按照流行度进行分组,得到若干个物品流行度分组;获取利用原始训练样本集对教师模型进行训练得到的用户点击物品的概率,基于用户点击物品的概率将每个物品流行度分组内的物品进行排序;对排序后的物品流行度分组内的物品进行无偏采样,得到采样后的训练样本集;基于蒸馏损失函数以及监督损失函数生成模型训练损失函数,利用采样后的训练样本集以及模型训练损失函数,对学生模型进行知识蒸馏训练,其中学生模型采用推荐模型。本申请有效地去除了蒸馏后学生模型中的流行度偏差,有效降低推荐偏差,提升推荐模型的性能,增强推荐模型的泛化能力,提升推荐模型的训练效果。
附图说明
为了更清楚地说明本申请实施例中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其它的附图。
图1是本申请实施例提供的基于知识蒸馏的推荐模型训练方法的流程示意图;
图2是本申请实施例提供的基于知识蒸馏的推荐模型训练装置的结构示意图;
图3是本申请实施例提供的电子设备的结构示意图。
具体实施方式
以下描述中,为了说明而不是为了限定,提出了诸如特定系统结构、技术之类的具体细节,以便透彻理解本申请实施例。然而,本领域的技术人员应当清楚,在没有这些具体细节的其它实施例中也可以实现本申请。在其它情况中,省略对众所周知的系统、装置、电路以及方法的详细说明,以免不必要的细节妨碍本申请的描述。
本申请技术方案的目的是解决推荐系统中由于知识蒸馏引入过多偏差的问题。在知识蒸馏过程中,在从教师模型蒸馏给学生模型的过程中,流行度偏差会被继承甚至放大,导致线上的推荐模型效果严重偏离离线训练的结果,大打折扣,不如预期。
针对现有技术中存在的问题,本申请实施例提供了一种新的推荐模型知识蒸馏训练方法。本申请通过在训练蒸馏模型的过程中将物品按照流行度分层,并在每个层内进行无偏采样,有效地去除了蒸馏后学生模型中的流行度偏差,从而显著提升了在线推荐的性能。在传统的知识蒸馏过程中,流行度偏差会被继承并可能被放大,这导致线上推荐模型的效果与离线训练的结果有严重偏差。本申请通过将物品按照流行度进行分层和无偏采样,有效地降低了这种偏差,使得学生模型的效果更接近教师模型,从而提高了推荐的准确性和用户满意度。本申请的方法通过分层和无偏采样,使得所有流行度的物品都能得到充分的训练,这避免了模型过度关注流行物品而忽视长尾物品的问题,从而增强了模型的泛化能力。本申请通过使用流行度分层和无偏采样,使得训练样本更加均衡,从而提高了训练效率。此外,基于蒸馏损失函数和监督损失函数生成的模型训练损失函数,也有助于提高训练的效率和模型的性能。
图1是本申请实施例提供的基于知识蒸馏的推荐模型训练方法的流程示意图。图1的基于知识蒸馏的推荐模型训练方法可以由服务器执行。如图1所示,该基于知识蒸馏的推荐模型训练方法具体可以包括:
S101,获取原始训练样本集,将原始训练样本集中的物品按照流行度进行分组,得到若干个物品流行度分组;
S102,获取利用原始训练样本集对教师模型进行训练得到的用户点击物品的概率,基于用户点击物品的概率将每个物品流行度分组内的物品进行排序;
S103,对排序后的物品流行度分组内的物品进行无偏采样,得到采样后的训练样本集;
S104,基于蒸馏损失函数以及监督损失函数生成模型训练损失函数,利用采样后的训练样本集以及模型训练损失函数,对学生模型进行知识蒸馏训练,其中学生模型采用推荐模型。
首先,结合本申请实施例的实际应用场景,对知识蒸馏的实现过程及原理进行说明。知识蒸馏是一种机器学习方法,该方法的主要思想是将一个大型、复杂的模型(称为教师模型)的“知识”迁移到一个更小、更简单的模型(称为学生模型)。这种方法可以在保持预测性能的同时,得到一个计算效率更高的模型。知识蒸馏的实现过程及原理主要包括以下步骤:
1. 训练教师模型:首先,需要有一个已经训练好的、性能优秀的大模型或者模型集合,这就是“教师模型”。这个模型通常会是一个大型的深度神经网络,它在训练数据集上已经达到了很高的准确率。
2. 生成软标签:在训练学生模型时,不仅会使用原始的硬标签(即原始的分类标签),还会使用教师模型的预测结果作为“软标签”。这些软标签能够提供比硬标签更丰富的信息,因为它们包含了数据点可能属于各个类别的概率。这些概率分布信息被看作是教师模型的“知识”,并被用来指导学生模型的学习。
3. 训练学生模型:最后,通过训练一个小型的模型(即学生模型),并使用教师模型的软标签以及原始硬标签作为目标。具体来说,训练的目标函数通常是硬标签的损失函数和软标签的损失函数的线性组合。通过这种方式,学生模型可以在模仿教师模型的同时,也考虑到原始的目标。
知识蒸馏的主要优点是,它可以帮助本申请实施例得到一个更小、更快、更节能的模型,同时保持良好的预测性能。在许多实际应用中,如移动设备和嵌入式系统中是非常重要的。
在一些实施例中,将原始训练样本集中的物品按照流行度进行分组,得到若干个物品流行度分组,包括:根据用户对原始训练样本集中每个物品的历史点击记录,确定每个物品对应的流行度,利用物品的流行度,对原始训练样本集中的物品进行分组,其中,每个物品流行度分组内的物品流行度的总和相同。
具体地,本申请首先对原始训练样本集进行处理。原始训练样本集包含了大量的用户与物品的交互记录,例如用户的点击行为、购买行为等。为了将物品按照流行度进行分组,本申请首先需要确定每个物品对应的流行度。在实际应用中,物品的流行度可以通过用户对每个物品的历史点击记录来确定。具体地,可以统计在所有用户的历史点击记录中,每个物品被点击的次数,这个次数就可以作为物品的流行度。也就是说,被点击次数更多的物品具有更高的流行度。
进一步地,在获取每个物品的流行度之后,本申请实施例就可以对原始训练样本集中的物品进行分组。具体地,可以将物品按照流行度分为K组,同时保证每组的物品流行度之和相同。这样就可以保证每组内的物品流行度大致相同,从而降低了流行度偏差带来的影响。为了更直观地展示这个分组结果,本申请还可以采用不同的颜色代表不同的流行度水平,将混杂在一起的物品按照流行度进行归类。
在一些实施例中,获取利用原始训练样本集对教师模型进行训练得到的用户点击物品的概率,包括:将原始训练样本集输入到教师模型中,利用教师模型对原始训练样本集中的物品对应的用户点击物品的概率进行预测,得到每个物品对应的用户点击物品的概率。
具体地,本申请需要获取利用原始训练样本集对教师模型进行训练得到的用户点击物品的概率。例如,将原始训练样本集输入到教师模型中,教师模型可以是一个深度学习模型(比如深度神经网络)。教师模型通过学习用户历史的行为数据,预测用户对每个物品的点击概率(soft label)。这些soft label不仅反映了用户可能对物品的兴趣程度,也隐含了物品的流行度信息。
进一步地,基于从教师模型获得的用户点击物品的概率,将每个物品流行度分组内的物品进行排序。这意味着,在每个物品流行度分组中,物品会根据它们的soft label进行排序。这样做的目的是为了后续的无偏采样做准备,使得每个分组中不同流行度的物品都有机会被采样到。
在实际应用中,对于教师模型的训练,可以采用常见的深度学习训练方法,例如随机梯度下降(SGD)或者Adam优化器等。此外,可以选择合适的损失函数,例如交叉熵损失函数,来衡量教师模型预测的用户点击物品概率与实际点击情况的差距,并通过反向传播来优化教师模型的参数。在实际操作中,还可以选择添加正则化项,例如L2正则化,来防止模型过拟合。
在一些实施例中,对排序后的物品流行度分组内的物品进行无偏采样,得到采样后的训练样本集,包括:选取每个排序后的物品流行度分组内用户点击物品的概率最高以及用户点击物品的概率最低的物品,将用户点击物品的概率最高的物品作为高点击概率物品,将用户点击物品的概率最低的物品作为低点击概率物品,利用高点击概率物品和低点击概率物品生成采样后的训练样本集。
具体地,本申请对排序后的物品流行度分组内的物品进行无偏采样,得到采样后的训练样本集。具体来说,本申请实施例选取每个排序后的物品流行度分组内用户点击物品的概率最高以及用户点击物品的概率最低的物品。
进一步地,这两种物品分别被称为高点击概率物品和低点击概率物品。它们分别代表了一个分组中最受欢迎和最不受欢迎的物品,反映了用户的兴趣分布和物品的流行度分布。这种采样方法可以确保本申请实施例考虑了一个分组中各种流行度的物品,这对于提高学生模型的泛化能力和推荐效果是非常重要的。在获取了高点击概率物品和低点击概率物品后,本申请实施例将它们作为输入生成采样后的训练样本集。每个训练样本包含了用户的特征、物品的特征、以及用户对物品的点击概率(soft label)。这些训练样本将被用于训练学生模型。
需要注意的是,上述采样方式这只是本申请实施例的一种可选的方式,本申请并不限于这种实现方式。在其他实施例中,可以选择使用不同的采样策略,例如均匀采样、比例采样等,只要能够保证无偏性,都在本申请的保护范围之内。
在一些实施例中,在基于蒸馏损失函数以及监督损失函数生成模型训练损失函数之前,该方法还包括:将采样后的训练样本集中的用户和物品转换为表征向量,利用采样后的训练样本集中用户对应的表征向量以及物品对应的表征向量,构造蒸馏损失函数。
具体地,本发明在基于蒸馏损失函数以及监督损失函数生成模型训练损失函数之前,还会将采样后的训练样本集中的用户和物品转换为表征向量。用户和物品的表征向量通常是由深度学习模型,如Embedding层,生成的低维度的实数向量,它们捕获了用户和物品的主要特性。
因此,利用采样后的训练样本集中用户对应的表征向量以及物品对应的表征向量,本申请实施例可以构造蒸馏损失函数。具体地,蒸馏损失函数如下所示:
;
其中,表示用户对应的表征向量,表示物品对应的表征向量。
在一些实施例中,基于蒸馏损失函数以及监督损失函数生成模型训练损失函数,包括:将蒸馏损失函数与预设的权重相乘,得到加权后的蒸馏损失函数,将加权后的蒸馏损失函数与监督损失函数求和,得到模型训练损失函数;其中,监督损失函数采用教师模型训练时的损失函数。
具体地,本申请实施例将蒸馏损失函数和监督损失函数结合起来,生成最终的模型训练损失函数,模型训练损失函数的形式如下:
;
其中,表示模型训练损失函数,表示监督损失函数,表示蒸馏损失函数,表示超参数(即权重)。
需要说明的是,是一个超参数,用于调节蒸馏损失函数和监督损失函数的相对重要性。在实际应用中,监督损失函数采用教师模型训练时的损失函数,即将用于训练教师模型的损失函数作为监督损失函数。本申请实施例通过这种方式,将教师模型的知识和原始训练目标同时考虑进来,从而在保证模型性能的同时,也尽可能地减小了流行度偏差。
在一些实施例中,利用采样后的训练样本集以及模型训练损失函数,对学生模型进行知识蒸馏训练,包括:将采样后的训练样本集作为对学生模型进行知识蒸馏训练的输入,通过最小化模型训练损失函数,对学生模型的参数进行更新,得到训练后的学生模型。
具体地,将采样后的训练样本集作为对学生模型进行知识蒸馏训练的输入,每一个训练样本都包含了一个用户和一个物品的信息,以及对应的用户点击物品的概率。本申请实施例使用模型训练损失函数来衡量学生模型的预测结果与实际结果之间的差距。这个损失函数是由蒸馏损失函数和监督损失函数组合而成的,其中蒸馏损失函数用于度量学生模型的预测结果与教师模型的预测结果之间的差距,监督损失函数则用于度量学生模型的预测结果与实际结果之间的差距。
进一步地,本申请实施例使用一种优化算法,例如随机梯度下降(SGD)或者Adam优化器,来最小化模型训练损失函数,从而对学生模型的参数进行更新。通过多次迭代这个过程,本申请实施例就可以得到训练后的学生模型。这个学生模型继承了教师模型的知识,同时也考虑了原始训练目标,从而在保持模型性能的同时,也尽可能地减小了流行度偏差。这样本申请实施例就得到了一个既能够保持良好推荐性能,又能够较好地抵抗流行度偏差的推荐模型。
需要说明的是,本申请提出的去除偏差的方法具有很好的通用性,可以适用于各种推荐模型框架。例如,DeepFM(Deep FactorizationMachines)是一种将深度神经网络与因子分解机结合的模型,而WDL(Wide&Deep Learning)是一种将宽模型(线性模型)和深模型(深度神经网络)结合的模型。这些模型在推荐系统中都有广泛的应用。
对于模型的损失函数,本申请同样不做特定的限制。可以选择平方损失、BPR损失、交叉熵损失等任意一种适合的损失函数。例如,平方损失函数通常用于回归问题,交叉熵损失函数通常用于分类问题,而BPR损失(Bayesian PersonalizedRanking)则专门用于个性化排序问题。根据具体的推荐任务和模型,可以选择最适合的损失函数。因此,本申请的方法在设计时充分考虑了通用性和灵活性,使得它可以应用于各种不同的推荐模型框架和任务中,具有很高的实用价值。
根据本申请实施例提供的技术方案,本申请实施例提出了一个无偏的、教师无关的知识蒸馏模型,该模型能从教师模型中提取流行度感知的排序知识,从而指导学生模型的学习。这种方法不仅将教师模型的知识成功地传递给了学生模型,而且还有效地避免了流行度偏差,从而保证了模型的泛化能力和推荐效果。本申请通过使用分组的方法将流行度偏差控制在组内,这种方法平衡了流行度较少的样本训练不够充分和头部较流行的物体被过度训练的问题,从而保证了所有物品都能得到充分的训练,提高了推荐模型的泛化能力。由于本申请充分地考虑了流行度信息,并且控制了流行度偏差,因此可以显著提升推荐效果,提高用户满意度。
下述为本申请装置实施例,可以用于执行本申请方法实施例。对于本申请装置实施例中未披露的细节,请参照本申请方法实施例。
图2是本申请实施例提供的基于知识蒸馏的推荐模型训练装置的结构示意图。如图2所示,该基于知识蒸馏的推荐模型训练装置包括:
分组模块201,被配置为获取原始训练样本集,将原始训练样本集中的物品按照流行度进行分组,得到若干个物品流行度分组;
排序模块202,被配置为获取利用原始训练样本集对教师模型进行训练得到的用户点击物品的概率,基于用户点击物品的概率将每个物品流行度分组内的物品进行排序;
采样模块203,被配置为对排序后的物品流行度分组内的物品进行无偏采样,得到采样后的训练样本集;
训练模块204,被配置为基于蒸馏损失函数以及监督损失函数生成模型训练损失函数,利用采样后的训练样本集以及模型训练损失函数,对学生模型进行知识蒸馏训练,其中学生模型采用推荐模型。
在一些实施例中,图2的分组模块201根据用户对原始训练样本集中每个物品的历史点击记录,确定每个物品对应的流行度,利用物品的流行度,对原始训练样本集中的物品进行分组,其中,每个物品流行度分组内的物品流行度的总和相同。
在一些实施例中,图2的排序模块202将原始训练样本集输入到教师模型中,利用教师模型对原始训练样本集中的物品对应的用户点击物品的概率进行预测,得到每个物品对应的用户点击物品的概率。
在一些实施例中,图2的采样模块203选取每个排序后的物品流行度分组内用户点击物品的概率最高以及用户点击物品的概率最低的物品,将用户点击物品的概率最高的物品作为高点击概率物品,将用户点击物品的概率最低的物品作为低点击概率物品,利用高点击概率物品和低点击概率物品生成采样后的训练样本集。
在一些实施例中,图2的训练模块204在基于蒸馏损失函数以及监督损失函数生成模型训练损失函数之前,将采样后的训练样本集中的用户和物品转换为表征向量,利用采样后的训练样本集中用户对应的表征向量以及物品对应的表征向量,构造蒸馏损失函数。
在一些实施例中,图2的训练模块204将蒸馏损失函数与预设的权重相乘,得到加权后的蒸馏损失函数,将加权后的蒸馏损失函数与监督损失函数求和,得到模型训练损失函数;其中,监督损失函数采用教师模型训练时的损失函数。
在一些实施例中,图2的训练模块204将采样后的训练样本集作为对学生模型进行知识蒸馏训练的输入,通过最小化模型训练损失函数,对学生模型的参数进行更新,得到训练后的学生模型。
理解,上述实施例中各步骤的序号的大小并不意味着执行顺序的先后,各过程的执行顺序应以其功能和内在逻辑确定,而不应对本申请实施例的实施过程构成任何限定。
图3是本申请实施例提供的电子设备3的结构示意图。如图3所示,该实施例的电子设备3包括:处理器301、存储器302以及存储在该存储器302中并且可以在处理器301上运行的计算机程序303。处理器301执行计算机程序303时实现上述各个方法实施例中的步骤。或者,处理器301执行计算机程序303时实现上述各装置实施例中各模块/单元的功能。
示例性地,计算机程序303可以被分割成一个或多个模块/单元,一个或多个模块/单元被存储在存储器302中,并由处理器301执行,以完成本申请。一个或多个模块/单元可以是能够完成特定功能的一系列计算机程序指令段,该指令段用于描述计算机程序303在电子设备3中的执行过程。
电子设备3可以是桌上型计算机、笔记本、掌上电脑及云端服务器等电子设备。电子设备3可以包括但不仅限于处理器301和存储器302。本领域技术人员可以理解,图3仅仅是电子设备3的示例,并不构成对电子设备3的限定,可以包括比图示更多或更少的部件,或者组合某些部件,或者不同的部件,例如,电子设备还可以包括输入输出设备、网络接入设备、总线等。
处理器301可以是中央处理单元(Central Processing Unit,CPU),也可以是其它通用处理器、数字信号处理器(Digital Signal Processor,DSP)、专用集成电路(Application SpecificIntegrated Circuit,ASIC)、现场可编程门阵列(Field-Programmable Gate Array,FPGA)或者其它可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等。通用处理器可以是微处理器或者该处理器也可以是任何常规的处理器等。
存储器302可以是电子设备3的内部存储单元,例如,电子设备3的硬盘或内存。存储器302也可以是电子设备3的外部存储设备,例如,电子设备3上配备的插接式硬盘,智能存储卡(Smart Media Card,SMC),安全数字(Secure Digital,SD)卡,闪存卡(Flash Card)等。进一步地,存储器302还可以既包括电子设备3的内部存储单元也包括外部存储设备。存储器302用于存储计算机程序以及电子设备所需的其它程序和数据。存储器302还可以用于暂时地存储已经输出或者将要输出的数据。
所属领域的技术人员可以清楚地了解到,为了描述的方便和简洁,仅以上述各功能单元、模块的划分进行举例说明,实际应用中,可以根据需要而将上述功能分配由不同的功能单元、模块完成,即将装置的内部结构划分成不同的功能单元或模块,以完成以上描述的全部或者部分功能。实施例中的各功能单元、模块可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中,上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。另外,各功能单元、模块的具体名称也只是为了便于相互区分,并不用于限制本申请的保护范围。上述系统中单元、模块的具体工作过程,可以参考前述方法实施例中的对应过程,在此不再赘述。
在上述实施例中,对各个实施例的描述都各有侧重,某个实施例中没有详述或记载的部分,可以参见其它实施例的相关描述。
本领域普通技术人员可以意识到,结合本文中所公开的实施例描述的各示例的单元及算法步骤,能够以电子硬件、或者计算机软件和电子硬件的结合来实现。这些功能究竟以硬件还是软件方式来执行,取决于技术方案的特定应用和设计约束条件。专业技术人员可以对每个特定的应用来使用不同方法来实现所描述的功能,但是这种实现不应认为超出本申请的范围。
在本申请所提供的实施例中,应该理解到,所揭露的装置/计算机设备和方法,可以通过其它的方式实现。例如,以上所描述的装置/计算机设备实施例仅仅是示意性的,例如,模块或单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,多个单元或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。另一点,所显示或讨论的相互之间的耦合或直接耦合或通讯连接可以是通过一些接口,装置或单元的间接耦合或通讯连接,可以是电性,机械或其它的形式。
作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部单元来实现本实施例方案的目的。
另外,在本申请各个实施例中的各功能单元可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中。上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。
集成的模块/单元如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读存储介质中。基于这样的理解,本申请实现上述实施例方法中的全部或部分流程,也可以通过计算机程序来指令相关的硬件来完成,计算机程序可以存储在计算机可读存储介质中,该计算机程序在被处理器执行时,可以实现上述各个方法实施例的步骤。计算机程序可以包括计算机程序代码,计算机程序代码可以为源代码形式、对象代码形式、可执行文件或某些中间形式等。计算机可读介质可以包括:能够携带计算机程序代码的任何实体或装置、记录介质、U盘、移动硬盘、磁碟、光盘、计算机存储器、只读存储器(Read-Only Memory,ROM)、随机存取存储器(Random Access Memory,RAM)、电载波信号、电信信号以及软件分发介质等。需要说明的是,计算机可读介质包含的内容可以根据司法管辖区内立法和专利实践的要求进行适当的增减,例如,在某些司法管辖区,根据立法和专利实践,计算机可读介质不包括电载波信号和电信信号。
以上实施例仅用以说明本申请的技术方案,而非对其限制;尽管参照前述实施例对本申请进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本申请各实施例技术方案的精神和范围,均应包含在本申请的保护范围之内。
Claims (10)
1.一种基于知识蒸馏的推荐模型训练方法,其特征在于,包括:
获取原始训练样本集,将所述原始训练样本集中的物品按照流行度进行分组,得到若干个物品流行度分组;
获取利用所述原始训练样本集对教师模型进行训练得到的用户点击物品的概率,基于所述用户点击物品的概率将每个所述物品流行度分组内的物品进行排序;
对排序后的物品流行度分组内的物品进行无偏采样,得到采样后的训练样本集;
基于蒸馏损失函数以及监督损失函数生成模型训练损失函数,利用所述采样后的训练样本集以及所述模型训练损失函数,对学生模型进行知识蒸馏训练,其中所述学生模型采用推荐模型。
2.根据权利要求1所述的方法,其特征在于,所述将所述原始训练样本集中的物品按照流行度进行分组,得到若干个物品流行度分组,包括:
根据用户对所述原始训练样本集中每个所述物品的历史点击记录,确定每个所述物品对应的流行度,利用所述物品的流行度,对所述原始训练样本集中的物品进行分组,其中,每个所述物品流行度分组内的物品流行度的总和相同。
3.根据权利要求1所述的方法,其特征在于,所述获取利用所述原始训练样本集对教师模型进行训练得到的用户点击物品的概率,包括:
将所述原始训练样本集输入到所述教师模型中,利用所述教师模型对所述原始训练样本集中的所述物品对应的用户点击物品的概率进行预测,得到每个所述物品对应的用户点击物品的概率。
4.根据权利要求1所述的方法,其特征在于,所述对排序后的物品流行度分组内的物品进行无偏采样,得到采样后的训练样本集,包括:
选取每个所述排序后的物品流行度分组内所述用户点击物品的概率最高以及所述用户点击物品的概率最低的物品,将所述用户点击物品的概率最高的物品作为高点击概率物品,将所述用户点击物品的概率最低的物品作为低点击概率物品,利用所述高点击概率物品和所述低点击概率物品生成所述采样后的训练样本集。
5.根据权利要求1所述的方法,其特征在于,在所述基于蒸馏损失函数以及监督损失函数生成模型训练损失函数之前,所述方法还包括:
将所述采样后的训练样本集中的用户和物品转换为表征向量,利用所述采样后的训练样本集中用户对应的表征向量以及物品对应的表征向量,构造所述蒸馏损失函数。
6.根据权利要求1所述的方法,其特征在于,所述基于蒸馏损失函数以及监督损失函数生成模型训练损失函数,包括:
将所述蒸馏损失函数与预设的权重相乘,得到加权后的蒸馏损失函数,将所述加权后的蒸馏损失函数与所述监督损失函数求和,得到所述模型训练损失函数;其中,所述监督损失函数采用所述教师模型训练时的损失函数。
7.根据权利要求1所述的方法,其特征在于,所述利用所述采样后的训练样本集以及所述模型训练损失函数,对学生模型进行知识蒸馏训练,包括:
将所述采样后的训练样本集作为对所述学生模型进行知识蒸馏训练的输入,通过最小化所述模型训练损失函数,对所述学生模型的参数进行更新,得到训练后的学生模型。
8.一种基于知识蒸馏的推荐模型训练装置,其特征在于,包括:
分组模块,被配置为获取原始训练样本集,将所述原始训练样本集中的物品按照流行度进行分组,得到若干个物品流行度分组;
排序模块,被配置为获取利用所述原始训练样本集对教师模型进行训练得到的用户点击物品的概率,基于所述用户点击物品的概率将每个所述物品流行度分组内的物品进行排序;
采样模块,被配置为对排序后的物品流行度分组内的物品进行无偏采样,得到采样后的训练样本集;
训练模块,被配置为基于蒸馏损失函数以及监督损失函数生成模型训练损失函数,利用所述采样后的训练样本集以及所述模型训练损失函数,对学生模型进行知识蒸馏训练,其中所述学生模型采用推荐模型。
9.一种电子设备,包括存储器,处理器及存储在存储器上并可在处理器上运行的计算机程序,其特征在于,所述处理器执行所述计算机程序时实现如权利要求1至7中任一项所述的方法。
10.一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现如权利要求1至7中任一项所述的方法。
Priority Applications (1)
| Application Number | Priority Date | Filing Date | Title |
|---|---|---|---|
| CN202311168646.6A CN116911956A (zh) | 2023-09-12 | 2023-09-12 | 基于知识蒸馏的推荐模型训练方法、装置及存储介质 |
Applications Claiming Priority (1)
| Application Number | Priority Date | Filing Date | Title |
|---|---|---|---|
| CN202311168646.6A CN116911956A (zh) | 2023-09-12 | 2023-09-12 | 基于知识蒸馏的推荐模型训练方法、装置及存储介质 |
Publications (1)
| Publication Number | Publication Date |
|---|---|
| CN116911956A true CN116911956A (zh) | 2023-10-20 |
Family
ID=88368117
Family Applications (1)
| Application Number | Title | Priority Date | Filing Date |
|---|---|---|---|
| CN202311168646.6A Pending CN116911956A (zh) | 2023-09-12 | 2023-09-12 | 基于知识蒸馏的推荐模型训练方法、装置及存储介质 |
Country Status (1)
| Country | Link |
|---|---|
| CN (1) | CN116911956A (zh) |
Cited By (1)
| Publication number | Priority date | Publication date | Assignee | Title |
|---|---|---|---|---|
| CN120146907A (zh) * | 2025-02-26 | 2025-06-13 | 杭州电子科技大学 | 一种消费者需求预测方法、设备、介质及产品 |
Citations (3)
| Publication number | Priority date | Publication date | Assignee | Title |
|---|---|---|---|---|
| CN112967088A (zh) * | 2021-03-03 | 2021-06-15 | 上海数鸣人工智能科技有限公司 | 基于知识蒸馏的营销活动预测模型结构和预测方法 |
| CN115687794A (zh) * | 2022-12-29 | 2023-02-03 | 中国科学技术大学 | 用于推荐物品的学生模型训练方法、装置、设备及介质 |
| US20230162005A1 (en) * | 2020-07-24 | 2023-05-25 | Huawei Technologies Co., Ltd. | Neural network distillation method and apparatus |
-
2023
- 2023-09-12 CN CN202311168646.6A patent/CN116911956A/zh active Pending
Patent Citations (4)
| Publication number | Priority date | Publication date | Assignee | Title |
|---|---|---|---|---|
| US20230162005A1 (en) * | 2020-07-24 | 2023-05-25 | Huawei Technologies Co., Ltd. | Neural network distillation method and apparatus |
| CN116249991A (zh) * | 2020-07-24 | 2023-06-09 | 华为技术有限公司 | 一种神经网络蒸馏方法以及装置 |
| CN112967088A (zh) * | 2021-03-03 | 2021-06-15 | 上海数鸣人工智能科技有限公司 | 基于知识蒸馏的营销活动预测模型结构和预测方法 |
| CN115687794A (zh) * | 2022-12-29 | 2023-02-03 | 中国科学技术大学 | 用于推荐物品的学生模型训练方法、装置、设备及介质 |
Non-Patent Citations (1)
| Title |
|---|
| CHEN GANG等: ""unbiased knowledge distillation for recommendation"", WSDM, no. 3, pages 976 - 984, XP058991649, DOI: 10.1145/3539597.3570477 * |
Cited By (1)
| Publication number | Priority date | Publication date | Assignee | Title |
|---|---|---|---|---|
| CN120146907A (zh) * | 2025-02-26 | 2025-06-13 | 杭州电子科技大学 | 一种消费者需求预测方法、设备、介质及产品 |
Similar Documents
| Publication | Publication Date | Title |
|---|---|---|
| US11586880B2 (en) | System and method for multi-horizon time series forecasting with dynamic temporal context learning | |
| Chopra et al. | Introduction to machine learning with Python | |
| CN111291266A (zh) | 基于人工智能的推荐方法、装置、电子设备及存储介质 | |
| US20210303970A1 (en) | Processing data using multiple neural networks | |
| CN107220217A (zh) | 基于逻辑回归的特征系数训练方法和装置 | |
| CN107423442A (zh) | 基于用户画像行为分析的应用推荐方法及系统,储存介质及计算机设备 | |
| CN110866199A (zh) | 位置确定方法、装置、电子设备和计算机可读介质 | |
| Subramanian et al. | Ensemble-based deep learning techniques for customer churn prediction model | |
| CN114912030A (zh) | 权益模型训练方法、推荐方法及电子终端和计算机介质 | |
| Saleh | Machine Learning Fundamentals: Use Python and scikit-learn to get up and running with the hottest developments in machine learning | |
| US20230237386A1 (en) | Forecasting time-series data using ensemble learning | |
| CN112927050A (zh) | 待推荐金融产品确定方法、装置、电子设备及存储介质 | |
| CN116050516A (zh) | 基于知识蒸馏的文本处理方法及装置、设备和介质 | |
| Yuan et al. | TRiP: a transfer learning based rice disease phenotype recognition platform using SENet and microservices | |
| Rahman et al. | Deep learning modeling for potato breed recognition | |
| CN118014693A (zh) | 会员商品推送方法、装置、系统、电子设备及存储介质 | |
| Saleh | The The Machine Learning Workshop: Get ready to develop your own high-performance machine learning algorithms with scikit-learn | |
| Kanwal et al. | An attribute weight estimation using particle swarm optimization and machine learning approaches for customer churn prediction | |
| CN113138977A (zh) | 交易转化分析方法、装置、设备及存储介质 | |
| CN116911956A (zh) | 基于知识蒸馏的推荐模型训练方法、装置及存储介质 | |
| Saranya et al. | RETRACTED: FBCNN-TSA: An optimal deep learning model for banana ripening stages classification | |
| CN119379389B (zh) | 一种基于多用户商城的erp后台管理方法及系统 | |
| Ghosh et al. | Understanding machine learning | |
| US20220012151A1 (en) | Automated data linkages across datasets | |
| CN111768218B (zh) | 用于处理用户交互信息的方法和装置 |
Legal Events
| Date | Code | Title | Description |
|---|---|---|---|
| PB01 | Publication | ||
| PB01 | Publication | ||
| SE01 | Entry into force of request for substantive examination | ||
| SE01 | Entry into force of request for substantive examination | ||
| RJ01 | Rejection of invention patent application after publication |
Application publication date: 20231020 |
|
| RJ01 | Rejection of invention patent application after publication |