classCrossAttention(nn.Module): """ A cross attention layer. Parameters: query_dim (`int`): The number of channels in the query. cross_attention_dim (`int`, *optional*): The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. bias (`bool`, *optional*, defaults to False): Set to `True` for the query, key, and value linear layers to contain a bias parameter. """ query = attn.to_q(hidden_states) query = attn.head_to_batch_dim(query)
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states isnotNoneelse hidden_states key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value)
inpaint 和 outpaint 模型的训练方式基本相同,都是用文生图模型的权重进行初始化,然后改变 Unet 输入的 channels(5 additional input channels ,4 for the encoded masked-image and 1 for the mask itself),新增的channels zero-initialized。
# coding=utf-8 """Memory-efficient MMD implementation in JAX."""
import jax import jax.numpy as jnp
Array = jnp.ndarray
# The bandwidth parameter for the Gaussian RBF kernel. See the paper for more # details. _SIGMA = 10 # The following is used to make the metric more human readable. See the paper # for more details. _SCALE = 1000
@jax.jit defmmd(x, y): """Memory-efficient MMD implementation in JAX. Args: x: The first set of embeddings of shape (n, embedding_dim). y: The second set of embeddings of shape (n, embedding_dim). Returns: The MMD distance between x and y embedding sets. """ x = jnp.asarray(x) y = jnp.asarray(y)
# jnp.matmul(x, x.T) etc. are not cached to avoid OOM when x has many rows. x_sqnorms = jnp.diag(jnp.matmul(x, x.T)) y_sqnorms = jnp.diag(jnp.matmul(y, y.T))
基于latent的图像生成模型都需要一个 vae 模型,将图片压缩,然后再压缩空间内进行扩散训练。autoencoder是一个基于encoder-decoder架构的图像压缩模型,对于一个大小为 $H\times W \times 3H$的输入图像,encoder模块将其编码为一个大小为 $h\times w \times c$ 的latent,其中$f=H/h=W/h$为下采样率(downsampling factor)。在训练autoencoder过程中,除了采用L1重建损失外,还增加了感知损失(perceptual loss,即LPIPS,具体见论文The Unreasonable Effectiveness of Deep Features as a Perceptual Metric)以及基于patch的对抗训练。辅助loss主要是为了确保重建的图像局部真实性以及避免模糊,具体损失函数见latent diffusion的loss部分。同时为了防止得到的latent的标准差过大,采用了两种正则化方法:第一种是KL-reg,类似VAE增加一个latent和标准正态分布的KL loss,不过这里为了保证重建效果,采用比较小的权重(~10e-6);第二种是VQ-reg,引入一个VQ (vector quantization)layer,此时的模型可以看成是一个VQ-GAN,不过VQ层是在decoder模块中,这里VQ的codebook采样较高的维度(8192)来降低正则化对重建效果的影响。
classEfficientNetEncoder(nn.Module): def__init__(self, c_latent=16): super().__init__() self.backbone = torchvision.models.efficientnet_v2_s(weights='DEFAULT').features.eval() self.mapper = nn.Sequential( nn.Conv2d(1280, c_latent, kernel_size=1, bias=False), nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1 )
模型结构如上图所示,通过将image encoder提取到 image embedding 与 text embedding融合再替换text embedding中的对应的token,来达到引入 ID 图像信息。训练的时候,原本Unet是一起训练的,无法做到即插即用。效果看起来跟IP-adapter接近,输入的参考图像的数量越多生成效果越好。
模型结构如上图,也可以看作是 IP-Adapter 和 ControlNet的结合。其中,IP-Adapter的输入是通过 Face Encoder 模型来提取脸部特征。ControlNet的输入是面部关键点图,条件输入去除了文本特征,只使用 Face embedding。训练的时候,Unet完全冻结。
VQ-VAE的loss如上所示,分成三部分,第一项用来训练 encoder和decoder,第二项叫 codebook loss,只训练 codebook,让codebook中的embedding向各自最近的$Z_e(x)$靠近。第三项叫 commitment loss,只训练encoder, 目的是encourage the output of encoder to stay close to the chosen codebook vector to prevent it from flucturating too frequently from one code vector to another, 即防止encoder的输出频繁在各个codebook embedding之间跳
CoCa构建在encoder-decoder基础上,不过这里将text decoder均分成两个部分:unimodal text decoder和multimodal text decoder。然后增加一个cls token在文本的最后,unimodal text decoder不参与对图像特征的cross-attention,这样cls token经过unimodal text decoder之后就能够得到整个句子的全局特征。同时采用attention pooling对image encoder得到特征提取图像的全局特征,两个全局特征就可以实现图像-文本的对比学习,这里的attention pooling其实就是一个multi-head attention,只不过key和value是image encoder得到的特征,而query是预先定义的一个可训练的embedding,由于我们只需要提取一个全局特征,所以只需要定义一个query就好了。
multimodal text decoder将用来执行生成任务,这里也通过一个attention pooling对image encoder得到的特征进行提取,不过这里query数量定义为256,这样attention pooling可以得到256个特征,它作为multimodal text decoder的cross-attention的输入。
CoCa训练使用的数据比较大,最后得到的效果也刷新的当时很多项多模态任务的榜单。
5 BEiT V3
5.1 背景
无论是 NLP,CV 还是 多模态领域,模型大一统是大势所趋,也就是在超大的数据集上做大规模预训练,一旦模型训练好了之后,就可以直接应用到下游任务中,成为一个通用的 Foundation Model。Beit V3正是朝着这个目标,对之前的工作进行总结和改进之后实现的。
主要亮点是 Beit v3 直接将图像和文本以相同的方式处理,并通过一个预训练任务进行训练,也就是 mask data modeling。
5.2 方法
模型结构采用和 VLMO 相同的 MOME,训练目标是 mask data modeling,可能是遮住了图像,可能是遮住了文本,模型训练学习如何去恢复它就可以。
ITC 的目的是使图像表征和文本表征对齐,以便最大化它们之间的交互信息。作者使用对比损失,通过对比正对和负对的相似性来实现这一点(正对的距离尽量近,负对的距离尽量远)。具体来说,作者将 image transformer 的输出 Query 表征 Z 与 text transformer 输出的文本表征 t 对齐,其中 t 对应 [CLS] Token 的输出 embedding。由于 Z 包含多个输出 embedding(32 个 Query,对应 32 个 embedding 向量),因此作者首先计算每个 Query 表征与 t 之间的成对相似度(32 个),然后选择最高的一个作为图像-文本相似度。为了避免信息泄露,作者采用了单模态的 Self-Attention Mask,也就是如下图红框所示 Mask,其不允许 Query 和 Text 相互看到。其中的负例都从 batch 数据中选择(图像-文本是成对存在的,每个图像都有 1 个正对,其余图像对应的文本都可以作为负对,也就是每个图像有 batch size - 1 个负样本),而不是 BLIP 中的动量队列.
6.2.2.2 ITG 基于图像文本生成
ITG 的目的是以给定输入图像作为条件来训练 Q-Former 生成文本。由于 Q-Former 的架构不允许 Text Token 与 image encoder 之间直接交互,因此必须先由 Query 和 image encoder 交互提取生成文本所需的信息,然后通过 Self-Attention 传递给 Text Token。也就是说,Query 被强制提取有关文本的所有信息的视觉特征。作者采用多模态因果自注意力掩码(Multi-modal Causal Self-Attention Mask)来控制 Query-Text 之间的交互。如下图红框内所示,类似于 UniLM 中使用的 Mask,Query 可以相互关注到,但不能关注到 Text。每个 Text Token 都可以关注到所有 Query Token,以及之前的 Text Token。此外作者还将 [CLS] Token 换成了 [DEC] Token,作为发出解码任务信息的第一个 Text Token。
6.2.2.3 ITM 图像文本匹配
ITM 的目的是学习图像和文本之间的细粒度对齐。这是一个二元分类任务,要求模型预测图像-文本对是正(匹配)还是负(不匹配)。此时作者使用双向自注意力掩码(Bi-directional Self-Attention Mask),如下图红框内所示,也就是 Query 和 Text 都可以相互看到。因此 Query 表征 Z 可以捕获到多模态信息。之后,作者将 Z 中的每个 embedding(32)都输入到二元分类 Linear 层以获得 logit,并将所有 Query 的 logit 平均输出为匹配分数。作者采用了 [2107.07651] Align before Fuse: Vision and Language Representation Learning with Momentum Distillation 中的 hard 负样本挖掘策略来创建信息丰富的负对。