一种基于原型对齐学习的个性化联邦学习方法

xiaoxiao2天前  4


本发明涉及云计算,具体涉及一种基于原型对齐学习的个性化联邦学习方法。


背景技术:

1、深度学习技术的广泛应用已经在各个领域取得了显著的成果,如图像识别、自然语言处理和推荐系统等。然而,随着深度学习模型的发展和数据规模的增长,面临的一个重要问题是如何处理分散在不同实体之间的大规模、敏感性高的数据。人们对于数据隐私和安全的担忧逐渐增加,对保护私有数据提出了更高的要求。

2、为了解决这一问题,联邦学习(federated learning)作为一种新颖的分布式学习框架被提出,它允许在分布式环境中训练模型,同时保护客户端的数据隐私。在联邦学习中,客户端共享模型更新的信息,而不是直接共享原始数据。这种方式有助于减少数据传输量和隐私泄露的风险。然而,当前存在的联邦学习算法通常假设客户端之间的数据是同质的,即数据具有相似的分布和特征。这个假设在现实中并不总是成立。实际应用场景中的数据通常具有异质性特征,例如不同地区、不同用户群体或不同设备所产生的数据可能具有不同的数据分布和特性。由于缺乏对数据异质性的考虑,现有的联邦学习算法在处理异质性数据时面临挑战。许多研究表明,忽视数据的异质性可能导致模型在某些客户端上的性能下降,从而影响整体模型的性能。

3、因此,需要一种解决数据异构性的有效的个性化联邦学习方法,具体是对客户端进行模型级别的个性化设置,从而缓解数据异构性的影响,为各客户端训练出效果更好的个性化模型。

4、现有的个性化联邦学习方法通常将神经网络模型划分为特征提取器(featureextractor)和个性分类器(classifier)。在这种划分中,特征提取器通常位于神经网络的前几层,负责从原始数据中提取共享的特征表示。这些特征表示不包含个体数据的敏感信息,因此可以在客户端之间共享。然后,每个客户端都有一个个性分类器,用于根据特征表示进行个性化的分类任务。在学习过程中,特征提取器在本地训练完成后,会发送到云端进行参数聚合,在下一轮训练中,覆盖本地的特征提取器,而个性化分类器则保留在本地训练,从本地私有数据中学习。

5、然而,当客户端之间的数据分布相似的情况下,由于上述方法没有充分考虑本地分类器与其他客户端分类器的协作,导致比传统联邦学习算法学习到的单个全局模型泛化效果差。在这种情况下,可以对个性分类器进行协作,通过共享和整合个性分类器的知识,来提升个性化模型的精度。


技术实现思路

1、本发明的目的在于针对现有技术的不足,提供一种基于原型对齐学习的个性化联邦学习方法。本发明旨在共享原型信息,鼓励具有相似原型信息的客户端进行协作,在不同数据异质性场景下均能为每个客户端训练出性能较好的个性化模型。

2、本发明的目的是通过以下技术方案来实现的:一种基于原型对齐学习的个性化联邦学习方法,包括以下步骤:

3、(1)中心服务器对全局模型参数{φ,ψ}、全局类原型集进行初始化;其中,全局模型参数{φ,ψ}包括全局特征提取器参数φ和分类器参数ψ;

4、(2)中心服务器选择部分客户端作为本轮训练的参与方,并将初始化后的全局模型参数{φ,ψ}以及全局类原型集下发给联邦学习的各个客户端,各个客户端根据接收到的全局模型参数{φ,ψ}对本地模型进行初始化;

5、(3)客户端i利用本地私有训练数据以及接收到的全局类原型集对本地模型进行迭代训练,以获取训练好的本地模型参数{φi,ψi}以及本地类原型集ci,并将其上传给中心服务器;

6、(4)中心服务器根据所有客户端上传的本地模型参数{φi,ψi}以及本地类原型集ci进行加权聚合,获得更新后的全局类原型集全局特征提取器参数以及各个客户端个性化分类器参数,并将其下发给各个客户端;

