本文介绍的是当今最好的 GAN 之一,来自论文《A Style-Based Generator Architecture for Generative Adversarial Networks》的 StyleGAN ,我们将使用 PyTorch 对其进行干净、简单且可读的实现,并尝试尽可能接近原始论文。
如果您没有阅读过 StyleGAN1 论文,或者不知道它是如何工作的,但您想了解它,我强烈建议您参考这篇博文。
我们在本博文中使用的数据集是来自 Kaggle 的数据集,其中包含 16240 件女性上衣,分辨率为 256*192。
我们首先导入 torch,然后从那里导入 nn. 这将帮助我们创建和训练网络,并让我们导入 optim,一个实现各种优化算法(例如 sgd、adam 等)的包。我们从 torchvision 导入数据集和转换来准备数据并应用一些转换。
我们将从 torch.nn 导入 F 函数以使用插值对图像进行上采样,从 torch.utils.data 导入 DataLoader 以创建小批量大小,从 torchvision.utils 导入 save_image 以保存一些假样本,并使用 log2 形式的数学表示,因为我们需要2 的幂的逆表示,用于根据输出分辨率实现自适应小批量大小,NumPy 用于线性代数,os 用于与操作系统交互,tqdm 用于显示进度条,最后 matplotlib.pyplot 用于显示结果并与真值进行比较。
import torch
from torch import nn, optim
from torchvision import datasets, transforms
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from math import log2
import numpy as np
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
DATASET = "Women clothes"
START_TRAIN_AT_IMG_SIZE = 8 #The authors start from 8x8 images instead of 4x4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE = 1e-3
BATCH_SIZES = [256, 128, 64, 32, 16, 8]
CHANNELS_IMG = 3
Z_DIM = 256
W_DIM = 256
IN_CHANNELS = 256
LAMBDA_GP = 10
PROGRESSIVE_EPOCHS = [30] * len(BATCH_SIZES)
现在让我们创建一个函数get_loader来:
def get_loader(image_size):
transform = transforms.Compose(
[
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.RandomHorizontalFlip(p=0.5),
transforms.Normalize(
[0.5 for _ in range(CHANNELS_IMG)],
[0.5 for _ in range(CHANNELS_IMG)],
),
]
)
batch_size = BATCH_SIZES[int(log2(image_size / 4))]
dataset = datasets.ImageFolder(root=DATASET, transform=transform)
loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
)
return loader, dataset
现在让我们使用论文中的关键属性来实现 StyleGAN1 生成器和鉴别器(ProGAN 和 StyleGAN1 具有相同的鉴别器架构)。我们将尽力使实现紧凑,但同时保持其可读性和可理解性。具体来说,有以下几个要点:
在本教程中,我们将仅使用 StyleGAN1 生成图像,而不实现风格混合和随机变化,但这应该不难。
让我们定义一个名为 Factors 的变量,其中包含与IN_CHANNELS 相乘的数字,以获得每个图像分辨率中我们想要的通道数。
factors = [1, 1, 1, 1, 1 / 2, 1 / 4, 1 / 8, 1 / 16, 1 / 32]
噪声映射网络采用 Z 并将其放入由某些激活分隔的八个完全连接的层。并且不要忘记像作者在 ProGAN 中所做的那样均衡学习率(ProGAN 和 StyleGan 由同一研究人员编写)。
我们首先构建一个名为 WSLinear(加权缩放线性)的类,该类将从 nn.Module 继承。
在init部分,我们发送 in_features 和 out_channels。创建一个线性层,然后我们定义一个比例,该比例等于2的平方根除以in_features,我们将当前列层的偏差复制到一个变量中,因为我们不希望线性层的偏差缩放,然后我们删除它,最后,我们初始化线性层。
在前向部分,我们发送 x,我们要做的就是将 x 与比例相乘,并在重塑后添加偏差。
class WSLinear(nn.Module):
def __init__(
self, in_features, out_features,
):
super(WSLinear, self).__init__()
self.linear = nn.Linear(in_features, out_features)
self.scale = (2 / in_features)**0.5
self.bias = self.linear.bias
self.linear.bias = None
# initialize linear layer
nn.init.normal_(self.linear.weight)
nn.init.zeros_(self.bias)
def forward(self, x):
return self.linear(x * self.scale) + self.bias
现在让我们创建 MappingNetwork 类。
class MappingNetwork(nn.Module):
def __init__(self, z_dim, w_dim):
super().__init__()
self.mapping = nn.Sequential(
PixelNorm(),
WSLinear(z_dim, w_dim),
nn.ReLU(),
WSLinear(w_dim, w_dim),
nn.ReLU(),
WSLinear(w_dim, w_dim),
nn.ReLU(),
WSLinear(w_dim, w_dim),
nn.ReLU(),
WSLinear(w_dim, w_dim),
nn.ReLU(),
WSLinear(w_dim, w_dim),
nn.ReLU(),
WSLinear(w_dim, w_dim),
nn.ReLU(),
WSLinear(w_dim, w_dim),
)
def forward(self, x):
return self.mapping(x)
现在让我们创建 AdaIN 类:
class AdaIN(nn.Module):
def __init__(self, channels, w_dim):
super().__init__()
self.instance_norm = nn.InstanceNorm2d(channels)
self.style_scale = WSLinear(w_dim, channels)
self.style_bias = WSLinear(w_dim, channels)
def forward(self, x, w):
x = self.instance_norm(x)
style_scale = self.style_scale(w).unsqueeze(2).unsqueeze(3)
style_bias = self.style_bias(w).unsqueeze(2).unsqueeze(3)
return style_scale * x + style_bias
现在让我们创建 InjectNoise 类以将噪声注入生成器
class InjectNoise(nn.Module):
def __init__(self, channels):
super().__init__()
self.weight = nn.Parameter(torch.zeros(1, channels, 1, 1))
def forward(self, x):
noise = torch.randn((x.shape[0], 1, x.shape[2], x.shape[3]), device=x.device)
return x + self.weight * noise
作者在 Karras 等人对 ProGAN 的官方实现的基础上构建了 StyleGAN,他们使用相同的判别器架构、自适应小批量大小、超参数等。因此,有很多类与 ProGAN 实现保持相同。
在本节中,我们将创建与我已在本博文中解释过的 ProGAN 架构保持不变的类。
在下面的代码片段中,您可以找到 WSConv2d(加权缩放卷积层)类,以用于转换层的均衡学习率。
class WSConv2d(nn.Module):
def __init__(
self, in_channels, out_channels, kernel_size=3, stride=1, padding=1
):
super(WSConv2d, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
self.scale = (2 / (in_channels * (kernel_size ** 2))) ** 0.5
self.bias = self.conv.bias
self.conv.bias = None
# initialize conv layer
nn.init.normal_(self.conv.weight)
nn.init.zeros_(self.bias)
def forward(self, x):
return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)
在下面的代码片段中,您可以找到 PixelNorm 类,用于在噪声映射网络之前对 Z 进行归一化。
class PixelNorm(nn.Module):
def __init__(self):
super(PixelNorm, self).__init__()
self.epsilon = 1e-8
def forward(self, x):
return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon)
在下面的代码片段中,您可以找到 ConvBock 类,它将帮助我们创建鉴别器。
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(ConvBlock, self).__init__()
self.conv1 = WSConv2d(in_channels, out_channels)
self.conv2 = WSConv2d(out_channels, out_channels)
self.leaky = nn.LeakyReLU(0.2)
def forward(self, x):
x = self.leaky(self.conv1(x))
x = self.leaky(self.conv2(x))
return x
在下面的代码片段中,您可以发现类 Discriminatowich 与 ProGAN 中的类相同。
class Discriminator(nn.Module):
def __init__(self, in_channels, img_channels=3):
super(Discriminator, self).__init__()
self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([])
self.leaky = nn.LeakyReLU(0.2)
# here we work back ways from factors because the discriminator
# should be mirrored from the generator. So the first prog_block and
# rgb layer we append will work for input size 1024x1024, then 512->256-> etc
for i in range(len(factors) - 1, 0, -1):
conv_in = int(in_channels * factors[i])
conv_out = int(in_channels * factors[i - 1])
self.prog_blocks.append(ConvBlock(conv_in, conv_out))
self.rgb_layers.append(
WSConv2d(img_channels, conv_in, kernel_size=1, stride=1, padding=0)
)
# perhaps confusing name "initial_rgb" this is just the RGB layer for 4x4 input size
# did this to "mirror" the generator initial_rgb
self.initial_rgb = WSConv2d(
img_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.rgb_layers.append(self.initial_rgb)
self.avg_pool = nn.AvgPool2d(
kernel_size=2, stride=2
) # down sampling using avg pool
# this is the block for 4x4 input size
self.final_block = nn.Sequential(
# +1 to in_channels because we concatenate from MiniBatch std
WSConv2d(in_channels + 1, in_channels, kernel_size=3, padding=1),
nn.LeakyReLU(0.2),
WSConv2d(in_channels, in_channels, kernel_size=4, padding=0, stride=1),
nn.LeakyReLU(0.2),
WSConv2d(
in_channels, 1, kernel_size=1, padding=0, stride=1
), # we use this instead of linear layer
)
def fade_in(self, alpha, downscaled, out):
"""Used to fade in downscaled using avg pooling and output from CNN"""
# alpha should be scalar within [0, 1], and upscale.shape == generated.shape
return alpha * out + (1 - alpha) * downscaled
def minibatch_std(self, x):
batch_statistics = (
torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])
)
# we take the std for each example (across all channels, and pixels) then we repeat it
# for a single channel and concatenate it with the image. In this way the discriminator
# will get information about the variation in the batch/image
return torch.cat([x, batch_statistics], dim=1)
def forward(self, x, alpha, steps):
# where we should start in the list of prog_blocks, maybe a bit confusing but
# the last is for the 4x4. So example let's say steps=1, then we should start
# at the second to last because input_size will be 8x8. If steps==0 we just
# use the final block
cur_step = len(self.prog_blocks) - steps
# convert from rgb as initial step, this will depend on
# the image size (each will have it's on rgb layer)
out = self.leaky(self.rgb_layers[cur_step](x))
if steps == 0: # i.e, image is 4x4
out = self.minibatch_std(out)
return self.final_block(out).view(out.shape[0], -1)
# because prog_blocks might change the channels, for down scale we use rgb_layer
# from previous/smaller size which in our case correlates to +1 in the indexing
downscaled = self.leaky(self.rgb_layers[cur_step + 1](self.avg_pool(x)))
out = self.avg_pool(self.prog_blocks[cur_step](out))
# the fade_in is done first between the downscaled and the input
# this is opposite from the generator
out = self.fade_in(alpha, downscaled, out)
for step in range(cur_step + 1, len(self.prog_blocks)):
out = self.prog_blocks[step](out)
out = self.avg_pool(out)
out = self.minibatch_std(out)
return self.final_block(out).view(out.shape[0], -1)
在生成器架构中,我们有一些重复的模式,所以让我们首先为其创建一个类,以使我们的代码尽可能干净,让我们将类命名为 GenBlock,它将继承自 nn.Module。
class GenBlock(nn.Module):
def __init__(self, in_channels, out_channels, w_dim):
super(GenBlock, self).__init__()
self.conv1 = WSConv2d(in_channels, out_channels)
self.conv2 = WSConv2d(out_channels, out_channels)
self.leaky = nn.LeakyReLU(0.2, inplace=True)
self.inject_noise1 = InjectNoise(out_channels)
self.inject_noise2 = InjectNoise(out_channels)
self.adain1 = AdaIN(out_channels, w_dim)
self.adain2 = AdaIN(out_channels, w_dim)
def forward(self, x, w):
x = self.adain1(self.leaky(self.inject_noise1(self.conv1(x))), w)
x = self.adain2(self.leaky(self.inject_noise2(self.conv2(x))), w)
return x
现在我们已经拥有了创建生成器所需的一切。
class Generator(nn.Module):
def __init__(self, z_dim, w_dim, in_channels, img_channels=3):
super(Generator, self).__init__()
self.starting_constant = nn.Parameter(torch.ones((1, in_channels, 4, 4)))
self.map = MappingNetwork(z_dim, w_dim)
self.initial_adain1 = AdaIN(in_channels, w_dim)
self.initial_adain2 = AdaIN(in_channels, w_dim)
self.initial_noise1 = InjectNoise(in_channels)
self.initial_noise2 = InjectNoise(in_channels)
self.initial_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
self.leaky = nn.LeakyReLU(0.2, inplace=True)
self.initial_rgb = WSConv2d(
in_channels, img_channels, kernel_size=1, stride=1, padding=0
)
self.prog_blocks, self.rgb_layers = (
nn.ModuleList([]),
nn.ModuleList([self.initial_rgb]),
)
for i in range(len(factors) - 1): # -1 to prevent index error because of factors[i+1]
conv_in_c = int(in_channels * factors[i])
conv_out_c = int(in_channels * factors[i + 1])
self.prog_blocks.append(GenBlock(conv_in_c, conv_out_c, w_dim))
self.rgb_layers.append(
WSConv2d(conv_out_c, img_channels, kernel_size=1, stride=1, padding=0)
)
def fade_in(self, alpha, upscaled, generated):
# alpha should be scalar within [0, 1], and upscale.shape == generated.shape
return torch.tanh(alpha * generated + (1 - alpha) * upscaled)
def forward(self, noise, alpha, steps):
w = self.map(noise)
x = self.initial_adain1(self.initial_noise1(self.starting_constant), w)
x = self.initial_conv(x)
out = self.initial_adain2(self.leaky(self.initial_noise2(x)), w)
if steps == 0:
return self.initial_rgb(x)
for step in range(steps):
upscaled = F.interpolate(out, scale_factor=2, mode="bilinear")
out = self.prog_blocks[step](upscaled, w)
# The number of channels in upscale will stay the same, while
# out which has moved through prog_blocks might change. To ensure
# we can convert both to rgb we use different rgb_layers
# (steps-1) and steps for upscaled, out respectively
final_upscaled = self.rgb_layers[steps - 1](upscaled)
final_out = self.rgb_layers[steps](out)
return self.fade_in(alpha, final_upscaled, final_out)
在下面的代码片段中,您可以找到generate_examples函数,该函数采用生成器gen 、识别当前分辨率的步骤数以及数字n=100。该函数的目标是生成n 个假图像并将其保存为结果。
def generate_examples(gen, steps, n=100):
gen.eval()
alpha = 1.0
for i in range(n):
with torch.no_grad():
noise = torch.randn(1, Z_DIM).to(DEVICE)
img = gen(noise, alpha, steps)
if not os.path.exists(f'saved_examples/step{steps}'):
os.makedirs(f'saved_examples/step{steps}')
save_image(img*0.5+0.5, f"saved_examples/step{steps}/img_{i}.png")
gen.train()
在下面的代码片段中,您可以找到 WGAN-GP 损失的gradient_penalty 函数。
def gradient_penalty(critic, real, fake, alpha, train_step, device="cpu"):
BATCH_SIZE, C, H, W = real.shape
beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
interpolated_images = real * beta + fake.detach() * (1 - beta)
interpolated_images.requires_grad_(True)
# Calculate critic scores
mixed_scores = critic(interpolated_images, alpha, train_step)
# Take the gradient of the scores with respect to the images
gradient = torch.autograd.grad(
inputs=interpolated_images,
outputs=mixed_scores,
grad_outputs=torch.ones_like(mixed_scores),
create_graph=True,
retain_graph=True,
)[0]
gradient = gradient.view(gradient.shape[0], -1)
gradient_norm = gradient.norm(2, dim=1)
gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
return gradient_penalty
在本节中,我们将训练 StyleGAN
对于训练函数,我们为生成器和批评者发送批评者(即鉴别器)、生成器(生成器)、加载器、数据集、步骤、alpha 和优化器。
我们首先循环使用 DataLoader 创建的所有小批量大小,并且只获取图像,因为我们不需要标签。
然后,当我们想要最大化E(critic(real)) - E(critic(fake))时,我们为判别器\Critic 设置训练。这个方程意味着评论家可以区分真实和虚假图像的程度。
之后,当我们想要最大化E(critic(fake)) 时,我们为生成器设置训练。
最后,我们更新循环和 fade_in 的 alpha 值并确保它在 0 和 1 之间,然后返回它。
def train_fn(
critic,
gen,
loader,
dataset,
step,
alpha,
opt_critic,
opt_gen,
):
loop = tqdm(loader, leave=True)
for batch_idx, (real, _) in enumerate(loop):
real = real.to(DEVICE)
cur_batch_size = real.shape[0]
noise = torch.randn(cur_batch_size, Z_DIM).to(DEVICE)
fake = gen(noise, alpha, step)
critic_real = critic(real, alpha, step)
critic_fake = critic(fake.detach(), alpha, step)
gp = gradient_penalty(critic, real, fake, alpha, step, device=DEVICE)
loss_critic = (
-(torch.mean(critic_real) - torch.mean(critic_fake))
+ LAMBDA_GP * gp
+ (0.001 * torch.mean(critic_real ** 2))
)
critic.zero_grad()
loss_critic.backward()
opt_critic.step()
gen_fake = critic(fake, alpha, step)
loss_gen = -torch.mean(gen_fake)
gen.zero_grad()
loss_gen.backward()
opt_gen.step()
# Update alpha and ensure less than 1
alpha += cur_batch_size / (
(PROGRESSIVE_EPOCHS[step] * 0.5) * len(dataset)
)
alpha = min(alpha, 1)
loop.set_postfix(
gp=gp.item(),
loss_critic=loss_critic.item(),
)
return alpha
现在我们已经拥有了一切,让我们将它们放在一起来训练我们的 StyleGAN。
我们首先初始化生成器、判别器/批评器和优化器,然后将生成器和批评器转换为训练模式,然后循环 PROGRESSIVE_EPOCHS,在每个循环中,我们调用训练函数的纪元数,然后生成一些伪造图像并使用generate_examples函数保存它们,最后,我们进入下一个图像分辨率。
gen = Generator(
Z_DIM, W_DIM, IN_CHANNELS, img_channels=CHANNELS_IMG
).to(DEVICE)
critic = Discriminator(IN_CHANNELS, img_channels=CHANNELS_IMG).to(DEVICE)
# initialize optimizers
opt_gen = optim.Adam([{"params": [param for name, param in gen.named_parameters() if "map" not in name]},
{"params": gen.map.parameters(), "lr": 1e-5}], lr=LEARNING_RATE, betas=(0.0, 0.99))
opt_critic = optim.Adam(
critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.99)
)
gen.train()
critic.train()
# start at step that corresponds to img size that we set in config
step = int(log2(START_TRAIN_AT_IMG_SIZE / 4))
for num_epochs in PROGRESSIVE_EPOCHS[step:]:
alpha = 1e-5 # start with very low alpha
loader, dataset = get_loader(4 * 2 ** step)
print(f"Current image size: {4 * 2 ** step}")
for epoch in range(num_epochs):
print(f"Epoch [{epoch+1}/{num_epochs}]")
alpha = train_fn(
critic,
gen,
loader,
dataset,
step,
alpha,
opt_critic,
opt_gen
)
generate_examples(gen, step)
step += 1 # progress to the next img size
希望您能够遵循所有步骤,并充分了解如何以正确的方式实施 StyleGAN。现在让我们看看在分辨率为 128*x 128 的数据集中训练该模型后获得的结果。