U-Net是生成式扩散模型的核心。它的输入有三个:(1)带噪声的图片 (2)时间标签 (3)其他条件变量。经过层层运算,得到一个噪声输出。该噪声输出可用于给图片去噪。
这里推荐一个diffusion实现手写数字的源代码https://github.com/TeaPearce/Conditional_Diffusion_MNIST,适合新手入门。本文主要讲解其中U-Net的工作过程。
扩散模型中的U-net结构如下图所示,1X28X28表示通道数为1,长宽为28的图片。在实际训练中不是一个三阶张量而是一个四阶张量128X1X28X28,其中128表示批处理数,即128张图片同时在GPU上完成一次训练迭代。
整个计算流程如下:输入图片(A)被提取出128张特征图(B),经过第一次下采样图像缩小一半(C),经过第二次下采样图像进一步缩小为一半(D),经过平均池化得到一个向量(E),这个向量包含了图片中的所有必要特征信息。至此,输入图片已被编码。除了图片以外,时间标签、其他条件变量也可使用全链接网络进行编码,得到两个向量(F和G),为了确保后续上采样顺利,E、F、G的长度应当相同。接下来,将E、F、G合并为一个更长的向量H。H经过上采用不断恢复出I、J、K直到L。L即为最终期望输出的噪声图。用这个噪声图即可实现对图片的去噪。
接下来以一个batch size=128来说明上述过程。