文生图模型演进

1. 背景

此文是为了更好的理解 stable diffusion以及DALL-E 3等最新的图像生成模型,回顾一下在它们之前更早的模型。stable diffusion的作者也是 VQ-GAN的作者,DALL-E3之前还有 DALL-E,DALL-E2。

2. AE

img

作用:可以用于学习 无标签数据的有效编码, 学习对高维数据的进行低维表示,常用于降维,数据去噪,特征学习。

基本思想:AE(auto-encoder)是很早之前的技术了,思路也非常简单:用一个编码器(Encoder)把输入编码为 latent vector;然后用 Decoder 将其解码为重建图像,希望重建后的图像与输入图像越接近越好。通常 latent vector 的维度比输入、输出的维度小,因此称之为 bottleneck。AE 是一个自重建的过程,所以叫做“自-编码器”。

特点:但是模型在 Latent Space 没有增加任何的约束或者正则化,意味着不知道 Latent Space 是如何构建的, 所以很难使用 latent space 来采样生成一个新的图像。

DAE

DAE(Denoising autoencoder)将原始输入图像进行一定程度的打乱,得到 corrupted input。然后把后者输入AE,目标仍然是希望重建后的图像与原始输入越接近越好。DAE 的效果很不错,原因之一就是图像的冗余度太高了,即使添加了噪声,模型依然能抓取它的特征。而这种方式增强了模型的鲁棒性,防止过拟合。

3. VAE

img

论文[1312.6114] Auto-Encoding Variational Bayes

作用:除了AE的作用之外,还广泛应用于生成新的、与训练数据相似但不完全相同的样本。

基本思想:VAE(Variational autoencoder)仍然由一个编码器和一个解码器构成,并且目标仍然是重建原始输入。但中间不再是学习 latent vector z ,而是学习它的后验分布 $p(z|x)$ ,并假设它遵循多维高斯分布。具体来说,编码器得到两个输出,并分别作为高斯分布的均值和协方差矩阵的对角元(假设协方差矩阵是对角矩阵)。然后在这个高斯分布中采样,送入解码器。实际工程实现中,会用到重参数化(reparameterization)的技巧。

重参数就是把带有随机性的z变成确定性的节点,同时把随机性转嫁给另一个输入节点 ϵ例如,这里用正态分布采样,原本从均值为x和标准差为ϕ的正态分布N(x,ϕ2)中采样得到 z,将其转化成从标准正态分布N(0,1)中采样得到 ϵ, 再通过重参数技巧计算得到 z=x+ϵ⋅ϕ。这样一来,采样的过程移出了梯度反向传播的路径,计算图里的参数(均值x和标准差ϕ)就可以用梯度更新了,而新加的 ϵ 的输入分支不做更新,只当成一个没有权重变化的输入。
用博客《The Gumbel-Softmax Distribution》的说法再复述一遍,重参数就是把原来完全随机的节点分成了确定的节点和随机的节点两部分:z∼N(0,1)→z=μ+σϵ where ϵ∼N(0,1)

特点:VAE训练目标除了重构误差,还包括最小化隐空间的KL散度,以确保隐空间与标准正态分布接近。但VAE最大的问题也是这个,使用了固定的先验分布。具体推导涉及 ELBO,如下图

截屏2024-03-09 23.08.49

截屏2024-03-09 23.10.07

4. VQ-VAE

VQ(vector quantization)是一种数据压缩和量化的技术,它可以将连续的向量映射到一组离散的具有代表性的向量中。在深度学习领域,VQ通常用来将连续的隐空间表示映射到一个有限的、离散的 codebook 中。

VAE 具有一个最大的问题就是使用了固定的先验(高斯分布),其次是使用了连续的中间表征,导致模型的可控性差。为了解决这个问题,VQ-VAE(Vector Quantized Variational Autoencoder)选择使用离散的中间表征,同时,通常会使用一个自回归模型来学习先验(例如 PixelCNN),在训练完成后,用来采样得到 $z_e$。

img

图像首先经过encoder,得到$z_e$,它是$H\times W$个 $D$ 维向量。$e_1,e_2,…,e_k$是 $K$ 个 $D$ 维向量,称为codebook。 对于 $z_e$ 中的每个 $D$ 维向量,都可以在 codebook 中找到最接近的 $e_i$, 构成 $z_q$, 这就是decoder的输入。一般 $k=8192, D=512 or 768$。

