CN116259057B - 基于联盟博弈解决联邦学习中数据异质性问题的方法 - Google Patents
基于联盟博弈解决联邦学习中数据异质性问题的方法Info
- Publication number
- CN116259057B CN116259057B CN202310167065.4A CN202310167065A CN116259057B CN 116259057 B CN116259057 B CN 116259057B CN 202310167065 A CN202310167065 A CN 202310167065A CN 116259057 B CN116259057 B CN 116259057B
- Authority
- CN
- China
- Prior art keywords
- client
- alliance
- neural network
- convolutional neural
- server
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
Classifications
-
- G—PHYSICS
- G06—COMPUTING OR CALCULATING; COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V30/00—Character recognition; Recognising digital ink; Document-oriented image-based pattern recognition
- G06V30/10—Character recognition
- G06V30/18—Extraction of features or characteristics of the image
- G06V30/1801—Detecting partial patterns, e.g. edges or contours, or configurations, e.g. loops, corners, strokes or intersections
-
- G—PHYSICS
- G06—COMPUTING OR CALCULATING; COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING OR CALCULATING; COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N5/00—Computing arrangements using knowledge-based models
- G06N5/04—Inference or reasoning models
- G06N5/042—Backward inferencing
-
- G—PHYSICS
- G06—COMPUTING OR CALCULATING; COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/40—Extraction of image or video features
- G06V10/44—Local feature extraction by analysis of parts of the pattern, e.g. by detecting edges, contours, loops, corners, strokes or intersections; Connectivity analysis, e.g. of connected components
-
- G—PHYSICS
- G06—COMPUTING OR CALCULATING; COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
-
- G—PHYSICS
- G06—COMPUTING OR CALCULATING; COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
-
- G—PHYSICS
- G06—COMPUTING OR CALCULATING; COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/94—Hardware or software architectures specially adapted for image or video understanding
- G06V10/95—Hardware or software architectures specially adapted for image or video understanding structured as a network, e.g. client-server architectures
-
- G—PHYSICS
- G06—COMPUTING OR CALCULATING; COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V30/00—Character recognition; Recognising digital ink; Document-oriented image-based pattern recognition
- G06V30/10—Character recognition
- G06V30/19—Recognition using electronic means
- G06V30/191—Design or setup of recognition systems or techniques; Extraction of features in feature space; Clustering techniques; Blind source separation
- G06V30/19173—Classification techniques
-
- H—ELECTRICITY
- H04—ELECTRIC COMMUNICATION TECHNIQUE
- H04L—TRANSMISSION OF DIGITAL INFORMATION, e.g. TELEGRAPHIC COMMUNICATION
- H04L67/00—Network arrangements or protocols for supporting network services or applications
- H04L67/01—Protocols
- H04L67/10—Protocols in which an application is distributed across nodes in the network
- H04L67/104—Peer-to-peer [P2P] networks
- H04L67/1059—Inter-group management mechanisms, e.g. splitting, merging or interconnection of groups
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- Multimedia (AREA)
- Software Systems (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Evolutionary Computation (AREA)
- Artificial Intelligence (AREA)
- Computing Systems (AREA)
- General Health & Medical Sciences (AREA)
- Health & Medical Sciences (AREA)
- General Engineering & Computer Science (AREA)
- Computational Linguistics (AREA)
- Medical Informatics (AREA)
- Databases & Information Systems (AREA)
- Data Mining & Analysis (AREA)
- Mathematical Physics (AREA)
- Biophysics (AREA)
- Molecular Biology (AREA)
- Computer Networks & Wireless Communication (AREA)
- Biomedical Technology (AREA)
- Life Sciences & Earth Sciences (AREA)
- Signal Processing (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本发明公开一种基于联盟博弈解决联邦学习中数据异质性问题的方法,其步骤为:1)构建卷积神经网络;2)生成每个客户端的训练集;3)每个客户端将自己的标签类别集合和迭代时间通过基站通信发送给服务器;4)服务器得到最优联盟分区;5)服务器为最优联盟分区中的每个联盟分组;6)对卷积神经网络利用联邦学习进行协同训练。本发明能解决数据联邦学习中数据异质性导致的本地模型和全局模型的权重散度过大的问题,提升了联邦学习模型的性能,并利用客户端计算资源的差异化改进联邦学习算法,加快了联邦学习的收敛速度,可用于解决联邦学习中数据异质性问题。
Description
技术领域
本发明属于边缘计算技术领域,更进一步其涉及数据处理技术领域中的一种基于联盟博弈解决联邦学习中数据异质性问题的方法。
背景技术
越来越多的手机和平板电脑成为许多人使用的主要计算设备,这些设备上强大的传感器(包括摄像头、麦克风和GPS)可以获得前所未有的数据。这些数据被用于机器学习训练,但是将数据传输到服务器集中训练的传统机器学习模式会带来许多问题:通信开销过大,服务器计算资源有限以及隐私安全问题。联邦学习框架被提出以解决上述问题。
在传统机器学习中,数据分布在同一个机器上,并且假设数据是从同一个分布中独立地采样的,即数据独立同分布(Independently Identically Distribution,IID)。联邦学习是一种涉及多个设备的机器学习形式,每个客户端利用自己的数据集提供本地模型,服务器利用这些本地模型参数来创建一个混合模型,其目标是使混合模型的性能比单独的任何客户端都要好。由于设备归属于某个用户、企业、场景,因此其数据分布往往是差异极其大的,即数据是非独立同分布(Non-Independent and Identically Distributed,Non-IID)的。客户端个人数据的性质导致本地模型的变化,这种变化只适用于客户端自己的数据,但不适用于其他数据集,会损害混合模型的性能。非独立同分布的数据导致各个客户端上参数的更新方向差异很大,导致最后混合形成的全局模型参数与本地模型参数相差较多。非独立同分布的数据导致联邦学习模型精度明显下降的问题就是数据异质性问题。联邦学习聚合得到的全局模型参数和本地模型通过随机梯度下降方法更新得到的参数的差异被定义为权重散度。当权重散度越大,说明全局模型和本地模型差距越大,联邦学习性能越差。
H.Brendan McMahan等人在其发表的论文“Communication-Efficient LearningOf Deep Networks From Decentralized Data”(Artificial Intelligence andStatistics 54.(2017):1273-1282.)中提出了一种基于迭代模型平均的深度网络联合学习的方法。该方法的实现步骤是:第一步,选择一定比例的客户端。第二步,计算所选客户端持有的所有数据的损失梯度。第三步,所选客户端采用梯度下降法更新本地模型参数。第四步,服务器取所选客户端本地模型的加权平均值构成全局模型。该方法在处理独立同分布数据时可取得相对较优秀的结果,但是,该方法仍然存在的不足之处是,在面对非独立同分布数据时,无法降低由数据异质性导致的本地模型参数和全局模型权重差异。
杭州电子科技大学在其申请的专利文献“一种基于二阶导数解决联邦学习中数据不平衡问题的方法”(申请号:202110917450.7,申请公布号:CN 113691594 B)中提出了基于二阶导数解决联邦学习中数据异质性问题的方法,该方法的实现步骤是:第一步,云端服务器初始化全局模型和代理数据集。第二步,当第一轮全局迭代或者本轮迭代的测试精度相比于上轮迭代的测试精度小于某个阈值时,云端服务器通过计算损失函数关于全局模型参数的二阶导数来获得全局模型。第三步,云端服务器向边缘客户端下发全局模型、全局模型参数重要性权重和全局数据失衡信息。第四步,边缘客户端根据收到的全局模型、全局模型参数重要性权重和全局数据失衡信息构建正则项,将正则项添加到预先设置的优化目标上构成新的优化目标,从而减小本地模型与全局模型之间的差异以及降低大类对全局模型的贡献,然后利用本地数据在本地执行模型训练,并将训练好的本地模型上传给云端服务器。第五步,云端服务器利用接收到的本地模型更新全局模型。第六步,云端服务器判定全局模型精度是否达到预设值,未达到则返回第二步,达到则训练结束。该方法解决了非独立同分布数据带来的局部模型和全局模型的差异化对全局模型训练的影响。但是,该方法仍然存在的不足之处是:未充分利用客户端的计算资源,模型的收敛速度较慢。
发明内容
本发明的目的在于针对上述已有技术的不足,针对边缘计算场景中的数据异质性挑战,提出了一种基于联盟博弈解决联邦学习中数据异质性问题的方法,用于解决数据联邦学习中数据异质性导致的本地模型和全局模型的权重散度过大的问题,提升了联邦学习模型的性能,并利用客户端计算资源的差异化改进联邦学习算法,加快了联邦学习的收敛速度。
实现本发明目的的思路是:本发明采用平均推土距离(average earth mover’sdistance,EMD)来衡量客户端数据的Non-IID程度。EMD是归一化的从一个分布变为另一个分布的最小代价。Non-IID程度指的是客户端的数据集中每类数据的概率分布与理想数据集中每类数据的概率分布的差异程度。因此,将EMD定义为每类数据的概率分布距离。
Non-IID数据会使得局部模型与全局模型产生权重差异,因此联邦学习模型的准确性会受到数据分布偏度的影响。本发明根据EMD进行联盟博弈,将客户端聚集成几个联盟。通过迭代的方式计算出不同客户端结成同一联盟的Non-IID程度,Non-IID程度大的联盟会被淘汰,因此博弈结束时每个联盟的联盟数据集Non-IID程度很低,使得每个联盟的局部模型与全局模型的权重差异小,提高了联盟学习模型的精度。本发明根据客户端计算资源的差异化进行分组策略,将客户端分成几个小组,每个小组的客户端计算资源相似,每个小组在一轮训练中进行本地训练的轮次是不同的,计算资源多的小组的本地训练轮次是计算资源少的小组的本地训练轮次的整倍数。因此计算资源多的客户端可以充分利用计算资源,不会白白浪费时间等待计算资源少的客户端完成本地训练,加快了联盟学习模型的收敛速度。
为实现上述目的,本发明具体实现步骤包括如下:
步骤1,构建卷积神经网络:
步骤1.1,为每个客户端搭建一个结构相同的卷积神经网络,该网络的各层串联组成,其结构依次为:输入层,第一卷积层,第一池化层,第二卷积层,第二池化层,全连接层;
步骤1.2,设置卷积神经网络的超参数:将输入层神经元的个数设置为28×28;将第一、第二卷积层的卷积核分别设置为5×5,3×3,卷积核个数均设置为64,滑动步长均设置为1;将第一、第二池化层的池化窗口尺寸均设置为2×2,滑动步长均设置为2;激活函数均采用ReLU实现;将全连接层输出神经元的个数设置为10,激活函数采用Softmax实现;
步骤2,生成每个客户端的训练集:
步骤2.1,将每个客户端自己的手写数字图片组成该客户端的样本集;对每个客户端样本集的每张手写数字图片标注上标签;
步骤2.2,对每个样本集标注后的每张图片进行均值方差归一化处理,处理后的数据符合标准正太分布,将每个样本集归一化后的图片组成该客户端的训练集;
步骤3,每个客户端将自己的标签类别集合和迭代时间通过基站通信发送给服务器;
步骤4,服务器得到最优联盟分区:
步骤4.1,将每个客户端作为一个联盟,所有联盟形成一个联盟分区;
步骤4.2,按照下式,生成联盟分区中每个联盟的效益:
其中,Vj表示联盟分区中的第j个联盟的效益,log(.)表示以10为底的对数操作,|.|表示取绝对值操作,nm表示联盟分区中的第j个联盟内所有客户端训练集标签为“m”类的手写数字图片的张数,Dj表示联盟分区中的第j个联盟内所有客户端的训练集中手写数字图片的总数;
步骤4.3,利用联盟博弈形成算法,不断地让客户端离开原有联盟,加入其他联盟,当两个联盟的效益之和增大,所有联盟形成一个新的联盟分区;
步骤4.4,当任意一个客户端加入任何联盟都无法使原有联盟和新加入的联盟的效益之和变大,不再产生新的联盟分区时,联盟博弈形成算法停止迭代;删除联盟分区中空的联盟,将剩余联盟组成最优联盟分区;
步骤5,服务器为最优联盟分区中的每个联盟分组:
步骤5.1,服务器根据每个客户端上传的迭代时间找到每个联盟中迭代时间最小的客户端,作为每个联盟的联盟领导者;
步骤5.2,服务器计算每个客户端的本地训练轮次;
步骤5.3,每个联盟内本地训练轮次相同的客户端形成一个小组;
步骤6,对卷积神经网络利用联邦学习进行协同训练:
步骤6.1,服务器向最优联盟分区中的每个客户端下发一个相同的卷积神经网络参数矩阵;
步骤6.2,每个客户端用接收到的卷积神经网络参数矩阵更新自己的卷积神经网络;
步骤6.3,每个客户端将每个训练集输入到其对应的卷积神经网络中,使用SGD梯度下降算法,计算每个客户端的卷积神经网络迭代更新10次后的卷积神经网络参数矩阵,并将该卷积神经网络参数矩阵上传到联盟领导者;
步骤6.4,每个联盟的联盟领导者收到客户端的卷积神经网络参数矩阵,对其接收的具有不同特征的卷积神经网络的参数矩阵取平均值;
步骤6.5,每个联盟的联盟领导者判断是否接收到了所有小组的客户端的卷积神经网络的参数矩阵,若是,将卷积神经网络的参数矩阵平均值发送给服务器后执行步骤6.6;否则,将卷积神经网络的参数矩阵平均值发送给客户端后执行步骤6.2;
步骤6.6,服务器对其接收的所有联盟的具有不同特征的卷积神经网络的参数矩阵取平均值,再将该平均值下发给每个联盟领导者,每个联盟领导者再将参数矩阵平均值下发给每个联盟的客户端;
步骤6.7,判断服务器是否已经执行步骤6.6中取平均值操作500次,若是,则结束协同训练,用参数矩阵平均值更新服务器的卷积神经网络,执行步骤7.1;否则,执行步骤6.2;
步骤7,预测服务器的手写数字图片的类别:
步骤7.1,采用与步骤2相同的预处理方法,对服务器的手写数字图片进行处理,得到服务器的测试集。
步骤7.2,将服务器测试集的所有图像输入到服务器的卷积神经网络中,输出预测的手写数字识别结果。
本发明与现有技术相比具有如下优点:
第一,本发明利用EMD对Non-IID程度高的客户端进行联盟博弈,降低了各个联盟的Non-IID程度,使得每个联盟的局部模型与全局模型的权重差异小,克服了现有技术中数据异质性导致局部模型与全局模型的权重差异大的缺点,使得本发明具有联盟学习模型的精度高的优点。
第二,本发明利用计算资源差异化对客户端进行分组策略,根据该分组为客户端确定本地训练轮次,充分利用了客户端的计算资源,克服了现有技术中客户端计算资源浪费的缺点,使得本发明具有联邦学习模型收敛速度快的优点。
附图说明
图1为本发明的流程图;
图2为本发明的仿真图。
具体实施方式
以下结合附图1和实施例,对本发明的实现步骤做进一步的描述。
本发明的实施例有30个客户端,一个服务器,客户端用自己的手写数字图片生成的训练集训练一个可用于识别所有手写数字的卷积神经网络,手写数字图片上只含有一个数字,训练好的卷积神经网络不光可以识别出30个训练集中的手写数字,还可以识别出30个训练集之外的手写数字图片。
卷积神经网络特别适合处理像图片、视频、音频、语言文字等数据,目前卷积神经网络是图像识别领域优势最为显著的神经网络结构。通过在每个客户端处构建一个结构相同卷积神经网络,然后将训练集中每张图片输入到卷积神经网络中学习每张图片的特征,再更新卷积神经网络的参数,每个客户端的卷积神经网络的参数是不同的。通过学习完训练集中每张图片的特征,每个客户端卷积神经网络性能较好。在服务器处将30个客户端的学习特征进行汇总,最后在服务器处构造的卷积神经网络学习了所有客户端的训练集中的手写数字图片的特征,服务器处的卷积神经网络不仅可识别30个训练集中的手写数字,还可精确识别所有手写数字图片。
步骤1,构建卷积神经网络。
步骤1.1,在每个客户端搭建一个结构相同的卷积神经网络,该网络各层串联组成,其结构依次为:输入层,第一卷积层,第一池化层,第二卷积层,第二池化层,全连接层。
步骤1.2,设置卷积神经网络的超参数:将输入层神经元的个数设置为28×28。将第一、第二卷积层的卷积核分别设置为5×5,3×3,卷积核个数均设置为64,滑动步长均设置为1。将第一、第二池化层的池化窗口尺寸均设置为2×2,滑动步长均设置为2。激活函数均采用Relu实现。全连接层输出神经元的个数设置为10,激活函数采用Softmax实现。每个客户端的每层网络的超参数相同,超参数指的是每层网络的神经元个数或网络尺寸,每个客户端的每层网络的参数不同,参数指的是每层网络的权重矩阵。
步骤2,生成每个客户端的训练集。
步骤2.1,将每个客户端自己的手写数字图片组成该客户端的样本集,本发明实施例中共组成了30个样本集。由于每个客户端不光拥有手写数字图片,还拥有别的图片,例如花卉图片、人像图片,而本发明只选取每个客户端自己的手写数字图片组成其样本集,且每个手写数字图片中只包含一个手写数字。对每个客户端样本集的每张手写数字图片标注上标签,由于本发明的样本集是由手写数字图片组成,故其图片标签分别为:0,1,2,3,4,5,6,7,8,9。例如一个客户端包含3张图片,第一张图片上只包含一个手写的数据0,所以将该图片的标签标注为0。第二张图片上只包含一个手写的数据1,所以将该图片的标签标注为1。第三张图片上只包含一个手写的数据2,所以将该图片的标签标注为2。
步骤2.2,对每个样本集标注后的每张图片进行均值方差归一化处理,处理后的数据符合标准正太分布,将每个样本集归一化后的图片组成该客户端的训练集。
步骤3,每个客户端向服务器发送标签类别集合和迭代时间。
每个客户端将自己的标签类别集合和迭代时间通过基站通信发送给服务器。标签类别集合指的是每个训练集中每类手写数字图片有多少张,迭代时间指的是每个客户端使用训练集中每个训练数据执行一遍卷积神经网络得到该训练数据的预测标签的时间。第1个客户端有300张手写数字图片,其中,标签为“0”类的手写数字图片有100张,标签为“1”类的手写数字图片有100张,标签为“5”类的手写数字图片有100张。据此组成第1个客户端的标签类别集合为{100,100,0,0,0,100,0,0,0,0}。第2个客户端有300张手写数字图片,标签为“7”类的手写数字图片有100张,标签为“8”类的手写数字图片有100张,标签为“9”类的手写数字图片有100张,据此组成第2个客户端的标签类别集合为{0,0,0,0,0,0,0,100,100,100}。
步骤4,服务器得到最优联盟分区。
步骤4.1,将每个客户端作为一个联盟,所有联盟形成一个联盟分区。本发明实施例中的30个客户端形成30个联盟,第一个联盟包含第1个客户端,第二个联盟包含第2个客户端,第30个联盟包含第30个客户端,这30个联盟形成联盟分区。
步骤4.2,按照下式,生成联盟分区中每个联盟的效益:
其中,Vj表示联盟分区中的第j个联盟的效益,log(.)表示以10为底的对数操作,|.|表示取绝对值操作,nm表示联盟分区中的第j个联盟内所有客户端训练集标签为“m”类的手写数字图片的张数,Dj表示联盟分区中的第j个联盟内所有客户端的训练集中手写数字图片的总数。本发明实施例中的第一个联盟的效益V1≈0.35;第二个联盟的效益V2≈0.35。
步骤4.3,利用联盟博弈形成算法,不断地让客户端离开原有联盟,加入其他联盟,当两个联盟的效益之和增大,所有联盟形成一个新的联盟分区。本发明实施例中的第一个联盟的效益和第二个联盟的效益之和为0.7,当第1个客户端离开第一个联盟,加入第二个联盟,第一个联盟的效益变为V1≈0,第二个联盟的效益变为V2≈0.88,第一个联盟的效益和第二个联盟的效益之和变为0.88,两个联盟的效益之和增大,形成了一个新的联盟分区,此联盟分区包含30个联盟,第一个联盟为空集,第二个联盟包含第1个客户端和第2个客户端,第30个联盟包含第30个客户端。
所述联盟博弈形成算法是根据Hui Yilong等人在其发表的论文“A GameTheoretic Scheme for Optimal Access Control in Heterogeneous VehicularNetworks”(IEEE transactions on intelligent transportation systems 2019,20(12):4590-4603.)中提出的一种联盟博弈形成算法。
步骤4.4,当任意一个客户端加入任何联盟都无法使原有联盟和新加入的联盟的效益之和变大,不再产生新的联盟分区时,联盟博弈形成算法停止迭代。联盟博弈形成算法停止时,本发明实施例中的联盟分区包含30个联盟,第二个联盟包含第1个客户端、第2个客户端、第3个客户端、第4个客户端、第5个客户端,第三个联盟包含第6个客户端、第7个客户端、第8个客户端、第9个客户端、第10个客户端、第11个客户端,第九个联盟包含第12个客户端、第13个客户端、第14个客户端、第15个客户端、第16个客户端、第17个客户端、第18个客户端、第19个客户端、第20个客户端、第21个客户端,第十个联盟包含第22个客户端、第23个客户端、第24个客户端、第25个客户端、第26个客户端、第27个客户端、第28个客户端、第29个客户端、第30个客户端,其余联盟为空集。
步骤4.5,删除联盟分区中空的联盟,将剩余联盟组成最优联盟分区。本发明实施例中的最优联盟分区包含4个联盟:第二个联盟、第三个联盟、第九个联盟、第十个联盟。
步骤5,服务器为最优联盟分区中的每个联盟分组
步骤5.1,服务器根据每个客户端上传的迭代时间找到每个联盟中迭代时间最小的客户端,作为每个联盟的联盟领导者。本发明实施例中第二个联盟的第1个客户端的迭代时间为0.2s,第2个客户端的迭代时间为0.25s,第3个客户端的迭代时间为0.3s,第4个客户端的迭代时间为0.35s,第5个客户端的迭代时间为0.6s,第二个联盟的联盟领导者是第1个客户端。
步骤5.2,根据下式,服务器计算每个客户端的本地训练轮次:
其中,μi代表第i个客户端的本地训练轮次,代表向下取整符号,代表每个联盟内迭代时间最大的客户端,代表每个联盟内迭代时间最大的客户端的迭代时间,t i代表第i个客户端的迭代时间。本发明实施例中第二个联盟的第1个客户端的本地训练轮次是30次,第2个客户端的本地训练轮次是20次,第3个客户端的本地训练轮次是20次,第4个客户端的本地训练轮次是10次,第5个客户端的本地训练轮次是10次。
步骤5.3,每个联盟内本地训练轮次相同的客户端形成一个小组。本发明实施例中第二个联盟形成3个小组,第一个小组包含第1个客户端,第二个小组包含第2个客户端和第3个客户端,第三个小组包含第4个客户端和第5个客户端。
步骤6,对卷积神经网络利用联邦学习进行协同训练。
步骤6.1,服务器向最优联盟分区中的每个客户端下发一个相同的卷积神经网络参数矩阵。
步骤6.2,每个客户端用接收到的卷积神经网络参数矩阵更新自己的卷积神经网络。
步骤6.3,每个客户端将每个训练集输入到其对应的卷积神经网络中,使用SGD梯度下降算法,计算每个客户端的卷积神经网络迭代更新10次后的卷积神经网络参数矩阵,并将该卷积神经网络参数矩阵上传到联盟领导者。
步骤6.4,每个联盟的联盟领导者收到客户端的卷积神经网络参数矩阵,对其接收的具有不同特征的卷积神经网络的参数矩阵取平均值。
步骤6.5,每个联盟的联盟领导者判断是否接收到了所有小组的客户端的卷积神经网络的参数矩阵,若是,将卷积神经网络的参数矩阵平均值发送给服务器,执行步骤6.6;反之,将卷积神经网络的参数矩阵平均值发送给步骤6.4中向自己发送了参数矩阵的客户端,执行步骤6.2。本发明实施例中,第二个联盟的联盟领导者在第2s收到第1个小组的参数矩阵,将第1个小组的参数矩阵取平均值后发送给第1个客户端,在第3s收到第2个小组的参数矩阵,将第2个小组的参数矩阵取平均值后发送给第2个客户端和第3个客户端,在第4s收到第1个小组的参数矩阵,将第1个小组的参数矩阵取平均值后发送给第1个客户端。在第6s收到第1个小组、第2个小组和第3个小组的参数矩阵,将第1个小组、第2个小组和第3个小组的参数矩阵取平均值后发送给服务器。
步骤6.6,服务器对其接收的所有联盟的具有不同特征的卷积神经网络的参数矩阵取平均值,再将该平均值下发给每个联盟领导者,每个联盟领导者再将参数矩阵平均值下发给每个联盟的客户端。
步骤6.7,判断服务器是否已经执行步骤6.6中取平均值操作500次,若是,则结束协同训练,用参数矩阵平均值更新服务器的卷积神经网络,执行步骤7.1;否则,执行步骤6.2;
步骤7,预测服务器的手写数字图片的类别:
步骤7.1,采用与步骤2相同的预处理方法,对服务器的手写数字图片进行处理,得到服务器的测试集。
步骤7.2,将服务器测试集的所有图像输入到服务器的卷积神经网络中,输出预测的手写数字识别结果。
下面结合仿真实验对本发明的效果做进一步的说明:
1.仿真实验条件:
本发明仿真实验的平台为:Windows 11操作系统和PyCharm2021。
本发明仿真实验所使用的训练集和测试集为MNIST数据集种的训练集和测试集。
2.仿真内容及其结果分析:
本发明仿真实验是采用本发明和两个现有技术(联邦平均方法和二阶导数联邦学习方法方法)分别对手写数字图片进行识别。
在仿真实验中,采用的两个现有技术是指:
现有技术联邦平均方法是指:H.Brendan McMahan等人在其发表的论文“Communication-Efficient Learning Of Deep Networks From Decentralized Data”(Artificial Intelligence and Statistics 54.(2017):1273-1282.)中提出的基于迭代模型平均的深度网络联合学习的方法,简称联邦平均方法。
现有技术二阶导数联邦学习方法是指:杭州电子科技大学在其申请的专利文献“一种基于二阶导数解决联邦学习中数据不平衡问题的方法”(申请号:202110917450.7,申请公布号:CN 113691594 B)中提出了基于二阶导数解决联邦学习中数据异质性问题的方法,简称二阶导数联邦学习方法。
下面结合图2的仿真图对本发明的效果做进一步的描述。
图2为本发明仿真实验三种方法分别对测试集的手写数字图片的识别精度的对比图。所述的识别精度指的是测试集的10000张手写数字图片被正确识别的比例。
图2中的横坐标表示服务器和客户端利用联邦学习协同训练卷积神经网络的训练轮次,纵坐标表示用训练好的卷积神经网络识别测试集手写数字类别的精度。图2中的实线代表采用本发明的方法获得不同训练轮次下的精度曲线,图2中的点实线代表采用联邦平均方法获得不同训练轮次下的精度曲线,图2中的虚线代表采用二阶导数联邦学习方法获得不同训练轮次下的精度曲线。
由图2可以看出,对任一固定的训练轮次下,本发明获得的精度相比于其他两个方法获得精度总是可以为服务器的卷积神经网络带来最高的识别精度,主要是因为联邦平均方法在面对非独立同分布数据时,无法降低由数据异质性导致的本地模型参数和全局模型权重差异;而对于二阶导数联邦学习方法,未充分利用客户端的计算资源,在固定的训练轮次下,模型的收敛速度较慢。
以上仿真实验表明:本发明方法可以通过联盟博弈降低客户端训练集之间的数据异质性,解决本地模型和全局模型的权重散度过大的问题,提升了联邦学习模型的性能,并利用客户端计算资源的差异化改进联邦学习算法,加快了联联邦学习的收敛速度,是一种非常实用的解决联邦学习中数据异质性问题方法。
Claims (4)
1.一种基于联盟博弈解决联邦学习中数据异质性问题的方法,其特征在于,利用联邦学习在服务器和客户端协同训练卷积神经网络;该方法的具体步骤包括如下:
步骤1,构建卷积神经网络:
步骤1.1,为每个客户端搭建一个结构相同的卷积神经网络,该网络的各层串联组成,其结构依次为:输入层,第一卷积层,第一池化层,第二卷积层,第二池化层,全连接层;
步骤1.2,设置卷积神经网络的超参数:将输入层神经元的个数设置为28×28;将第一、第二卷积层的卷积核分别设置为5×5,3×3,卷积核个数均设置为64,滑动步长均设置为1;将第一、第二池化层的池化窗口尺寸均设置为2×2,滑动步长均设置为2;激活函数均采用ReLU实现;将全连接层输出神经元的个数设置为10,激活函数采用Softmax实现;
步骤2,生成每个客户端的训练集:
步骤2.1,将每个客户端自己的手写数字图片组成该客户端的样本集;对每个客户端样本集的每张手写数字图片标注上标签;
步骤2.2,对每个样本集标注后的每张图片进行均值方差归一化处理,处理后的数据符合标准正太分布,将每个样本集归一化后的图片组成该客户端的训练集;
步骤3,每个客户端将自己的标签类别集合和迭代时间通过基站通信发送给服务器;
步骤4,服务器得到最优联盟分区:
步骤4.1,将每个客户端作为一个联盟,所有联盟形成一个联盟分区;
步骤4.2,计算联盟分区中每个联盟的效益;
步骤4.3,利用联盟博弈形成算法,不断地让客户端离开原有联盟,加入其他联盟,当两个联盟的效益之和增大,所有联盟形成一个新的联盟分区;
步骤4.4,当任意一个客户端加入任何联盟都无法使原有联盟和新加入的联盟的效益之和变大,不再产生新的联盟分区时,联盟博弈形成算法停止迭代;删除联盟分区中空的联盟,将剩余联盟组成最优联盟分区;
步骤5,服务器为最优联盟分区中的每个联盟分组:
步骤5.1,服务器根据每个客户端上传的迭代时间,找到每个联盟中迭代时间最短的客户端,作为每个联盟的联盟领导者;
步骤5.2,服务器计算每个客户端的本地训练轮次;
步骤5.3,将每个联盟内本地训练轮次相同的客户端组成一个小组;
步骤6,对卷积神经网络利用联邦学习进行协同训练:
步骤6.1,服务器向最优联盟分区中的每个客户端下发一个相同的卷积神经网络参数矩阵;
步骤6.2,每个客户端用接收到的卷积神经网络参数矩阵,更新自己的卷积神经网络;
步骤6.3,每个客户端将每个训练集输入到其对应的卷积神经网络中,使用SGD梯度下降算法,计算每个客户端的卷积神经网络迭代更新10次后的卷积神经网络参数矩阵,并将该卷积神经网络参数矩阵上传到联盟领导者;
步骤6.4,每个联盟的联盟领导者收到客户端的卷积神经网络参数矩阵,对其接收的具有不同特征的卷积神经网络的参数矩阵取平均值;
步骤6.5,每个联盟的联盟领导者判断是否接收到了所有小组的客户端的卷积神经网络的参数矩阵,若是,将卷积神经网络的参数矩阵平均值发送给服务器后执行步骤6.6;否则,将卷积神经网络的参数矩阵平均值发送给客户端后执行步骤6.2;
步骤6.6,服务器对其接收的所有联盟的具有不同特征的卷积神经网络的参数矩阵取平均值,再将该平均值下发给每个联盟领导者,每个联盟领导者再将参数矩阵平均值下发给每个联盟的客户端;
步骤6.7,判断服务器是否已经执行步骤6.6中取平均值操作5000次,若是,则结束协同训练,用参数矩阵平均值更新服务器的卷积神经网络,执行步骤7.1;否则,执行步骤6.2;
步骤7,预测服务器的手写数字图片的类别:
步骤7.1,采用与步骤2相同的预处理方法,对服务器的手写数字图片进行处理,得到服务器的测试集;
步骤7.2,将服务器测试集的所有图像输入到服务器的卷积神经网络中,输出预测的手写数字识别结果。
2.根据权利要求1所述的基于联盟博弈解决联邦学习中数据异质性问题的方法,其特征在于,步骤2.1中所述的标签包括:0,1,2,3,4,5,6,7,8,9。
3.根据权利要求1所述的基于联盟博弈解决联邦学习中数据异质性问题的方法,其特征在于,步骤5.2中所述计算联盟分区中每个联盟的效益是由下式得到的:
其中,Vj表示联盟分区中的第j个联盟的效益,log(.)表示以10为底的对数操作,|.|表示取绝对值操作,nm表示联盟分区中的第j个联盟内所有客户端训练集标签为“m”类的手写数字图片的张数,Dj表示联盟分区中的第j个联盟内所有客户端的训练集中手写数字图片的总数。
4.根据权利要求1所述的基于联盟博弈解决联邦学习中数据异质性问题的方法,其特征在于,步骤5.2中所述计算每个客户端的本地训练轮次是由下式得到的:
其中,μi表示第i个客户端的本地训练轮次,表示向下取整操作,表示每个联盟内迭代时间最大的客户端,表示每个联盟内迭代时间最大的客户端的迭代时间,ti表示第i个客户端的迭代时间。
Priority Applications (1)
| Application Number | Priority Date | Filing Date | Title |
|---|---|---|---|
| CN202310167065.4A CN116259057B (zh) | 2023-02-27 | 2023-02-27 | 基于联盟博弈解决联邦学习中数据异质性问题的方法 |
Applications Claiming Priority (1)
| Application Number | Priority Date | Filing Date | Title |
|---|---|---|---|
| CN202310167065.4A CN116259057B (zh) | 2023-02-27 | 2023-02-27 | 基于联盟博弈解决联邦学习中数据异质性问题的方法 |
Publications (2)
| Publication Number | Publication Date |
|---|---|
| CN116259057A CN116259057A (zh) | 2023-06-13 |
| CN116259057B true CN116259057B (zh) | 2025-10-28 |
Family
ID=86680547
Family Applications (1)
| Application Number | Title | Priority Date | Filing Date |
|---|---|---|---|
| CN202310167065.4A Active CN116259057B (zh) | 2023-02-27 | 2023-02-27 | 基于联盟博弈解决联邦学习中数据异质性问题的方法 |
Country Status (1)
| Country | Link |
|---|---|
| CN (1) | CN116259057B (zh) |
Families Citing this family (2)
| Publication number | Priority date | Publication date | Assignee | Title |
|---|---|---|---|---|
| CN116582447B (zh) * | 2023-06-20 | 2025-09-26 | 杭州电子科技大学 | 一种基于边缘计算网关的IoT网络协议识别方法 |
| CN116502709A (zh) * | 2023-06-26 | 2023-07-28 | 浙江大学滨江研究院 | 一种异质性联邦学习方法和装置 |
Family Cites Families (6)
| Publication number | Priority date | Publication date | Assignee | Title |
|---|---|---|---|---|
| CN111369042B (zh) * | 2020-02-27 | 2021-09-24 | 山东大学 | 一种基于加权联邦学习的无线业务流量预测方法 |
| CN112001321B (zh) * | 2020-08-25 | 2024-06-14 | 商汤国际私人有限公司 | 网络训练、行人重识别方法及装置、电子设备和存储介质 |
| US12307364B2 (en) * | 2020-09-25 | 2025-05-20 | Qualcomm Incorporated | Federated learning with varying feedback |
| CN112101946B (zh) * | 2020-11-20 | 2021-02-19 | 支付宝(杭州)信息技术有限公司 | 联合训练业务模型的方法及装置 |
| CN113988314B (zh) * | 2021-11-09 | 2024-05-31 | 长春理工大学 | 一种选择客户端的分簇联邦学习方法及系统 |
| CN115577804A (zh) * | 2022-10-22 | 2023-01-06 | 北京工业大学 | 基于合作博弈和知识蒸馏的个性化联邦学习方法 |
-
2023
- 2023-02-27 CN CN202310167065.4A patent/CN116259057B/zh active Active
Non-Patent Citations (1)
| Title |
|---|
| RCFL: Redundancy-Aware Collaborative Federated Learning in Vehicular Networks;Yilong Hui;《IEEE TRANSACTIONS ON INTELLIGENT TRANSPORTATION SYSTEMS》;20240630;第25卷(第6期);第1-15页 * |
Also Published As
| Publication number | Publication date |
|---|---|
| CN116259057A (zh) | 2023-06-13 |
Similar Documents
| Publication | Publication Date | Title |
|---|---|---|
| CN114943345B (zh) | 基于主动学习和模型压缩的联邦学习全局模型训练方法 | |
| CN114091667B (zh) | 一种面向非独立同分布数据的联邦互学习模型训练方法 | |
| CN117523291A (zh) | 基于联邦知识蒸馏和集成学习的图像分类方法 | |
| CN109948029B (zh) | 基于神经网络自适应的深度哈希图像搜索方法 | |
| WO2022121289A1 (en) | Methods and systems for mining minority-class data samples for training neural network | |
| CN113988314B (zh) | 一种选择客户端的分簇联邦学习方法及系统 | |
| CN115032682B (zh) | 一种基于图论的多站台地震震源参数估计方法 | |
| US12518168B2 (en) | Training and application method apparatus system and storage medium of neural network model | |
| CN116259057B (zh) | 基于联盟博弈解决联邦学习中数据异质性问题的方法 | |
| CN117994635B (zh) | 一种噪声鲁棒性增强的联邦元学习图像识别方法及系统 | |
| CN110738309A (zh) | Ddnn的训练方法和基于ddnn的多视角目标识别方法和系统 | |
| CN113987236B (zh) | 基于图卷积网络的视觉检索模型的无监督训练方法和装置 | |
| CN112766603B (zh) | 一种交通流量预测方法、系统、计算机设备及存储介质 | |
| CN115577797B (zh) | 一种基于本地噪声感知的联邦学习优化方法及系统 | |
| CN113516163B (zh) | 基于网络剪枝的车辆分类模型压缩方法、装置及存储介质 | |
| CN111104831B (zh) | 一种视觉追踪方法、装置、计算机设备以及介质 | |
| CN110794965B (zh) | 一种基于深度强化学习的虚拟现实语言任务卸载方法 | |
| CN114925848A (zh) | 一种基于横向联邦学习框架的目标检测方法 | |
| US20250078442A1 (en) | Method for predicting channel based on image processing and machine learning | |
| CN113627333A (zh) | 一种基于个性化联邦学习的分心驾驶行为识别方法 | |
| CN117454413A (zh) | 一种基于加权蒸馏的异构联邦学习及恶意客户端防御方法 | |
| CN117710312A (zh) | 基于联邦学习和YOLOv5的输电网异物检测方法 | |
| CN115272774A (zh) | 基于改进自适应差分进化算法的对抗样本攻击方法及系统 | |
| CN120146219A (zh) | 一种面向资源异构的分层联邦学习方法 | |
| CN115311449A (zh) | 基于类重激活映射图的弱监督图像目标定位分析系统 |
Legal Events
| Date | Code | Title | Description |
|---|---|---|---|
| PB01 | Publication | ||
| PB01 | Publication | ||
| SE01 | Entry into force of request for substantive examination | ||
| SE01 | Entry into force of request for substantive examination | ||
| GR01 | Patent grant | ||
| GR01 | Patent grant |