这篇博文是关于 StyleGAN2 的,来自论文Analyzing and Improving the Image Quality of StyleGAN,我们将使用 PyTorch 对其进行干净、简单且可读的实现,并尝试尽可能地还原原始论文。
如果您没有阅读 StyleGAN2 论文。或者不知道它是如何工作的并且你想了解它,我强烈建议你看看扫一下原始论文,了解其主要思想。
我们在本博客中使用的数据集是来自 Kaggle 的数据集,其中包含 16240 件女性上衣,分辨率为 256*192。
一如既往,让我们首先加载我们需要的所有依赖项。
我们首先导入 torch,因为我们将使用 PyTorch,然后从那里导入 nn. 这将帮助我们创建和训练网络,并让我们导入 optim,一个实现各种优化算法(例如 sgd、adam 等)的包。我们从 torchvision 导入数据集和转换来准备数据并应用一些转换。
我们将从 torch.nn 导入 F 函数,从 torch.utils.data 导入 DataLoader 以创建小批量大小,从 torchvision.utils 导入 save_image 以保存一些假样本,log2 和 sqrt 形成数学,Numpy 用于线性代数,操作系统用于交互使用操作系统,tqdm 显示进度条,最后使用 matplotlib.pyplot 绘制一些图像。
import torch
from torch import nn, optim
from torchvision import datasets, transforms
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from math import log2, sqrt
import numpy as np
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
DATASET = "Women clothes"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
EPOCHS = 300
LEARNING_RATE = 1e-3
BATCH_SIZE = 32
LOG_RESOLUTION = 7 #for 128*128
Z_DIM = 256
W_DIM = 256
LAMBDA_GP = 10
现在让我们创建一个函数get_loader来:
def get_loader():
transform = transforms.Compose(
[
transforms.Resize((2 ** LOG_RESOLUTION, 2 ** LOG_RESOLUTION)),
transforms.ToTensor(),
transforms.RandomHorizontalFlip(p=0.5),
transforms.Normalize(
[0.5, 0.5, 0.5],
[0.5, 0.5, 0.5],
),
]
)
dataset = datasets.ImageFolder(root=DATASET, transform=transform)
loader = DataLoader(
dataset,
batch_size=BATCH_SIZE,
shuffle=True,
)
return loader
现在让我们使用论文中的关键属性来实现 StyleGAN2 网络。我们将尽力使实现紧凑,但同时保持其可读性和可理解性。具体来说,有以下几个要点:
让我们创建将从 nn.Module 继承的 MappingNetwork 类。
在init部分,我们发送 z_dim 和 w_din,并定义包含 8 个 EqualizedLinear 的网络映射,这是我们稍后将实现的用于均衡学习率的类,以及作为激活函数的 ReLu
在前一部分中,我们使用像素范数初始化 z_dim,然后返回网络映射。
class MappingNetwork(nn.Module):
def __init__(self, z_dim, w_dim):
super().__init__()
self.mapping = nn.Sequential(
EqualizedLinear(z_dim, w_dim),
nn.ReLU(),
EqualizedLinear(z_dim, w_dim),
nn.ReLU(),
EqualizedLinear(z_dim, w_dim),
nn.ReLU(),
EqualizedLinear(z_dim, w_dim),
nn.ReLU(),
EqualizedLinear(z_dim, w_dim),
nn.ReLU(),
EqualizedLinear(z_dim, w_dim),
nn.ReLU(),
EqualizedLinear(z_dim, w_dim),
nn.ReLU(),
EqualizedLinear(z_dim, w_dim)
)
def forward(self, x):
x = x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + 1e-8) # for PixelNorm
return self.mapping(x)
在下图中,您可以看到生成器架构,它以初始常量开始。然后它有一系列的块。每个块的特征图分辨率加倍。每个块输出一个 RGB 图像,它们被放大并求和以获得最终的 RGB 图像。
toRGB还有一个风格调制,为简单起见,图中未显示。
为了使代码尽可能简洁,在生成器的实现中,我们将使用稍后定义的三个类(StyleBlock、toRGB 和 GeneratorBlock)。
class Generator(nn.Module):
def __init__(self, log_resolution, W_DIM, n_features = 32, max_features = 256):
super().__init__()
features = [min(max_features, n_features * (2 ** i)) for i in range(log_resolution - 2, -1, -1)]
self.n_blocks = len(features)
self.initial_constant = nn.Parameter(torch.randn((1, features[0], 4, 4)))
self.style_block = StyleBlock(W_DIM, features[0], features[0])
self.to_rgb = ToRGB(W_DIM, features[0])
blocks = [GeneratorBlock(W_DIM, features[i - 1], features[i]) for i in range(1, self.n_blocks)]
self.blocks = nn.ModuleList(blocks)
def forward(self, w, input_noise):
batch_size = w.shape[1]
x = self.initial_constant.expand(batch_size, -1, -1, -1)
x = self.style_block(x, w[0], input_noise[0][1])
rgb = self.to_rgb(x, w[0])
for i in range(1, self.n_blocks):
x = F.interpolate(x, scale_factor=2, mode="bilinear")
x, rgb_new = self.blocks[i - 1](x, w[i], input_noise[i])
rgb = F.interpolate(rgb, scale_factor=2, mode="bilinear") + rgb_new
return torch.tanh(rgb)
在下图中,您可以看到生成器block架构,它由两个风格blocks(带有风格调制的 3×3 卷积)和 RGB 输出组成。
class GeneratorBlock(nn.Module):
def __init__(self, W_DIM, in_features, out_features):
super().__init__()
self.style_block1 = StyleBlock(W_DIM, in_features, out_features)
self.style_block2 = StyleBlock(W_DIM, out_features, out_features)
self.to_rgb = ToRGB(W_DIM, out_features)
def forward(self, x, w, noise):
x = self.style_block1(x, w, noise[0])
x = self.style_block2(x, w, noise[1])
rgb = self.to_rgb(x, w)
return x, rgb
class StyleBlock(nn.Module):
def __init__(self, W_DIM, in_features, out_features):
super().__init__()
self.to_style = EqualizedLinear(W_DIM, in_features, bias=1.0)
self.conv = Conv2dWeightModulate(in_features, out_features, kernel_size=3)
self.scale_noise = nn.Parameter(torch.zeros(1))
self.bias = nn.Parameter(torch.zeros(out_features))
self.activation = nn.LeakyReLU(0.2, True)
def forward(self, x, w, noise):
s = self.to_style(w)
x = self.conv(x, s)
if noise is not None:
x = x + self.scale_noise[None, :, None, None] * noise
return self.activation(x + self.bias[None, :, None, None])
class ToRGB(nn.Module):
def __init__(self, W_DIM, features):
super().__init__()
self.to_style = EqualizedLinear(W_DIM, features, bias=1.0)
self.conv = Conv2dWeightModulate(features, 3, kernel_size=1, demodulate=False)
self.bias = nn.Parameter(torch.zeros(3))
self.activation = nn.LeakyReLU(0.2, True)
def forward(self, x, w):
style = self.to_style(w)
x = self.conv(x, style)
return self.activation(x + self.bias[None, :, None, None])
此类通过样式向量缩放卷积权重,并通过对其进行归一化来解调。
class Conv2dWeightModulate(nn.Module):
def __init__(self, in_features, out_features, kernel_size,
demodulate = True, eps = 1e-8):
super().__init__()
self.out_features = out_features
self.demodulate = demodulate
self.padding = (kernel_size - 1) // 2
self.weight = EqualizedWeight([out_features, in_features, kernel_size, kernel_size])
self.eps = eps
def forward(self, x, s):
b, _, h, w = x.shape
s = s[:, None, :, None, None]
weights = self.weight()[None, :, :, :, :]
weights = weights * s
if self.demodulate:
sigma_inv = torch.rsqrt((weights ** 2).sum(dim=(2, 3, 4), keepdim=True) + self.eps)
weights = weights * sigma_inv
x = x.reshape(1, -1, h, w)
_, _, *ws = weights.shape
weights = weights.reshape(b * self.out_features, *ws)
x = F.conv2d(x, weights, padding=self.padding, groups=b)
return x.reshape(-1, self.out_features, h, w)
在下图中,您可以看到鉴别器架构。它首先将分辨率为
2
L
O
G
_
R
E
S
O
L
U
T
I
O
N
x
2
L
O
G
_
R
E
S
O
L
U
T
I
O
N
2 ^{LOG\_RESOLUTION} x 2^{LOG\_RESOLUTION}
2LOG_RESOLUTIONx2LOG_RESOLUTION的图像转换 为相同分辨率的特征图,然后通过一系列具有残差连接的块来运行它。每个块的分辨率下采样 2 倍,同时特征数量加倍。
class Discriminator(nn.Module):
def __init__(self, log_resolution, n_features = 64, max_features = 256):
super().__init__()
features = [min(max_features, n_features * (2 ** i)) for i in range(log_resolution - 1)]
self.from_rgb = nn.Sequential(
EqualizedConv2d(3, n_features, 1),
nn.LeakyReLU(0.2, True),
)
n_blocks = len(features) - 1
blocks = [DiscriminatorBlock(features[i], features[i + 1]) for i in range(n_blocks)]
self.blocks = nn.Sequential(*blocks)
final_features = features[-1] + 1
self.conv = EqualizedConv2d(final_features, final_features, 3)
self.final = EqualizedLinear(2 * 2 * final_features, 1)
def minibatch_std(self, x):
batch_statistics = (
torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])
)
return torch.cat([x, batch_statistics], dim=1)
def forward(self, x):
x = self.from_rgb(x)
x = self.blocks(x)
x = self.minibatch_std(x)
x = self.conv(x)
x = x.reshape(x.shape[0], -1)
return self.final(x)
在下图中,您可以看到判别器blocks架构,它由两个带有残差连接的 3×3 卷积组成。
class DiscriminatorBlock(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.residual = nn.Sequential(nn.AvgPool2d(kernel_size=2, stride=2), # down sampling using avg pool
EqualizedConv2d(in_features, out_features, kernel_size=1))
self.block = nn.Sequential(
EqualizedConv2d(in_features, in_features, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, True),
EqualizedConv2d(in_features, out_features, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, True),
)
self.down_sample = nn.AvgPool2d(
kernel_size=2, stride=2
) # down sampling using avg pool
self.scale = 1 / sqrt(2)
def forward(self, x):
residual = self.residual(x)
x = self.block(x)
x = self.down_sample(x)
return (x + residual) * self.scale
现在是时候实现EqualizedLinear了,我们之前在几乎每个类中都使用它来均衡线性层的学习率。
class EqualizedLinear(nn.Module):
def __init__(self, in_features, out_features, bias = 0.):
super().__init__()
self.weight = EqualizedWeight([out_features, in_features])
self.bias = nn.Parameter(torch.ones(out_features) * bias)
def forward(self, x: torch.Tensor):
return F.linear(x, self.weight(), bias=self.bias)
现在让我们实现之前用来均衡卷积层学习率的EqualizedConv2d 。
class EqualizedConv2d(nn.Module):
def __init__(self, in_features, out_features,
kernel_size, padding = 0):
super().__init__()
self.padding = padding
self.weight = EqualizedWeight([out_features, in_features, kernel_size, kernel_size])
self.bias = nn.Parameter(torch.ones(out_features))
def forward(self, x: torch.Tensor):
return F.conv2d(x, self.weight(), bias=self.bias, padding=self.padding)
现在让我们实现在学习率均衡线性层和学习率均衡 2D 卷积层中使用的EqualizedWeight类。
这是基于 ProGAN 论文中引入的均衡学习率。他们不是将权重初始化为 N(0, c ),而是将权重初始化为 N(0,1),然后在使用时将其乘以c。
class EqualizedWeight(nn.Module):
def __init__(self, shape):
super().__init__()
self.c = 1 / sqrt(np.prod(shape[1:]))
self.weight = nn.Parameter(torch.randn(shape))
def forward(self):
return self.weight * self.c
感知路径长度归一化鼓励w中的固定大小步长,以导致图像中固定大小的变化。
其中
J
w
J_w
Jw?使用以下等式计算,w 从映射网络中采样,y是带有噪声 N(0, I) 的图像,a是训练过程中的指数移动平均值。
class PathLengthPenalty(nn.Module):
def __init__(self, beta):
super().__init__()
self.beta = beta
self.steps = nn.Parameter(torch.tensor(0.), requires_grad=False)
self.exp_sum_a = nn.Parameter(torch.tensor(0.), requires_grad=False)
def forward(self, w, x):
device = x.device
image_size = x.shape[2] * x.shape[3]
y = torch.randn(x.shape, device=device)
output = (x * y).sum() / sqrt(image_size)
sqrt(image_size)
gradients, *_ = torch.autograd.grad(outputs=output,
inputs=w,
grad_outputs=torch.ones(output.shape, device=device),
create_graph=True)
norm = (gradients ** 2).sum(dim=2).mean(dim=1).sqrt()
if self.steps > 0:
a = self.exp_sum_a / (1 - self.beta ** self.steps)
loss = torch.mean((norm - a) ** 2)
else:
loss = norm.new_tensor(0)
mean = norm.mean().detach()
self.exp_sum_a.mul_(self.beta).add_(mean, alpha=1 - self.beta)
self.steps.add_(1.)
return loss
在下面的代码片段中,您可以找到 WGAN-GP 损失的gradient_penalty 函数。
def gradient_penalty(critic, real, fake,device="cpu"):
BATCH_SIZE, C, H, W = real.shape
beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
interpolated_images = real * beta + fake.detach() * (1 - beta)
interpolated_images.requires_grad_(True)
# Calculate critic scores
mixed_scores = critic(interpolated_images)
# Take the gradient of the scores with respect to the images
gradient = torch.autograd.grad(
inputs=interpolated_images,
outputs=mixed_scores,
grad_outputs=torch.ones_like(mixed_scores),
create_graph=True,
retain_graph=True,
)[0]
gradient = gradient.view(gradient.shape[0], -1)
gradient_norm = gradient.norm(2, dim=1)
gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
return gradient_penalty
该函数对 Z 进行随机采样,并从映射网络中获取 W。
def get_w(batch_size):
z = torch.randn(batch_size, W_DIM).to(DEVICE)
w = mapping_network(z)
return w[None, :, :].expand(LOG_RESOLUTION, -1, -1)
该函数为每个生成器block组生成噪声
def get_noise(batch_size):
noise = []
resolution = 4
for i in range(LOG_RESOLUTION):
if i == 0:
n1 = None
else:
n1 = torch.randn(batch_size, 1, resolution, resolution, device=DEVICE)
n2 = torch.randn(batch_size, 1, resolution, resolution, device=DEVICE)
noise.append((n1, n2))
resolution *= 2
return noise
在下面的代码片段中,您可以找到generate_examples函数,它接受生成器gen 、epoch数和n=100。该函数的目标是生成n 个假图像并将它们保存为每个epoch的结果。
def generate_examples(gen, epoch, n=100):
gen.eval()
alpha = 1.0
for i in range(n):
with torch.no_grad():
w = get_w(1)
noise = get_noise(1)
img = gen(w, noise)
if not os.path.exists(f'saved_examples/epoch{epoch}'):
os.makedirs(f'saved_examples/epoch{epoch}')
save_image(img*0.5+0.5, f"saved_examples/epoch{epoch}/img_{i}.png")
gen.train()
在本节中,我们将训练 StyleGAN2。
让我们首先创建训练函数,该函数采用判别器/批评器、生成器 gen、每 16 个 epoch 使用的 path_length_penalty、加载器和网络优化器。我们首先循环使用 DataLoader 创建的所有小批量大小,并且只获取图像,因为我们不需要标签。
然后,当我们想要最大化E(critic(real)) - E(critic(fake))时,我们为判别器\Critic 设置训练。这个方程意味着评论家可以区分真实和虚假图像的程度。
之后,当我们想要最大化E(critic(fake))时,我们为生成器和映射网络设置训练,并且每 16 个时期向该函数添加一个感知路径长度。
最后,我们更新循环。
def train_fn(
critic,
gen,
path_length_penalty,
loader,
opt_critic,
opt_gen,
opt_mapping_network,
):
loop = tqdm(loader, leave=True)
for batch_idx, (real, _) in enumerate(loop):
real = real.to(DEVICE)
cur_batch_size = real.shape[0]
w = get_w(cur_batch_size)
noise = get_noise(cur_batch_size)
with torch.cuda.amp.autocast():
fake = gen(w, noise)
critic_fake = critic(fake.detach())
critic_real = critic(real)
gp = gradient_penalty(critic, real, fake, device=DEVICE)
loss_critic = (
-(torch.mean(critic_real) - torch.mean(critic_fake))
+ LAMBDA_GP * gp
+ (0.001 * torch.mean(critic_real ** 2))
)
critic.zero_grad()
loss_critic.backward()
opt_critic.step()
gen_fake = critic(fake)
loss_gen = -torch.mean(gen_fake)
if batch_idx % 16 == 0:
plp = path_length_penalty(w, fake)
if not torch.isnan(plp):
loss_gen = loss_gen + plp
mapping_network.zero_grad()
gen.zero_grad()
loss_gen.backward()
opt_gen.step()
opt_mapping_network.step()
loop.set_postfix(
gp=gp.item(),
loss_critic=loss_critic.item(),
)
现在让我们初始化加载器、网络和优化器,并使网络处于训练模式
loader = get_loader()
gen = Generator(LOG_RESOLUTION, W_DIM).to(DEVICE)
critic = Discriminator(LOG_RESOLUTION).to(DEVICE)
mapping_network = MappingNetwork(Z_DIM, W_DIM).to(DEVICE)
path_length_penalty = PathLengthPenalty(0.99).to(DEVICE)
opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.99))
opt_critic = optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.99))
opt_mapping_network = optim.Adam(mapping_network.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.99))
gen.train()
critic.train()
mapping_network.train()
现在让我们使用训练循环来训练网络,并在每 50 个 epoch 中保存一些假样本。
loader = get_loader()
for epoch in range(EPOCHS):
train_fn(
critic,
gen,
path_length_penalty,
loader,
opt_critic,
opt_gen,
opt_mapping_network,
)
if epoch % 50 == 0:
generate_examples(gen, epoch)
在本文中,我们使用 PyTorch 从头开始??为 StyleGAN2 这个大型项目制作了一个干净、简单且可读的实现。我们尝试尽可能地复制原始论文。