disentangleprompt
Paper Reading(Prompt to Disentangle)
Key Idea
- 借助神经网络提取特征
- 借助prompt机制将特征分解为domain-general和domain-specific两部分
- domain general prompt有利于提升在test domain上的繁华效果
- 不偏不倚不偏不倚不偏不倚不偏不倚不偏不倚不偏不倚*
- 本文将visual backbone输出的feature视为一个个token,将其与DS/DG prompt拼接得到的序列送给Transformer,编码出的Sequence送给transform输出的序列
Problem Formulation
给定若干domain用于训练
测试数据来自新的Domain $D_t$,给定部分包含标签的数据用于fine-tune(C-way K-shot,C个类,每个类包含K个sample)
Method
Train
每个training iteration,随即从若干训练域中选择一个域$D_n \in D$,并采样若干图片$\{x_n\}$,利用CNN visual backbone $f(\cdot)$编码得到feature map
展平为$H\times W$个token $F\in \R^{HW\times D}$,DG prompt(G)是一组可学习参数,DS prompt(S)根据Domain中若干图像计算得到(本质上用神经网络提取特征)
利用多层Transformer提取特征,第$l$层$l=1,2,\cdots,L$记作
将最后一层的prompt输出$G^L,S^L$送给global/local classification head,global classification head用于判断样本属于哪个domain,local classification head用于图片分类
Test
测试阶段的DG prompt $G$来自训练阶段,DS prompt $S$来自测试数据通过训练阶段得到的用于抽取DSprompt的网络推理得到
训练DG prompt
做平均池化,计算Domain Level的分类损失
为了避免在单个domain上训练导致DG prompt倾向于某个特定domain,本文提出Neutralize DG prompt,对于特定domain $D_n$中的B个样本$x_1,x_2,\cdots,x_B$,计算Domain Center(动量更新)
最终希望DG prompt和每个Domain Center尽量远,得到
训练DS prompt
每个class选择一个样本,平均池化后的特征拼接生成Domain Specific Prompt(为什么要从多个class中选择?)
和图片feature组成的sequence经过transformer后,DS prompt应该具有和Feature共有的属性(原先不属于图片分类的token被掩藏),随后将池化后的DS prompt送给local classification head(对于每个Domain构建一个独立的分类器),希望预测出F对应的ground truth label