从 $z_e(x)$ 到 $z_q(x)$ 这个变化可以看成一个聚类,$e_1,e_2,…,e_k$ 可以看作 K 个聚类中心。这样把 encoder 得到的 embedding 离散化了,只由聚类中心表示。

4.1 VQ-VAE的训练

在VQ中使用 Argmin来获取最小的距离,这一步是不可导的,因为无法将 Decoder 和 Encoder联合训练,针对这个问题,作者添加了一个trick,如上图红线所示:直接将 $z_q(x)$的梯度cooy给$Z_e(x)$, 而不是给 codebook里面的embedding。

推推截图_20240112000325

VQ-VAE的loss如上所示,分成三部分,第一项用来训练 encoder和decoder,第二项叫 codebook loss,只训练 codebook,让codebook中的embedding向各自最近的$Z_e(x)$靠近。第三项叫 commitment loss,只训练encoder, 目的是encourage the output of encoder to stay close to the chosen codebook vector to prevent it from flucturating too frequently from one code vector to another, 即防止encoder的输出频繁在各个codebook embedding之间跳

具体代码实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class VectorQuantizer(nn.Module):
def __init__(self, num_embeddings, embedding_dim, commitment_cost):
super(VectorQuantizer, self).__init__()

self._embedding_dim = embedding_dim
self._num_embeddings = num_embeddings

self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
self._embedding.weight.data.uniform_(-1/self._num_embeddings, 1/self._num_embeddings)
self._commitment_cost = commitment_cost

def forward(self, inputs):
# convert inputs from BCHW -> BHWC
inputs = inputs.permute(0, 2, 3, 1).contiguous()
input_shape = inputs.shape

# Flatten input
flat_input = inputs.view(-1, self._embedding_dim)

# Calculate distances
distances = (torch.sum(flat_input**2, dim=1, keepdim=True)
+ torch.sum(self._embedding.weight**2, dim=1)
- 2 * torch.matmul(flat_input, self._embedding.weight.t()))

# Encoding
encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
encodings.scatter_(1, encoding_indices, 1)

# Quantize and unflatten
quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)

# Loss
e_latent_loss = F.mse_loss(quantized.detach(), inputs)
q_latent_loss = F.mse_loss(quantized, inputs.detach())
loss = q_latent_loss + self._commitment_cost * e_latent_loss

quantized = inputs + (quantized - inputs).detach()
avg_probs = torch.mean(encodings, dim=0)
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))

# convert quantized from BHWC -> BCHW
return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings

4.2 VQ-VAE + PixelCNN

有了上述的 VQ-VAE 模型,可以很容易实现图像压缩、重建的目的,但是无法生成新的图像数据。当然可以随机生成 Index,然后对应生成量化后的 latent code,进而使用 Decoder 来生成输出图像。但是这样的 latent code 完全没有全局信息甚至局部信息,因为每个位置都是随机生成的。因此,作者引入了 PixelCNN 来自回归的生成考虑了全局信息的 latent code,进而可以生成更真实的图像,如下图所示:

image-20240112001504547

PixelCNN 和 VQ-VAE 的一作是同一个人,来自 Google DeepMind,对应的论文为:Conditional Image Generation with PixelCNN Decoders。此处我们不再对 PixelCNN 展开,只需要知道它是一个自回归生成模型,可以逐个像素的生成,因为其是自回归模型,所以每个位置都能看到之前位置的信息,这样生成的 latent code 能够更全面的考虑到空间信息,有助于提高模型生成图像的质量和多样性。

4.3 VQ-VAE-2

VQ-VAE-2 的模型结构如下图所示,以 256x256 的图像压缩重建为例:

  • 训练阶段:其首先使用 Encoder 将图像压缩到 Bottom Level,对应大小为 64x64,然后进一步使用 Encoder 压缩到 Top Level,大小为 32x32。重建时,首先将 32x32 的表征经过 VQ 量化为 latent code,然后经过 Decoder 重建 64x64 的压缩图像,再经过 VQ 和 Decoder 重建 256x256 的图像。
  • 推理阶段(图像生成):使用 PixelCNN 首先生成 Top Level 的离散 latent code,然后作为条件输入 PixelCNN 以生成 Bottom Level 的更高分辨率的离散 latent code。之后使用两个 Level 的离散 latent code 生成最终的图像。

