Scalable Diffusion Models with Transformers
公和众和号:EDPJ(进 Q 交流群:922230617 或加 VX:CV_EDPJ 进 V 交流群)
目录
我们探索一种基于 Transformer 架构的新型扩散模型。我们训练图像的潜在扩散模型,将通常使用的 U-Net 骨干替换为在潜在 patch 上操作的 Transformer。我们通过前向传播复杂度的角度分析我们的扩散 Transformer(Diffusion Transformers,DiT)的可扩展性,该复杂度由 Gflops 测量。我们发现,具有更高 Gflops 的 DiTs(通过增加 Transformer 的深度/宽度或增加输入标记的数量)始终具有较低的 FID。除了具有良好的可扩展性特性外,我们最大的 DiT-XL/2 模型在类条件ImageNet 512x512 和 256x256 基准上优于所有先前的扩散模型,后者实现了 2.27 的最先进 FID。
DDPM,无分类器引导,LDM
架构复杂性。在图像生成领域评估架构复杂性时,使用参数计数是一种相当常见的做法。总体上,参数计数可能是评估图像模型复杂性的不良代理,因为它们未考虑例如影响性能的图像分辨率等因素 [44, 45]。相反,本文中对模型复杂性的分析主要通过理论上的 Gflops 视角进行。这使我们与架构设计文献保持一致,其中 Gflops 被广泛用于衡量复杂性。在实践中,黄金复杂度度量(golden complexity metric)仍然存在争议,因为它经常取决于特定的应用场景。Nichol 和 Dhariwal 的开创性工作改进扩散模型 [9, 36] 与我们最相关——其中,他们分析了 U-Net 架构类的可扩展性和Gflop 特性。在本文中,我们专注于 Transformer 类。?
我们介绍了扩散 Transformer(Diffusion Transformers,简称DiTs),这是一种新的扩散模型架构。我们的目标是尽可能忠实于标准的 Transformer 架构,以保留其扩展性质。由于我们的重点是训练图像的 DDPM(Diffusion-Probabilistic Models,扩散概率模型),特别是图像的空间表示,DiT 基于 Vision Transformer(ViT)架构,该架构操作于图像的 patch 序列 [10]。DiT 保留了 ViTs 的许多最佳实践。图 3 显示了完整的 DiT 架构概述。在本节中,我们描述了 DiT 的前向传播,以及 DiT 类的设计空间组件。
Patchify。DiT 的输入是一个空间表示 z(对于 256x256x3 的图像,z 的形状为 32x32x4)。DiT 的第一层是 “patchify”,它通过线性嵌入输入中的每个 patch,将空间输入转换为 T 个维度为 d 的标记序列。在 patchify 之后,我们对所有输入标记应用标准 ViT 的基于频率的位置嵌入(正弦-余弦版本)。由 patchify 创建的标记数 T 由 patch 大小超参数 p 确定。如图 4 所示,减半 p 将使 T 成倍增加,从而至少使 transformer 的总 Gflops 成倍增加。尽管这对 Gflops 有重要影响,但请注意,更改 p 对下游参数计数没有实质性影响。
(我们将 p = 2, 4, 8 添加到 DiT 的设计空间。)
DiT 设计。在 patchify 之后,输入标记由一系列 transformer 块处理。除了带噪声的图像输入外,扩散模型有时还处理额外的条件信息,如噪声时间步 t、类标签 c、自然语言等。我们探索了四个变体的 transformer 块,这些块以不同方式处理条件输入。这些设计在标准 ViT 块设计中引入了小但重要的修改。所有块的设计都显示在图 3 中。
(我们在 DiT 设计空间中包括了上下文、交叉注意力、自适应层归一化和 adaLN-Zero 块。)
模型大小。我们应用 N 个 DiT 块的序列,每个块在隐藏维度大小为 d。遵循 ViT,我们使用共同缩放 N、d 和注意力头的标准 transformer配置 [10, 63]。具体而言,我们使用四个配置:DiT-S、DiT-B、DiT-L 和 DiT-XL。它们涵盖了各种模型大小和 Gflops 分配,从 0.3 到 118.6 Gflops,使我们能够评估缩放性能。表 1 提供了配置的详细信息。
(我们在 DiT 设计空间中添加了 B、S、L 和 XL 配置。)
Transformer 解码器。在最后的 DiT 块之后,我们需要将图像标记的序列解码为输出噪声预测和输出对角协方差预测。这两个输出的形状与原始的空间输入相等。我们使用标准的线性解码器来完成这个任务;我们将最终的层归一化(如果使用 adaLN,则是自适应的)应用于每个标记,并将每个标记线性解码为一个 p x p x 2C 的张量,其中 C 是输入到 DiT 的空间输入中的通道数。最后,我们重新排列解码的标记以获得预测的噪声和协方差。
(我们探索的完整 DiT 设计空间包括补丁大小、transformer 块架构和模型大小。)