UNet是一种用于图像分割的神经网络,由于这个算法前后两个部分在处理上比较对称,类似一个U形,如下图所示,故称之为Unet,论文链接:U-Net: Convolutional Networks for Biomedical Image Segmentation,全文仅8页。
从此图可以看出,左边的基础操作是两次 3 × 3 3\times3 3×3卷积后池化,连续4次,图像从 572 × 572 572\times572 572×572变成 32 × 32 32\times32 32×32。右侧则调转过来,以两次 3 × 3 3\times3 3×3卷积核一个 2 × 2 2\times2 2×2上采样卷积作为一组,再来四次,最后恢复成 388 × 388 388\times388 388×388的图像。
整理一下上图,其计算顺序依次是
由于两次 3 × 3 3\times3 3×3卷积累计出现多次,故而先将其封装成类,便于后续调用
import torch
import torch.nn as nn
import torch.nn.functional as F
class DoubleConv(nn.Module):
def __init__(self, inSize, outSize):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(inSize, outSize, kernel_size=3, padding=1),
nn.BatchNorm2d(outSize),
nn.ReLU(inplace=True),
nn.Conv2d(outSize, outSize, kernel_size=3, padding=1),
nn.BatchNorm2d(outSize),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.conv(x)
然后分别实现其降采样、上采样以及最终的输出过程,其中降采样没什么好说的,就是两次卷积一次池化,最终输出的 1 × 1 1\times1 1×1卷积当然就更简单了,二者一并实现如下
class Down(nn.Module):
def __init__(self, inSize, outSize):
super().__init__()
self.conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(inSize, outSize))
def forward(self, x):
return self.conv(x)
class OutConv(nn.Module):
def __init__(self, inSize, outSize):
super(OutConv, self).__init__()
self.conv = nn.Conv2d(inSize, outSize, 1)
def forward(self, x):
return self.conv(x)
上采样过程相对来说复杂一点,多了一个拼接操作,故而其forward函数中,除了需要输入被卷积的数据之外,还要输入U形中,与之对应的那部分计算结果
class Up(nn.Module):
def __init__(self, inSize, outSize):
super().__init__()
self.up = nn.UpsamplingBilinear2d(scale_factor=2)
self.conv = DoubleConv(inSize, outSize)
def forward(self, x1, x2):
x1 = self.up(x1)
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
最后,将这几个组分拼接成一个UNet
class UNet(nn.Module):
def __init__(self, nChannel, nClass):
super(UNet, self).__init__()
self.inc = DoubleConv(nChannel, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
self.down4 = Down(512, 512)
self.up1 = Up(1024, 256)
self.up2 = Up(512, 128)
self.up3 = Up(256, 64)
self.up4 = Up(128, 64)
self.outc = OutConv(64, nClass)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
return self.outc(x)
在具体训练之前,需要准备数据集,其中图像存放在image文件夹中,标签存放在label文件夹中,同名的图像和标签文件一一对应。
from PIL import Image
import os
import numpy as np
from torch.utils.data import Dataset
class ImgData(Dataset):
def __init__(self, data_path):
self.path = data_path
self.imgForder = os.path.join(data_path, "image")
# 加载图像
def loadImg(self, path):
img = np.array(Image.open(path))
return img.reshape(1, *img.shape)
# 根据index读取图片
def __getitem__(self, index):
pImg = os.path.join(self.path, f"image\{index}.png")
pLabel = os.path.join(self.path, f"label\{index}.png")
image = self.loadImg(pImg)
label = self.loadImg(pLabel)
# 数据标签归一化
if label.max() > 1:
label = label / 255
# 随机翻转图像,增加训练样本
flipCode = np.random.randint(3)
if flipCode!=0:
image = np.flip(image, flipCode).copy()
label = np.flip(label, flipCode).copy()
return image, label
def __len__(self):
# 返回训练集大小
return len(os.listdir(self.imgForder))
接下来就是激动人心的训练过程了,UNet采用RMSprop优化算法和BCEWithLogits损失函数,训练函数如下
from torch.utils.data import DataLoader
from torch import optim
import torch.nn as nn
def train(net, device, path, epochs=40, bSize=1, lr=0.00001):
igmData = ImgData(path)
train_loader = DataLoader(igmData, bSize, shuffle=True)
# 优化算法
optimizer = optim.RMSprop(net.parameters(),
lr=lr, weight_decay=1e-8, momentum=0.9)
criterion = nn.BCEWithLogitsLoss() # 损失函数
bestLoss = float('inf') # 最佳loss,初始化为无穷大
# 训练epochs次
for epoch in range(epochs):
net.train() # 训练模式
for image, label in train_loader:
optimizer.zero_grad()
# 将数据拷贝到device中
image = image.to(device=device, dtype=torch.float32)
label = label.to(device=device, dtype=torch.float32)
pred = net(image) # 使用网络参数,输出预测结果
loss = criterion(pred, label) # 计算损失
# 保存loss最小的网络参数
if loss < bestLoss:
bestLoss = loss
torch.save(net.state_dict(), 'best_model.pth')
loss.backward() # 更新参数
optimizer.step()
print(epoch, 'Loss/train', loss.item())
接下来调用训练函数,经过40次训练之后,得到51MB的best_model.pth
模型文件,此即最佳测试结果
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = UNet(1, 1)
net.to(device=device)
path = "train/"
train(net, device, path)
所谓预测,无非是重新做一次训练,而且不及损失,只需保存被神经网络处理之后的结果即可,下面是预测一张图像的函数,其输入net即为我们训练好的网络,device为设备。
def predictOne(net, device, pRead, pSave):
img = Image.open(pRead)
img = np.array(img)
img = img.reshape(1, 1, *img.shape)
img = torch.from_numpy(img)
img = img.to(device=device, dtype=torch.float32)
pred = net(img) # 预测
pred[pred >= 0.5] = 255
pred[pred < 0.5] = 0
pred = np.array(pred.data.cpu()[0])[0]
img = Image.fromarray(pred.astype(np.uint8))
img.save(pSave)
最后,批量处理预测数据集,test和predict分别是存放测试文件和预测图像的文件夹。
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = UNet(1, 1)
net.to(device=device)
net.load_state_dict(torch.load('best_model.pth', map_location=device))
net.eval() # 测试模式
fs = os.listdir('test')
for f in fs:
pRead = os.path.join('test', f)
pSave = os.path.join("predict",f)
predictOne(net, device, pRead, pSave)
预测结果如下,左侧为图像,右侧为标签。