图片

当然,基于这个思想作者也进一步验证了使用 3 个 Level 来生成 1024x1024 分辨率的图像,相应的压缩分辨率分别为 128x128、64x64、32x32。

5 VQ-GAN

5.1 概述

paper:https://openaccess.thecvf.com/content/CVPR2021/html/Esser_Taming_Transformers_for_High-Resolution_Image_Synthesis_CVPR_2021_paper.html

code:https://github.com/CompVis/taming-transformers

image-20230512112944268

VQ-GAN 相比较 VQ-VAE 的主要改变有以下几点:

  • 引入 GAN 的思想,将 VQ-VAE 当做生成器(Generator),并加入判别器(Discriminator),以对生成图像的质量进行判断、监督,以及加入感知重建损失(不只是约束像素的差异,还约束 feature map 的差异),以此来重建更具有保真度的图片,也就学习了更丰富的 codebook。
  • 将 PixelCNN 替换为性能更强大的自回归 GPT2 模型(针对不同的任务可以选择不同的规格)
  • 引入滑动窗口自注意力机制,以降低计算负载,生成更大分辨率的图像。

VQ-GAN 也是 stable diffusion的作者。

5.2 方法

模型结构如上图所示,实际训练的时候是分为两阶段训练的

如下图所示,第一阶段训练,相比 VQ-VAE 主要是增加 Discriminator, 以及将重建损失替换成 LPIPS损失:

  • Discriminator:对生成的图像块进行判别,每一块都会返回 True 和 False,然后将对应的损失加入整体损失中。
  • LPIPS:除了像素级误差外,也会使用 VGG 提取 input 图像和 reconstruction 图像的多尺度 feature map,以监督对应的误差(具体可参考 lpips.py - CompVis/taming-transformers · GitHub)。

image-20240113215830827

6. DALL-E (dVAE, DALL-E)

6.1 概述

DALL-E 最主要的贡献是提供了不错的文本引导图片生成的能力,其不是在 VQ-VAE 基础上修改,而是首先引入 VAE 的变种 dVAE,然后在此基础上进一步训练 DALL-E。可惜的是,OpenAI 并不 Open,只开源了 dVAE 部分模型,文本引导生成部分并没有开源,不过 Huggingface 和 Google Cloud 团队进行了复现,并发布对应的 DALL-E mini 模型。

DALL-E 对应的论文为:[2102.12092] Zero-Shot Text-to-Image Generation。对应的代码库为:GitHub - openai/DALL-E: PyTorch package for the discrete VAE used for DALL·E.。

DALL-E mini 对应的文档为:DALL-E Mini Explained,对应的代码库为:GitHub - borisdayma/dalle-mini: DALL·E Mini - Generate images from a text prompt。

6.2 dVAE

dVAE(discrete VAE)与VQ-VAE的区别在于引入 Gumbel Softmax 来训练, 避免 VQ-VAE 训练中 ArgMin 不可导的问题。

image-20240113222402539

6.3 模型训练

有了 dVAE 模型之后,第二阶段就是就是训练 Transformer(此阶段会固定 dVAE),使其具备文本引导生成的能力。DALL-E 使用大规模的图像-文本对数据集进行训练,训练过程中使用 dVAE 的 Encoder 将图像编码为离散的 latent code。然后将文本输入 Transformer,并使用生成的 latent code 来作为 target 输出。以此就可以完成有监督的自回归训练。推理时只需输入文本,然后逐个生成图像对应的 Token,直到生成 1024 个,然后将其作为离散的 latent code 进一步生成最终图像。

image-20240113222517785

最终作者在 1024 个 16G 的 V100 GPU 上完成训练,batch size 为 1024,总共更新了 430,000 次模型,也就相当于训练了 4.3 亿图像-文本对(训练集包含 250M 图像-文本对,主要是 Conceptual Captions 和 YFFCC100M)。

6.4 DALL-E mini 模型概述

如下图所示,DALL-E mini 中作者使用 VQ-GAN 替代 dVAE,使用 Encoder + Decoder 的 BART 替代 DALL-E 中 Decoder only 的 Transformer。

dalle-e mini infer

