disentangleprompt

Paper Reading(Prompt to Disentangle)

Key Idea

  1. 借助神经网络提取特征
  2. 借助prompt机制将特征分解为domain-general和domain-specific两部分
    1. domain general prompt有利于提升在test domain上的繁华效果
    2. 不偏不倚不偏不倚不偏不倚不偏不倚不偏不倚不偏不倚*
  3. 本文将visual backbone输出的feature视为一个个token,将其与DS/DG prompt拼接得到的序列送给Transformer,编码出的Sequence送给transform输出的序列

Problem Formulation

给定若干domain用于训练

测试数据来自新的Domain $D_t$,给定部分包含标签的数据用于fine-tune(C-way K-shot,C个类,每个类包含K个sample)

Method

Train

每个training iteration,随即从若干训练域中选择一个域$D_n \in D$,并采样若干图片$\{x_n\}$,利用CNN visual backbone $f(\cdot)$编码得到feature map

展平为$H\times W$个token $F\in \R^{HW\times D}$,DG prompt(G)是一组可学习参数,DS prompt(S)根据Domain中若干图像计算得到(本质上用神经网络提取特征)

利用多层Transformer提取特征,第$l$层$l=1,2,\cdots,L$记作

将最后一层的prompt输出$G^L,S^L$送给global/local classification head,global classification head用于判断样本属于哪个domain,local classification head用于图片分类

Test

测试阶段的DG prompt $G$来自训练阶段,DS prompt $S$来自测试数据通过训练阶段得到的用于抽取DSprompt的网络推理得到

训练DG prompt

做平均池化,计算Domain Level的分类损失

为了避免在单个domain上训练导致DG prompt倾向于某个特定domain,本文提出Neutralize DG prompt,对于特定domain $D_n$中的B个样本$x_1,x_2,\cdots,x_B$,计算Domain Center(动量更新)

最终希望DG prompt和每个Domain Center尽量远,得到

训练DS prompt

每个class选择一个样本,平均池化后的特征拼接生成Domain Specific Prompt(为什么要从多个class中选择?)

image-20231123175109769

和图片feature组成的sequence经过transformer后,DS prompt应该具有和Feature共有的属性(原先不属于图片分类的token被掩藏),随后将池化后的DS prompt送给local classification head(对于每个Domain构建一个独立的分类器),希望预测出F对应的ground truth label

本站访客数人次