7、(5)各个客户端根据自身接收到的更新后的全局特征提取器参数以及客户端个性化分类器参数对其本地模型进行更新;

8、(6)重复步骤(3)-步骤(5),直至达到预设的通信轮次或所有客户端的本地模型的目标损失函数收敛为止。

9、进一步地,所述客户端i利用本地私有训练数据以及接收到的全局类原型集对本地模型进行迭代训练,以获取训练好的本地模型参数{φi,ψi},具体包括:

10、客户端i的本地模型包括特征提取器和分类器;

11、将本地私有训练数据输入至本地模型中,先经过特征提取器,提取输入数据的底层信息,得到特征嵌入向量;

12、再将特征嵌入向量输入至分类器,将其映射到对应的类别或标签的预测值;

13、根据类别或标签的预测值以及对应的真实类别计算交叉熵损失函数值,根据客户端i接收到的全局类原型集以及特征提取器输出的特征嵌入向量计算原型对齐损失函数值;

14、根据交叉熵损失函数和原型对齐损失函数构建目标损失函数,以最小化目标损失函数值为优化目标,采用随机梯度下降法调整本地模型参数{φi,ψi},直至达到预设的迭代训练次数或目标损失函数值收敛为止,获取训练好的本地模型参数{φi,ψi}。

15、进一步地,所述目标损失函数的表达式为:

16、

17、式中,表示目标损失函数,表示交叉熵损失函数,表示原型对齐损失函数,μ为用于调整本地模型的目标损失函数中交叉熵损失项和原型对齐损失项重要程度的超参数;

18、所述交叉熵损失函数的表达式为:

19、

20、式中,ni表示第i个客户端的本地私有训练数据样本总数,yl是样本数据xl的所属类别,表示客户端本地模型参数;

21、所述原型对齐损失函数的表达式为:

22、

23、式中,ni表示第i个客户端的本地私有训练数据样本总数,是给定输入样本数据xl特征提取器输出的特征嵌入向量,φi为第i个客户端的特征提取器参数,yl是样本数据xl的所属类别,是样本所属类别yl的全局类原型,表示两个特征嵌入向量之间的欧氏距离的平方,d是特征嵌入向量的维数。

24、进一步地,所述步骤(3)中,所述本地类原型集ci的获取方法具体包括:

25、将迭代训练过程中特征提取器输出的特征嵌入向量作为类原型信息,按照下式获取客户端i所有类别的原型信息;根据所有类别的原型信息获取客户端i的本地类原型集ci;

26、

27、式中,表示客户端i第j个类别的原型信息,也是类别j中所有样本的特征嵌入向量的平均值;di,j是第i个客户端的本地私有训练数据集di的子集,由属于类别j的本地私有训练数据组成;fi(φi;x)表示第i个客户端的特征提取器输出的特征嵌入向量,φi表示第i个客户端的特征提取器参数,x表示输入特征提取器的本地私有训练数据。

28、进一步地,所述步骤(4)中,更新后的全局类原型集的获取方法具体包括:

29、中心服务器接收到所有客户端上传的本地类原型集ci后,对于类别j,中心服务器选择一组具有类别j的客户端上传的原型信息进行加权聚合,按照下式计算类别j的全局原型信息;根据所有类别全局原型信息获取更新后的全局类原型集

30、

31、式中,表示类别j的全局原型信息,表示具有类别j的客户端集合,表示客户端i类别j的原型信息。

32、进一步地,所述步骤(4)中,更新后的全局特征提取器参数的获取方法具体包括:

33、中心服务器利用各个客户端的权重对其上传的特征提取器参数进行加权平均,得到更新后的全局特征提取器参数;其中,客户端的权重根据该客户端对应的本地私有训练数据计算获得。

34、进一步地,所述更新后的全局特征提取器参数的计算公式为:

35、

36、式中,φ(t+1)表示更新后的全局特征提取器参数,t表示当前通信轮次,表示客户端i在当前通信轮次中的特征提取器参数,m为参与联邦学习的客户端的总数,βi表示客户端i的权重,其计算公式为:

37、

38、式中,ni表示客户端i的本地私有训练数据样本总数,为m个客户端的本地私有训练数据总量。

39、进一步地,所述步骤(4)中,各个客户端个性化分类器参数的获取方法具体包括:

40、中心服务器根据各个客户端上传的本地类原型集计算该客户端的本地类原型表征与其他客户端的本地类原型表征之间的余弦相似度,经过归一化处理后得到关系矩阵r,其计算公式表示为:

41、

42、式中,rib为余弦相似度;rib为关系矩阵权重,是关系矩阵r中的元素;

43、中心服务器将所有客户端在t+1轮的个性化分类器参数按照关系矩阵权重rib进行加权求和,其计算公式表示为:

44、

45、式中,为更新后的客户端i的个性化分类器参数,为当前通信轮次t时的其他客户端的个性化分类器参数。

46、与现有技术相比,本发明的有益效果为:

47、(1)本发明利用特征提取层的输出表征作为数据原型,与所有类别的全局原型计算对齐损失,从而充分利用全局知识,缓解数据异构性的影响,提高了模型的性能。

48、(2)本发明通过上传类原型,能有效避免隐私泄漏的问题;每个客户端只需上传其数据的抽象表征,而不是具体的数据本身,这大大降低了泄露用户敏感信息的风险。

49、(3)本发明通过接收到的客户端所有原型的相似度,加权组合来更新每个客户端下一轮的个性化分类器模型,自适应地促进了客户端之间潜在的成对协作,并显著提高了协作的有效性。

50、(4)本发明能够弥补个性化联邦学习方法没有考虑不同客户端之间知识协作等问题,不仅满足了数据隐私保护的需求,还确保了联邦学习过程中各客户端之间的信息共享与协作,能有效提高联邦学习方法训练所得模型的精度。


技术特征:

1.一种基于原型对齐学习的个性化联邦学习方法,其特征在于,包括以下步骤:

2.根据权利要求1所述的基于原型对齐学习的个性化联邦学习方法,其特征在于,所述客户端i利用本地私有训练数据以及接收到的全局类原型集对本地模型进行迭代训练,以获取训练好的本地模型参数{φi,ψi},具体包括:

3.根据权利要求2所述的基于原型对齐学习的个性化联邦学习方法,其特征在于,所述目标损失函数的表达式为:

4.根据权利要求1所述的基于原型对齐学习的个性化联邦学习方法,其特征在于,所述步骤(3)中,所述本地类原型集ci的获取方法具体包括:

5.根据权利要求1所述的基于原型对齐学习的个性化联邦学习方法,其特征在于,所述步骤(4)中,更新后的全局类原型集的获取方法具体包括:

6.根据权利要求1所述的基于原型对齐学习的个性化联邦学习方法,其特征在于,所述步骤(4)中,更新后的全局特征提取器参数的获取方法具体包括:

7.根据权利要求6所述的基于原型对齐学习的个性化联邦学习方法,其特征在于,所述更新后的全局特征提取器参数的计算公式为:

8.根据权利要求1所述的基于原型对齐学习的个性化联邦学习方法,其特征在于,所述步骤(4)中,各个客户端个性化分类器参数的获取方法具体包括:


技术总结
本发明公开了一种基于原型对齐学习的个性化联邦学习方法,该方法包括:服务器初始化全局模型参数和全局类原型集,并下发给各个客户端,以对客户端的本地模型进行初始化;各个客户端利用本地私有训练数据和全局类原型集对本地模型进行迭代训练,获取训练好的本地模型参数和本地类原型集,并上传给服务器;服务器根据所有客户端的本地模型参数和本地类原型集进行加权聚合,获得更新后的全局类原型集、全局特征提取器参数和各个客户端个性化分类器参数,并下发给各个客户端;各个客户端重复训练本地模型,直至达到预设的通信轮次或所有本地模型的目标损失函数收敛为止。本发明能够大大降低泄露用户敏感信息的风险,有利于提高模型的精度和性能。

技术研发人员:王永祥,才振功
受保护的技术使用者:浙江大学
技术研发日:
技术公布日:2024/9/23

最新回复(0)