?
论文链接:https://arxiv.org/abs/2308.08339
代码链接:https://github.com/AAleka/retree
在视网膜图像合成方面表现优于传统的生成对抗网络(GAN)。
模型 分为两阶段。
两个DDPM模型和ta们在生成视网膜图像过程中的作用:
每个模型都包含以下部分:
在每个扩散步骤t,模型都会预测给定扩散步骤t-1的输出。
每个DDPM模型的输出都用作下一步的输入。
在第一个模型中,血管树图像与带噪声的输入进行了连接(Concatenation),在第二个模型中,彩色视网膜图像与带噪声的输入进行了连接。
整个流程展示了一个完整的数据生成和处理流程:
?
去噪扩散概率模型(DDPM)的架构,由四个主要部分组成:下采样块、ViT编码器块、上采样块和瓶颈块。
图中还展示了各个部分的输入和输出关系,以及其中使用的矩阵乘法(×)和加法(+)操作。
例如,在ViT编码器块中,q(查询矩阵)、k(键矩阵)和v(值矩阵)参与了矩阵乘法和Softmax操作以实现自注意力机制。
在整个模型中,t表示扩散时间步,它在模型的不同阶段被用来调整输入特征和生成的时间嵌入。
DDPM的损失函数:DDPM使用了两种损失函数的组合,即L1损失(或称平均绝对误差MAE)和均方误差(MSE)。
这两种损失函数用于计算预测噪声和实际应用噪声之间的差异
GenLoss ( ? θ ( x t , t ) , z ) = 1 N ∑ i = 1 N ( ∣ ? θ ? z ∣ + ( ? θ ? z ) 2 ) \text{GenLoss}(\epsilon_{\theta}(x_t, t), z) = \frac{1}{N}\sum_{i=1}^{N}(|\epsilon_{\theta} - z| + (\epsilon_{\theta} - z)^2) GenLoss(?θ?(xt?,t),z)=N1?∑i=1N?(∣?θ??z∣+(?θ??z)2)
N是图像像素的总数, ? θ \epsilon_{\theta} ?θ? 是预测的噪声,z是实际应用的噪声。
超分辨率模型的损失函数:超分辨率模型使用了五种损失函数的组合,包括L1损失、结构相似性指数(SSIM)、对抗损失、二元交叉熵(BCE)和感知损失。
SSIM用于计算两幅图像之间的结构相似性:
SSIM ( X , Y ) = ( 2 μ X μ Y + C 1 ) ( 2 σ X Y + C 2 ) ( μ X 2 + μ Y 2 + C 1 ) ( σ X 2 + σ Y 2 + C 2 ) \text{SSIM}(X, Y) = \frac{(2\mu_X\mu_Y + C1)(2\sigma_{XY} + C2)}{(\mu_X^2 + \mu_Y^2 + C1)(\sigma_X^2 + \sigma_Y^2 + C2)} SSIM(X,Y)=(μX2?+μY2?+C1)(σX2?+σY2?+C2)(2μX?μY?+C1)(2σXY?+C2)?
μ \mu μ 是均值, σ \sigma σ 是方差,C1和C2是稳定除法的变量。
SSIMLoss ( X , Y ) = 1 ? SSIM ( X , Y ) 2 \text{SSIMLoss}(X, Y) = 1 - \frac{\text{SSIM}(X, Y)}{2} SSIMLoss(X,Y)=1?2SSIM(X,Y)?
对抗损失:对抗损失使用鉴别器模型的输出计算,鉴别器模型负责将输入图像分类为真实或伪造。
鉴别器模型的损失函数定义如下:
AdvLoss ( Y , X ) = ? E Y [ log ? ( 1 ? D ( Y , X ) ) ] ? E X [ log ? ( D ( Y , X ) ) ] \text{AdvLoss}(Y, X) = -\mathbb{E}_Y[\log(1 - D(Y, X))] - \mathbb{E}_X[\log(D(Y, X))] AdvLoss(Y,X)=?EY?[log(1?D(Y,X))]?EX?[log(D(Y,X))]
D ( x r , x f ) D(x_r, x_f) D(xr?,xf?) 是鉴别器模型给出的预测。
感知损失:感知损失使用VGG19网络测量激活特征之间的距离,并使用MSE损失函数最小化该距离。
分割模型的BCE损失:分割模型使用二元交叉熵损失训练。
BCELoss = ? 1 N ∑ i = 1 N ( x i t ? 1 log ? ( x θ i t ? 1 ) + ( 1 ? x i t ? 1 ) log ? ( 1 ? x θ i t ? 1 ) ) \text{BCELoss} = -\frac{1}{N}\sum_{i=1}^{N}(x_{it-1}\log(x_{\theta_{it-1}}) + (1 - x_{it-1})\log(1 - x_{\theta_{it-1}})) BCELoss=?N1?∑i=1N?(xit?1?log(xθit?1??)+(1?xit?1?)log(1?xθit?1??))
N代表图像中的像素数。