常见的多模态对齐方案

常见的多模态对齐方案

1、背景

LLM 只具备处理文本的能力,需要弥补自然语言和图像模态之间的差距。通过端到端的方式训练 LMM 代价非常高,并且可能会带来灾难性遗忘的风险。目前,通常的做法是基于预训练的视觉编码器和 LLM 来构建 VLM。图像生成领域,也通过类似模态对齐的结构来引入额外的信息控制生成。常见的模态对齐方案,主要分为以下两类:

  • 基于 Learnable Query 的方案,包括:
    • Perceiver Resampler
    • Q-Former
    • Cross-attention
  • 基于 Projection 的方案:
    • 单层 Linear 投影
    • 两层 MLP

2、模态对齐的方案

2.1 Cross attention

很多常见的模块的使用了 Cross attention,大多数文生图模型,text embedding 就是通过 cross attention 引入到 Unet 结构里面。

假设输入的 Query embedding 维度 32 x 768,输入的 image embedding 维度 257 x 1024 为例,如下所示,可以看出 Cross Attention 的过程,K 和 V 的维度为 1024 x 768,Q 的维度为 768 x 768,所以对应的 Attention Score 的维度为 32 x 257,最终也可以保持 Query embedding 维度不变,依然为 32 x 768(红框):

image-20240303234925136

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class CrossAttention(nn.Module):
"""
A cross attention layer.
Parameters:
query_dim (`int`): The number of channels in the query.
cross_attention_dim (`int`, *optional*):
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
bias (`bool`, *optional*, defaults to False):
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
"""
query = attn.to_q(hidden_states)
query = attn.head_to_batch_dim(query)

encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)

attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)

使用 Cross attention 作为模态间对齐结构的 VLM 有:

  • Qwen-VL
  • 各种文生图模型

2.2 Percevier Resampler

img

Resampler 的结构如上图所示,可以看出它是一个常规的 transformer block, 其中 attention 使用 Cross attention,拿 Flamingo 举例,具体来说:

  • 每个图像经 Vision Encoder 会生成一个 [S, d] 的视觉特征,T 个图像对应 x_f 的维度为 [T, S, d]
  • x_f 加上维度为 [T, 1, d] 的 time_embeddings
  • 将时间和空间维度拉平,x_f -> [T*S, d]
  • 将 x_f 作为 transformer block 的 Key 和 Value 输入
  • 自定义的 R 个可学习的 Query Token,对应维度为 [R, d]
  • 然后经过 num_layers 层 transformer block 得到对应的新的视觉特征 x,维度为 [R, d],和可学习的 Query 维度一致。

使用 Resampler 结构的常见 VLM 模型有:

  • Flamingo
  • mPLUG-Owl

2.3 Q-former

Q-former 结构由 BLIP2 模型提出来的,在 Q-Former 中,作者额外创建了一组可学习的 Query embedding 作为输入(这与 Flamingo 中R 个可学习的 Query Token 作用一样)。这些 Query embedding 在 Self Attention 层相互交叉,并通过 Cross attention 层(每隔一个 transformer block 有一个 Cross attention)与冻结的 image encoder 输出的 image embedding 进行交叉。

截屏2024-03-04 00.07.50

使用 Q-former 结构的 VLM 模型有:

  • BLIP2
  • MiniGPT- v1
  • InstructBILP