本发明涉及联邦学习,尤其涉及一种基于知识蒸馏和锐度感知最小化的个性化联邦学习方法。
背景技术:
1、联邦学习是一种分布式机器学习方法,旨在应对数据隐私和安全性方面的挑战。传统的集中式机器学习模型通常集中存储数据,但这种做法可能会导致隐私泄露和数据安全性问题。联邦学习通过在本地设备上训练模型,并将更新的模型参数上传至服务器进行聚合,从而避免了数据离开本地设备,有效地保护了用户的隐私。目前,联邦学习已在金融、医疗和工业等领域已得到广泛应用。
2、然而,联邦学习面临着数据异构和模型异构的挑战,即不同设备上的数据分布和模型架构可能存在显著差异。尽管联邦学习通常致力于训练一个全局模型以适应所有客户端,但服务器聚合各个差异较大的客户端模型可能导致聚合后的全局模型出现客户端漂移的情况,从而影响客户端模型的性能和收敛性。
3、为了解决这一问题,pfl(personalized federated learning,个性化联邦学习)应运而生。pfl旨在保持数据隐私的同时,为每个客户端定制个性化的模型,以提高模型在特定数据分布上的表现。
4、目前,现有技术中针对数据异构提出的pfl方法大致可以分为两类:全局模型个性化更新和客户端本地构建个性化模型。在全局模型个性化更新策略中通过本地适应性步骤对全局模型进行调整,以满足每个客户端的特定需求。例如在本地训练引入正则化项,减少局部模型和本地模型之间的差异,或通过元学习使全局共享模型在不暴露私有数据的情况下适应局部私有数据。在客户端本地构建个性化模型策略中,为每个客户端训练独立的个性化模型。例如通过基于客户端相似性的聚类方法,将客户端分成多个组,并为每个组单独训练更适应该组数据分布的模型,或通过知识蒸馏方法,每个客户端通过全局共享的公共数据集获取软值预测,利用服务器聚合得到的软值预测改进客户端模型。
5、上述现有技术中针对数据异构提出的pfl方法的缺点包括:
6、(1)现有的全局模型个性化方法通常假设客户端和服务器采用相同的模型结构。然而,当面对模型异构情况时,这种假设可能不适用。此外,该方案中的方法都需要传输模型参数,这会导致通信开销与模型大小成正比,进而影响通信效率,而且存在梯度被攻击的风险。
7、(2)现有的客户端学习个性化模型的方案,基于聚类的方法可能受到聚类误差的影响,而且会产生较大的计算和通信成本。基于知识蒸馏的方法在聚合全局知识时采用简单加权平均方法,会降低聚合后的知识质量,进而影响模型性能。
8、(3)现有的技术方案普遍使用经验风险最小化作为局部优化器,而在客户端训练模型的过程中,很容易陷入陡峭的局部极小值点,出现过拟合现象。这会导致模型的收敛速度减慢且泛化能力下降。
技术实现思路
1、本发明提供了一种基于知识蒸馏和锐度感知最小化的个性化联邦学习方法,在保证避免暴露敏感信息的前提下,提升模型性能和收敛效果。
2、为了实现上述目的,本发明采取了如下技术方案。
3、一种基于知识蒸馏和锐度感知最小化的个性化联邦学习方法,包括:
4、服务器生成全局共享无标签伪数据集,将无标签伪数据集下发给各个客户端;
5、客户端利用本地数据训练本地模型,采用锐度感知最小化方法对本地模型进行优化,得到优化后的本地模型wc;
6、客户端利用本地数据集和所述无标签伪数据集生成本地知识,所述本地知识包括客户端本地各类原型和客户端本地模型对无标签伪数据样本的预测值
7、服务器对各个客户端的本地知识中的本地原型进行加权平均操作,得到全局原型对各个客户端的本地知识中的logits依据权重系数进行加权运算得到全局将由全局原型和全局构成的全局知识下发给各个客户端;
8、客户端根据本地logits和全局知识中的全局对本地模型wc进行更新。
9、优选地,所述的服务器生成全局共享无标签伪数据集,将无标签伪数据集下发给各个客户端,包括:
10、服务器对本地数据的随机子集进行加权计算,得到全局共享的伪数据样本其中xt是从大小为n的子集中随机选取的一个样本,向伪数据样本xp添加从均值为0、标准差为σ的高斯分布中随机采样得到的噪声n(0,σ2),得到处理后的伪数据样本为
11、服务器将由伪数据样本构成的全局共享数据集dp下发给各个客户端。
12、优选地,所述的客户端利用本地数据训练本地模型,采用锐度感知最小化方法对本地模型进行优化,得到优化后的本地模型wc,包括:
13、在客户端利用本地数据进行本地模型ω训练的过程中,锐度感知最小化方法的优化目标函数为:
14、
15、其中[fsam(ω)―f(ω)]代表锐度,公式(1)中其中f为损失函数,fsam为在扰动半径为ρ的空间内的损失最大值,δ为扰动向量,计算公式为为损失函数关于模型参数ω的梯度,‖·‖2为l2范数;
16、本地模型的更新计算公式为:其中η为学习率,ωk为第k轮本地模型参数;
17、对所述优化目标函数进行权重调整,调整后的目标函数为:
18、
19、其中γ为用于控制目标函数寻找局部极小值与寻找平坦区域的权重系数。模型更新的计算公式调整为:
20、在初始轮次未生成全局知识时,参与方客户端本地模型训练的目标函数为:
21、和yi为本地数据样本i的预测值和真实值,在后续的通信轮次中参与方模型训练目标函数为:其中lmse为均方误差损失函数,和分别为客户端c的本地样本i标签为j的特征表示与全局原型,为超参数,pi为客户端本地数据样本i通过特征提取器获取的特征表示;
22、训练结束后,得到优化后的本地模型wc。
23、优选地,所述的客户端利用本地数据集和所述无标签伪数据集生成本地知识,所述本地知识包括客户端本地各类原型和客户端本地模型对无标签伪数据样本的预测值包括:
24、客户端通过模型特征提取器获取本地训练数据的特征表示其中xi为客户端c的训练数据,j为xi样本的标签,为模型参数为ωc的特征提取器,通过公式q计算出本地各个类的平均表示即本地类原型,其中为客户端c中数据集标签为j的数量;
25、客户端通过模型分类器在公共数据集输出其中(·)为客户端c参数为的分类器,logits为公共数据集样本经本地模型分类器输出得到样本属于各个类别的预测值,为公共数据集样本的特征表示;
26、客户端根据公共数据集中的样本与全局原型的近似程度,分配权重其中为公共数据集样本的权重,led为计算欧式距离。为公共数据集样本生成的伪标签,伪标签通过本地模型分类器获得为公共数据集样本的分类预测概率值,通过argmax将概率最高的类别作为公共数据集样本的类标签,为公共数据集中伪标签为的样本的特征表示,为全局知识中标签为的原型;
27、客户端本地各个类的平均表示和本地构成客户端的本地知识,客户端将本地知识及对应的权重一同发送给服务器;
28、优选地,所述的服务器对各个客户端的本地知识中的本地原型进行加权平均操作,得到全局原型对各个客户端的本地知识中的logits依据权重系数进行加权运算得到全局将由全局原型和全局构成的全局知识下发给各个客户端,包括:
29、服务器接收所有参与方客户端上传的本地知识中的本地类原型后,对本地原型进行加权平均操作得到全局原型其中cj为拥有样本标签为j的客户端,为客户端c中类别为j的原型,n为类标签数量;
30、服务器对本地模型输出的软值预测logits依据权重系数进行加权运算得到全局其中权重根据欧式距离计算得到,距离越小权重越高,为遍历客户端u的权重进行再次计算得到权重,和分别为dp公共数据集样本在客户端c的权重系数和logits,m为公共数据集大小;
31、服务器聚合得到的全局原型和全局构成全局知识,将全局知识下发给各个客户端。
32、优选地,所述的客户端根据本地logits和全局知识中的全局对本地模型wc进行更新,包括:
33、各个客户端利用公共数据集、本地和全局logits对本地模型wc进行训练和更新,目标函数为:其中lkl为kullback-leibler散度损失函数,和分别为公共数据集样本的本地logits和全局logits,lce为交叉熵损失函数,其中和分别为公共数据集样本的预测值和伪标签;
34、在本地模型wc的训练和更新过程中,对参数τ进行动态控制,τ的计算公式为其中r为当前轮次,r为总轮次,τ随着训练轮数的增加而增加。
35、由上述本发明的实施例提供的技术方案可以看出,本发明实施例利用无标签伪数据集作为公共数据集进行知识蒸馏,这些无标签数据提供了额外的辅助信息。通过仅使用少量的数据样本,就能够实现良好的性能。本发明根据各个客户端拥有的样本,对样本质量进行评估,并据此生成相应的权重。服务器利用这些权重对本地知识进行聚合,以获得高质量的全局知识,用于提升本地模型性能。
36、本发明附加的方面和优点将在下面的描述中部分给出,这些将从下面的描述中变得明显,或通过本发明的实践了解到。
1.一种基于知识蒸馏和锐度感知最小化的个性化联邦学习方法,其特征在于,包括:
2.根据权利要求1所述的方法,其特征在于,所述的服务器生成全局共享无标签伪数据集,将无标签伪数据集下发给各个客户端,包括:
3.根据权利要求1所述的方法,其特征在于,所述的客户端利用本地数据训练本地模型,采用锐度感知最小化方法对本地模型进行优化,得到优化后的本地模型wc,包括:
4.根据权利要求3所述的方法,其特征在于,所述的客户端利用本地数据集和所述无标签伪数据集生成本地知识,所述本地知识包括客户端本地各类原型和客户端本地模型对无标签伪数据样本的预测值包括:
5.根据权利要求4所述的方法,其特征在于,所述的服务器对各个客户端的本地知识中的本地原型进行加权平均操作,得到全局原型对各个客户端的本地知识中的logits依据权重系数进行加权运算得到全局将由全局原型和全局构成的全局知识下发给各个客户端,包括:
6.根据权利要求5所述的方法,其特征在于,所述的客户端根据本地logits和全局知识中的全局对本地模型wc进行更新,包括:
