本技术涉及深度学习领域,尤其涉及一种基于可解释多任务模型的预测与解释方法及装置。
背景技术:
1、尽管深度学习模型在众多图像分类任务中取得了较好的效果,然而在细粒度分类问题中的分类效果仍有待提高,这种情况往往是因为细粒度分类问题的类别之间差别较小,难以被模型准确地区分,如何引导深度学习模型关注目标本身并忽略复杂的背景信息也是目前亟待解决的问题。为了解决这个问题,研究者们正在探索新的方法,例如引入更多的领域知识、采用更加复杂的模型结构、进行数据增强等等。此外,尽管深度学习性能优越,但其复杂的网络结构导致其缺乏可解释性。
2、目前的细粒度目标识别方法有如下几种:①将在大规模数据集上预训练的模型(如imagenet)进行微调,以适应细粒度分类任务。②引入注意力机制来关注目标物体的细节和特征,以提高模型在细粒度分类任务中的性能。③结合属性信息,将目标的属性信息与图像特征相结合,提高模型的细粒度分类能力。
3、研究人员们也在不断探索如何提高深度神经网络的可解释性:①使用可视化工具将神经网络的内部运作可视化的方法。这种方法通过使用各种图形化技术,例如热图、特征可视化等,来展示神经网络中每个神经元的激活模式以及神经网络对不同输入特征的敏感性。②模型简化。这种方法通过从神经网络中删除贡献较小的节点和连接来减少神经网络的复杂度,从而提高神经网络的可解释性。③模型解释。这是一种将神经网络的输出结果解释为与特定输入特征有关的方法。例如,研究人员可以使用局部敏感性分析方法来识别与某个特定输入特征有关的输出结果,通过在神经网络中进行反向传播,来确定对于某个特定输出结果贡献最大的输入特征,这种方法的典型代表是类别激活映射图。
4、然而,现有的分析方法存在以下缺点:特征重要性分析法对于复杂的深度学习模型,特征之间的相互作用可能无法完全捕捉,从而导致对模型行为的不完全理解;模型解释方法在某些情况下可能不稳健,因为他们可能对输入数据的微小变化非常敏感,从而产生不同的解释结果;激活热图法通常只提供了定性的直观信息,而不提供关于模型决策的精确定量解释。
技术实现思路
1、本技术旨在至少在一定程度上解决相关技术中的技术问题之一。
2、为此,本技术的第一个目的在于提出一种基于可解释多任务模型的预测与解释方法,以通过结合类别预测和属性预测,引导深度学习模型将注意力集中在前景物体而非背景信息上,为细粒度分类提供关键视觉依据,并且生成的文本-视觉多模态解释可以帮助用户更好地理解网络预测结果。
3、本技术的第二个目的在于提出一种基于可解释多任务模型的预测与解释装置。
4、本技术的第三个目的在于提出一种电子设备。
5、本技术的第四个目的在于提出一种计算机可读存储介质。
6、为达上述目的,本技术第一方面实施例提出了一种基于可解释多任务模型的预测与解释方法,包括:
7、获取初始图像,抽取所述初始图像的图像特征,对所述图像特征进行类别预测与属性预测,得到初始类别预测结果和预测属性;
8、将所述初始类别预测结果和所述预测属性嵌入到同一维度向量进行特征维度对齐,拼接对齐后得到的属性嵌入和类别嵌入以生成集成嵌入,利用卷积神经网络将所述集成嵌入映射为图像类别标签,得到集成类别预测结果;
9、根据所述初始类别预测结果和所述集成类别预测结果,得到所述初始图像的最终预测结果;
10、基于嵌入注意力和梯度反传注意力,分别对所述最终预测结果进行语义解释和视觉解释。
11、可选的,所述抽取所述初始图像的图像特征,对所述图像特征进行类别预测与属性预测,得到初始类别预测结果和预测属性,包括:
12、将所述初始图像输入特征抽取网络,抽取所述初始图像的图像特征,公式化为:
13、vimg=fimg(x)
14、其中,vimg为所述图像特征,fimg(·)表示特征提取函数,x为所述初始图像;
15、将所述图像特征分别输入类别预测网络和属性预测网络,得到所述图像特征的初始类别预测结果和预测属性,公式化为:
16、cp=gc(vimg)
17、
18、其中,cp为所述初始类别预测结果,ai表示第i个预测属性,gc(·)表示类别预测函数,ga(·)表示属性预测函数。
19、可选的,所述将所述初始类别预测结果和所述预测属性嵌入到同一维度向量进行特征维度对齐,拼接对齐后得到的属性嵌入和类别嵌入以生成集成嵌入,包括:
20、将所述初始类别预测结果cp和所述预测属性ai的维度嵌入到de维向量中来进行特征维度对齐,公式化为:
21、ea=emba(α)
22、ep=embp(cp)
23、其中,ea为属性嵌入,ep为类别嵌入,emb(·)表示嵌入函数;
24、拼接所述属性嵌入和类别嵌入,得到所述集成嵌入,公式化为:
25、e=[ea;ep]
26、其中,e表示所述集成嵌入。
27、可选的,所述利用cnn分类器将所述集成嵌入映射为图像类别标签,得到集成类别预测结果,包括:
28、ci=gi(e)
29、其中,ci为所述集成类别预测结果,gi(·)表示所述cnn分类器。
30、可选的,所述根据所述初始类别预测结果和所述集成类别预测结果,得到初始图像的最终预测结果,包括:
31、计算所述初始类别预测结果cp与所述集成类别预测结果ci的加权和,将所述加权和作为所述最终预测结果,公式化为:
32、c=λ·cp+η·ci
33、其中,c为所述最终预测结果,λ与η为超参数。
34、可选的,所述基于嵌入注意力,对所述最终预测结果进行语义解释,包括:
35、将所述最终预测结果c反向传播到属性嵌入ea层,得到所述最终预测结果c在所述嵌入ea上的梯度,公式化为:
36、
37、其中,w为所述梯度,w的每一行wi代表一个对应属性的梯度值向量;
38、对wi每一行进行累加,得到第i个属性的贡献度分数,公式化为:
39、
40、其中,si表示第i个属性的贡献度分数;
41、将贡献度分数最高的前三个属性总结为语义解释。
42、可选的,所述基于梯度反传注意力,对所述最终预测结果视觉解释,包括:
43、将所述最终预测结果的梯度反传回最后一个卷积层,公式化为:
44、
45、其中,表示每一层特征图的权重,ak是特征图,yc表示最终预测结果c的梯度;
46、计算最后一个卷积层中所有特征图ak的加权和,得到注意力图,公式化为:
47、
48、其中,at为所述注意力图,relu(·)表示非线性激活函数。
49、为达上述目的,本技术第二方面实施例提出了一种基于可解释多任务模型的预测与解释装置,包括:
50、类别与属性预测模块,用于获取初始图像,抽取所述初始图像的图像特征,对所述图像特征进行类别预测与属性预测,得到初始类别预测结果和预测属性;
51、嵌入模块,用于将所述初始类别预测结果和所述预测属性嵌入到同一维度向量进行特征维度对齐,拼接对齐后得到的属性嵌入和类别嵌入以生成集成嵌入,利用卷积神经网络将所述集成嵌入映射为图像类别标签,得到集成类别预测结果;
52、最终预测结果生成模块,用于根据所述初始类别预测结果和所述集成类别预测结果,得到所述初始图像的最终预测结果;
53、解释模块,用于基于嵌入注意力和梯度反传注意力,分别对所述最终预测结果进行语义解释和视觉解释。
54、为达上述目的,本技术第三方面实施例提出了一种电子设备,包括:处理器,以及与所述处理器通信连接的存储器;
55、所述存储器存储计算机执行指令;
56、所述处理器执行所述存储器存储的计算机执行指令,以实现如上述第一方面中任一项所述的方法。
57、为达上述目的,本技术第四方面实施例提出了一种计算机可读存储介质,所述计算机可读存储介质中存储有计算机执行指令,所述计算机执行指令被处理器执行时用于实现如上述第一方面中任一项所述的方法。
58、为达上述目的,本技术第五方面实施例提出了一种计算机程序产品,包括计算机程序,该计算机程序被处理器执行时实现如上述第一方面中任一项所述的方法。
59、本技术的实施例提供的技术方案至少带来以下有益效果:
60、在充分研究深度神经网络的细粒度分类原理以及可解释性规律的基础上,建立可解释的基于属性学习预测的多任务模型,通过结合类别预测和属性预测有助于模型聚焦于显著属性,引导深度学习模型将注意力集中在前景物体而非背景信息上,为细粒度分类提供关键视觉依据,并且通过后续的嵌入注意力推理出哪些属性对分类结果的贡献更大,可以帮助用户更好地理解网络预测结果,利用这些属性设计出的自然语言描述将更有利于小样本分类任务的完成。
61、本技术附加的方面和优点将在下面的描述中部分给出,部分将从下面的描述中变得明显,或通过本技术的实践了解到。
1.一种基于可解释多任务模型的预测与解释方法,其特征在于,包括:
2.根据权利要求1所述的方法,其特征在于,所述抽取所述初始图像的图像特征,对所述图像特征进行类别预测与属性预测,得到初始类别预测结果和预测属性,包括:
3.根据权利要求2所述的方法,其特征在于,所述将所述初始类别预测结果和所述预测属性嵌入到同一维度向量进行特征维度对齐,拼接对齐后得到的属性嵌入和类别嵌入以生成集成嵌入,包括:
4.根据权利要求3所述的方法,其特征在于,所述利用cnn分类器将所述集成嵌入映射为图像类别标签,得到集成类别预测结果,包括:
5.根据权利要求4所述的方法,其特征在于,所述根据所述初始类别预测结果和所述集成类别预测结果,得到初始图像的最终预测结果,包括:
6.根据权利要求5所述的方法,其特征在于,所述基于嵌入注意力,对所述最终预测结果进行语义解释,包括:
7.根据权利要求5所述的方法,其特征在于,所述基于梯度反传注意力,对所述最终预测结果视觉解释,包括:
8.一种基于可解释多任务模型的预测与解释装置,其特征在于,包括:
9.一种电子设备,其特征在于,包括:处理器,以及与所述处理器通信连接的存储器;
10.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质中存储有计算机执行指令,所述计算机执行指令被处理器执行时用于实现如权利要求1-7中任一项所述的方法。