在推理过程中,不是生成单一的图像,而是会经过采样机制生成多个 latent code,并使用 VQ-GAN 的 Decoder 生成多个候选图像,之后再使用 CLIP 提取这些图像的 embedding 和文本 embedding,之后进行比对排序,挑选出最匹配的生成结果。

image-20240113222818173

6.5 DALL-E mini 和 DALL-E 对比

DALL-E mini 和 DALL-E 在模型、训练上都有比较大的差异,具体体现在:

  • DALL-E 使用 12B 的 GPT-3 作为 Transformer,而 mini 使用的是 0.4B 的 BART,小 27 倍。
  • mini 中使用预训练的 VQ-GAN、BART 的 Encoder 以及 CLIP,而 DALL-E 从头开始训练,mini 训练代价更小。
  • DALL-E 使用 1024 个图像 Token,词表更小为 8192,而 mini 使用 256 个图像 Token,词表大小为 16384。
  • DALL-E 支持最多 256 个文本 Token,对应词表为 16,384,mini 支持最多 1024 文本 Token,词表大小为 50,264。
  • mini 使用的 BART 是 Encoder + Decoder 的,因此文本是使用双向编码,也就是每个文本 Token 都能看到所有文本 Token,而 DALL-E 是 Decoder only 的 GPT-3,文本 Token 只能看到之前的 Token。
  • DALL-E 使用 250M 图像-文本对训练,而 mini 只使用了 15M。

7 CLIP+VQ-GAN(VQGAN-CLIP)

Katherine 等人将 VQ-GAN 和 OpenAI 发布的 CLIP 模型结合起来,利用 CLIP 的图文对齐能力来赋予 VQ-GAN 文本引导生成的能力。其最大的优势是不需要额外的预训练,也不需要对 CLIP 和 VQ-GAN 进行微调,只需在推理阶段执行少量的迭代即可实现。

image-20240113223222191

如上图所示:使用初始图像通过 VQ-GAN 生成一个图像,然后使用 CLIP 对生成图像和 Target Text 提取 embedding,然后计算相似性,并将其误差作为反馈对隐空间的 Z-vector 进行迭代更新,直到生成图像和 Target Text 对应的 embedding 很相似为止。

参考文献

  1. https://arxiv.org/abs/1312.6114
  2. https://arxiv.org/abs/1711.00937
  3. https://arxiv.org/abs/1606.05328
  4. https://arxiv.org/abs/1906.00446
  5. https://arxiv.org/abs/2012.09841
  6. https://github.com/CompVis/taming-transformers
  7. https://arxiv.org/abs/2102.12092
  8. https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-Mini-Explained--Vmlldzo4NjIxODA
  9. https://github.com/borisdayma/dalle-mini
  10. https://arxiv.org/abs/2204.08583
  11. https://python.plainenglish.io/variational-autoencoder-1eb543f5f055
  12. https://ljvmiranda921.github.io/notebook/2021/08/08/clip-vqgan/

补充

Gumbel-softmax

Gumbel softmax允许模型中有从离散的分布中采样的这个过程变得可微,从而允许反向传播时可以用梯度更新模型参数。 Gumbel max trick与softmax的组合。

Gumbel-Max Trick也是使用了重参数技巧把采样过程分成了确定性的部分和随机性的部分,我们会计算所有类别的log分布概率(确定性的部分),然后加上一些噪音(随机性的部分),上面的例子中,噪音是标准高斯分布,而这里噪音是标准gumbel分布。在我们把采样过程的确定性部分和随机性部分结合起来之后,我们在此基础上再用一个argmax来找到具有最大概率的类别。自此可见,Gumbel-Max Trick由使用gumbel分布的Re-parameterization Trick和argmax组成而成,正如它的名字一样。
用公式表示的话就是:$z=argmax_i(log(\pi_i)+g_i)$ 其中 $g_i=-log(-log(u_i)),u_i\in U(0,1)$

这一项就是从 gumbel 分布采样得到的噪声,目的是使得 $z$ 的返回结果不固定,它是标准gumbel分布的CDF的逆函数

那为什么随机部分要用gumbel分布而不是常见的高斯分布呢?这是因为gumbel分布是专门用来建模从其他分布(比如高斯分布)采样出来的极值形成的分布,而我们这里“使用argmax挑出概率最大的那个类别索引 $Z$就属于取极值的操作,所以它属于极值分布