以下是一门实验课的大作业。
数据集介绍:
此数据集包含从土耳其伊兹密尔的一家超市收集的9种不同的海鲜类型,数据集包括镀金头鲷、红鲷、鲈鱼、红鲻鱼、竹荚鱼、黑海鲱鱼,条纹红鲻鱼,鳟鱼,虾图像样本。(label分别为gilt head bream, red sea bream, sea bass, red mullet, horse mackerel,?black sea sprat, striped red mullet, trout, shrimp image samples)
论文材料:
O. Ulucan, D. Karakaya and M. Turkan, "A Large-Scale Dataset for Fish Segmentation and Classification," 2020 Innovations in Intelligent Systems and Applications Conference (ASYU), 2020, pp. 1-5, doi: 10.1109/ASYU50717.2020.9259867.
数据集下载:
https://www.kaggle.com/crowww/a-large-scale-fish-dataset
图像格式:
????图像是通过2台不同的相机收集的,柯达Easyshare Z650 和三星 ST60。因此,图像的分辨率分别为2832x2128,1024x =768。在分割和分类过程之前,数据集的大小已调整为?590x445。通过保留纵横比。调整图像大小后,数据集中的所有标签都得到了增强(通过翻转和旋转)。在增强过程结束时,每个类的总图像数量变为2000;其中1000是RGB图像和另外1000作为他们对应的ground-truth图像标签。
数据集的描述
????该数据集包含9种不同的海鲜类型。对于每个类,有1000个增强图像及ground-truth图像标签。每个类都可以在带有其真实标签的“Fish_Dataset”文件中找到。每个类的所有图像按“00000.png”到“01000.png”排序。例如,如果要访问数据集中虾真实标签图像,则应遵循"Fish->Shrimp->Shrimp GT"的顺序。
任务要求:
1.图像分割:
训练一个深度神经网络分割出图片中海鲜物体,
1)在测试集上形成如下所示的双色图,每一海鲜种类生成一个,共9个;
2)统计分割结果,打印出训练集loss和accuracy和测试集loss和accuracy。
2. 图像分类
训练深度神经网络对海鲜图片数据集进行分类,
??????1)在测试集上生成分类结果图9张(每一类各一张)(如下图,可在每个子图的标题上标记真实label和预测label);
? ? ? 2)?统计分类结果,打印出训练集loss和accuracy和测试集loss和accuracy。
图像分割采用的是Unet+ResNet34,其中ResNet34采用的是torchvision中的预训练模型。
训练结果如下
首先是制作数据集。
import torch
from torch.utils.data import Dataset, ConcatDataset, DataLoader
from PIL import Image
import os
import matplotlib.pyplot as plt
from torchvision import transforms
class MyData(Dataset):
def __init__(self, root_dir, label_dir, transfrom):
self.root_dir = os.path.join(root_dir, label_dir)
self.label_dir = label_dir
self.path = os.path.join(self.root_dir, self.label_dir)
self.mpath = os.path.join(self.root_dir, self.label_dir + " GT")
self.img_path = os.listdir(self.path)
self.mask_path = os.listdir(self.mpath)
self.myTransforms = transfrom
def __getitem__(self, idx):
img_name = self.img_path[idx]
mask_name = self.mask_path[idx]
img_item_path = os.path.join(self.path, img_name)
mask_item_path = os.path.join(self.mpath, mask_name)
img = Image.open(img_item_path)
label = Image.open(mask_item_path)
img = self.myTransforms(img)
label = self.myTransforms(label)
label = label.to (torch.long)
return img, label
def __len__(self):
return len(self.img_path)
制作Unet+ResNet34的模型,其中resnet34使用的预训练权重。
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import ResNet34_Weights
class DecoderBlock(nn.Module):
"""Upscaling then double conv"""
def __init__(self, conv_in_channels, conv_out_channels, up_in_channels=None, up_out_channels=None):
super().__init__()
"""
eg:
decoder1:
up_in_channels : 1024, up_out_channels : 512
conv_in_channels : 1024, conv_out_channels : 512
decoder5:
up_in_channels : 64, up_out_channels : 64
conv_in_channels : 128, conv_out_channels : 64
"""
if up_in_channels == None:
up_in_channels = conv_in_channels
if up_out_channels == None:
up_out_channels = conv_out_channels
self.up = nn.ConvTranspose2d(up_in_channels, up_out_channels, kernel_size=2, stride=2)
self.conv = nn.Sequential(
nn.Conv2d(conv_in_channels, conv_out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(conv_out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(conv_out_channels, conv_out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(conv_out_channels),
nn.ReLU(inplace=True)
)
# x1-upconv , x2-downconv
def forward(self, x1, x2):
x1 = self.up(x1)
x = torch.cat([x1, x2], dim=1)
return self.conv(x)
class UnetResnet34(nn.Module):
def __init__(self, num_classes=2):
super().__init__()
resnet34 = torchvision.models.resnet34(weights=ResNet34_Weights.IMAGENET1K_V1)
filters = [64, 128, 256, 512]
self.firstlayer = nn.Sequential(*list(resnet34.children())[:3])
self.maxpool = list(resnet34.children())[3]
self.encoder1 = resnet34.layer1
self.encoder2 = resnet34.layer2
self.encoder3 = resnet34.layer3
self.encoder4 = resnet34.layer4
self.bridge = nn.Sequential(
nn.Conv2d(filters[3], filters[3] * 2, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(filters[3] * 2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.decoder1 = DecoderBlock(conv_in_channels=filters[3] * 2, conv_out_channels=filters[3])
self.decoder2 = DecoderBlock(conv_in_channels=filters[3], conv_out_channels=filters[2])
self.decoder3 = DecoderBlock(conv_in_channels=filters[2], conv_out_channels=filters[1])
self.decoder4 = DecoderBlock(conv_in_channels=filters[1], conv_out_channels=filters[0])
self.decoder5 = DecoderBlock(
conv_in_channels=filters[1], conv_out_channels=filters[0], up_in_channels=filters[0],
up_out_channels=filters[0]
)
self.lastlayer = nn.Sequential(
nn.ConvTranspose2d(in_channels=filters[0], out_channels=filters[0], kernel_size=2, stride=2),
nn.Conv2d(filters[0], num_classes, kernel_size=3, padding=1, bias=False)
)
def forward(self, x):
e1 = self.firstlayer(x)
maxe1 = self.maxpool(e1)
e2 = self.encoder1(maxe1)
e3 = self.encoder2(e2)
e4 = self.encoder3(e3)
e5 = self.encoder4(e4)
c = self.bridge(e5)
d1 = self.decoder1(c, e5)
d2 = self.decoder2(d1, e4)
d3 = self.decoder3(d2, e3)
d4 = self.decoder4(d3, e2)
d5 = self.decoder5(d4, e1)
out = self.lastlayer(d5)
return out
训练部分代码。
import torch.optim
import torchvision
from matplotlib import pyplot as plt
from torch import nn
from torch.utils.data import DataLoader, random_split, ConcatDataset
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import transforms
import time
from unet_utils import *
from Res34_Unet import UnetResnet34
from Dataloader import MyData
train_root_dir = "Fish_Dataset"
class_kind = ["Black Sea Sprat", "Gilt Head Bream", "Hourse Mackerel", "Red Mullet", "Red Sea Bream", "Sea Bass",
"Shrimp", "Striped Red Mullet", "Trout"]
myTransforms = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor()]
)
dataset = []
for i in class_kind:
data = MyData(train_root_dir, i, myTransforms)
dataset = ConcatDataset([dataset, data])
fig, axes = plt.subplots(10, 2)
for i in range(10):
image, mask = dataset[i]
axes[i, 0].imshow(image.permute(1, 2, 0).numpy().reshape(256, 256, 3))
axes[i, 1].imshow(mask.numpy().reshape(256, 256)) # input size : 445x590
plt.show()
def train_test_split(dataset, test_size=0.2):
length = len(dataset)
train_length = round(length * (1 - test_size))
test_length = length - train_length
train_dataset, test_dataset = random_split(dataset, [train_length, test_length])
return train_dataset, test_dataset
train_dataset, test_dataset = train_test_split(dataset)
train_data_size = len(train_dataset)
test_data_size = len(test_dataset)
print("训练数据集长度为{}".format(train_data_size))
print("测试数据集长度为{}".format(test_data_size))
# 利用dataloader 加载数据集
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False)
# 搭载神经网络
model = UnetResnet34(num_classes=2)
model = model.cuda()
# 损失函数
loss_fn = nn.CrossEntropyLoss()
loss_fn = loss_fn.cuda()
# 优化器
learning_rate = 0.01
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
# 设置训练网络的一些参数
# 记录训练的次数
total_train_step = 0
# 记录测试的次数
total_test_step = 0
# 训练的轮数
epoch = 10
# 添加tensorboard
writer = SummaryWriter("final_homework/final_homework/logs_train")
fit_time = time.time()
for i in range(epoch):
# 训练步骤开始
t_loss_perb, t_iou_perb, t_acc_perb = 0, 0, 0
v_loss_perb, v_iou_perb, v_acc_perb = 0, 0, 0
t_loss_pere, v_loss_pere = 0, 0
since = time.time()
model.train()
for data in train_dataloader:
imgs, masks = data
imgs = imgs.cuda()
masks = masks.cuda()
outputs = model(imgs)
masks = torch.squeeze(masks)
loss = loss_fn(outputs, masks)
t_iou_perb += mIoU(outputs, masks)
train_accuracy= pixel_accuracy(outputs, masks)
t_acc_perb += train_accuracy
# 优化器模型
optimizer.zero_grad()
loss.backward()
optimizer.step()
t_loss_perb += loss.item()
total_train_step = total_train_step + 1
if total_train_step % 10 == 0:
writer.add_scalar("train_loss", loss.item(), total_train_step)
writer.add_scalar("pixel_accuracy", train_accuracy, total_train_step)
with torch.no_grad():
model.eval()
for data in test_dataloader:
imgs, masks = data
imgs = imgs.cuda()
masks = masks.cuda()
masks = torch.squeeze(masks)
outputs = model(imgs)
loss = loss_fn(outputs, masks)
v_iou_perb += mIoU(outputs, masks)
v_acc_perb += pixel_accuracy(outputs, masks)
# loss
v_loss_perb += loss.item()
#calculation mean for each batch
t_loss_pere = t_loss_perb/len(train_dataloader)
v_loss_pere = v_loss_perb/len(test_dataloader)
print("Epoch:{}/{}..".format(i + 1, epoch),
"Train Loss: {:.3f}..".format(t_loss_pere),
"Val Loss: {:.3f}..".format(v_loss_pere),
"Train Score:{:.3f}..".format(t_iou_perb / len(train_dataloader)),
"Val Score: {:.3f}..".format(v_iou_perb / len(test_dataloader)),
"Train Acc:{:.3f}..".format(t_acc_perb / len(train_dataloader)),
"Val Acc:{:.3f}..".format(v_acc_perb / len(test_dataloader)),
"Time: {:.2f}m".format((time.time() - since) / 60))
torch.save(model, 'Res34_Unet.pth')
writer.close()
预测结果可视化部分。
import os
import matplotlib.pyplot as plt
import numpy as np
import torch.optim
import torchvision
from torch import nn
from torch.utils.data import DataLoader, random_split, ConcatDataset
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms as transforms
from Read_Data_predit import MyData
from PIL import Image
myTransforms = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
])
#准备数据集
root_dir = "Fish_Dataset"
kind_name = ["Black Sea Sprat", "Gilt Head Bream", "Hourse Mackerel", "Red Mullet",
"Red Sea Bream", "Sea Bass", "Shrimp", "Striped Red Mullet", "Trout"]
dataset = []
fig, axes = plt.subplots(9, 3)
myModel = torch.load("Res34_Unet.pth")
# fig, axes = plt.subplots(10, 2, figsize=(10, 50))
# for i in range(10):
# image, mask = dataset[i]
# axes[i, 0].imshow(image.permute(1, 2, 0).numpy().reshape(445, 590, 3))
# axes[i, 1].imshow(mask.numpy().reshape(256, 256)) # input size : 445x590
# plt.show()
myModel.eval()
with torch.no_grad():
for i, name in enumerate(kind_name):
img_root_dir = os.path.join(root_dir,name)
mask_root_dir = os.path.join(root_dir, name)
img_item_path = os.path.join(img_root_dir, name, "00010.png")#取每类第一张进行验证
mask_item_path = os.path.join(mask_root_dir, name+" GT", "00010.png") # 取每类第一张进行验证
img = Image.open(img_item_path)
mask = Image.open(mask_item_path)
data_transform = myTransforms(img)
data_transform = data_transform.unsqueeze(dim=0) # 添加一维batch
data_transform = data_transform.cuda()
output = myModel(data_transform)
output = torch.squeeze(output) # 去掉batch_size 维度
output = torch.squeeze(output[1]) # 第0个是背景 第1个是目标
output = output.cpu()
output = output.numpy().reshape(256, 256)
binary_predicted_mask = np.where(output > 0.5, 1, 0)
axes[i, 0].imshow(img.resize((256, 256)))
axes[i, 1].imshow(mask.resize((256, 256)))
axes[i, 2].imshow(binary_predicted_mask)
plt.show()
? ? ? ? 图像分类采用的是ResNet50,使用torchvision自带的模型与训练权重,因此无需制作模型。
图像分类训练结果,可以看到训练三轮就可以90%。
首先是数据集制作
import torch
from torch.utils.data import Dataset, ConcatDataset, random_split
from PIL import Image
import os
import numpy as np
class MyData(Dataset):
def __init__(self, root_dir, label_dir, class_num, transforms):
self.root_dir = os.path.join(root_dir, label_dir)
self.label_dir = label_dir
self.class_num = class_num
self.transforms = transforms
self.path = os.path.join(self.root_dir, self.label_dir)
self.img_path = os.listdir(self.path)
def __getitem__(self, idx):
img_name = self.img_path[idx]
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
img = Image.open(img_item_path)
img = self.transforms(img)
label = self.class_num
return img, label
def __len__(self):
return len(self.img_path)
# root_dir = "Fish_Dataset"
#
# kind_name = ["Black Sea Sprat", "Gilt Head Bream", "Hourse Mackerel", "Red Mullet",
# "Red Sea Bream", "Sea Bass", "Shrimp", "Striped Red Mullet", "Trout"]
# dataset = []
#
#
# def train_test_split(dataset, test_size=0.2):
# length = len(dataset)
# train_length = round(length * (1 - test_size))
# test_length = length - train_length
#
# train_dataset, test_dataset = random_split(dataset, [train_length, test_length])
# return train_dataset, test_dataset
#
#
# for i, name in enumerate(kind_name):
# dataset_temp = MyData(root_dir, name, i)
# dataset = ConcatDataset([dataset, dataset_temp])
#
# train_data, test_data = train_test_split(dataset)
# train_data_size = len(train_data)
# test_data_size = len(test_data)
# print("训练数据集长度为{}".format(train_data_size))
# print("测试数据集长度为{}".format(test_data_size))
训练部分代码。
import time
import torch.optim
import torchvision
from torch import nn
from torch.utils.data import DataLoader, random_split, ConcatDataset
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms as transforms
from Read_Data import MyData
myTransforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
#准备数据集
root_dir = "Fish_Dataset"
kind_name = ["Black Sea Sprat", "Gilt Head Bream", "Hourse Mackerel", "Red Mullet",
"Red Sea Bream", "Sea Bass", "Shrimp", "Striped Red Mullet", "Trout"]
dataset = []
def train_test_split(dataset, test_size=0.2):
length = len(dataset)
train_length = round(length * (1 - test_size))
test_length = length - train_length
train_dataset, test_dataset = random_split(dataset, [train_length, test_length])
return train_dataset, test_dataset
for i, name in enumerate(kind_name):
dataset_temp = MyData(root_dir, name, i, myTransforms)
dataset = ConcatDataset([dataset, dataset_temp])
train_data, test_data = train_test_split(dataset)
train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集长度为{}".format(train_data_size))
print("测试数据集长度为{}".format(test_data_size))
#利用dataloader 加载数据集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)
#搭载神经网络
# 定义模型
myModel = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
# 将原来的ResNet18的最后两层全连接层拿掉,替换成一个输出单元为9的全连接层
inchannel = myModel.fc.in_features
myModel.fc = nn.Linear(inchannel, 9)
myModel.cuda()
#损失函数
loss_fn = nn.CrossEntropyLoss()
loss_fn = loss_fn.cuda()
#优化器
learning_rate = 0.001
optimizer = torch.optim.SGD(myModel.parameters(), lr=learning_rate)
#设置训练网络的一些参数
#记录训练的次数
total_train_step = 0
#记录测试的次数
total_test_step = 0
# 训练的轮数
epoch = 10
# 添加tensorboard
writer = SummaryWriter("./ResNet_train")
for i in range(epoch):
since = time.time()
print("----------------第{}轮训练开始----------------".format(i+1))
total_train_loss = 0
total_train_accuracy = 0
#训练步骤开始
for data in train_dataloader:
imgs, targets = data
imgs = imgs.cuda()
targets = targets.cuda()
ouputs = myModel(imgs)
loss = loss_fn(ouputs, targets)
#优化器模型
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_train_loss = total_train_loss +loss.item()
total_train_step = total_train_step + 1
train_accuracy = (ouputs.argmax(1) == targets).sum()
total_train_accuracy = total_train_accuracy + train_accuracy
# if total_train_step % 10 == 0:
# print("训练次数:{}, train_loss:{}".format(total_train_step, loss.item()))
# print("训练次数:{}, train_accuracy:{}".format(train_accuracy, loss.item()))
#测试步骤开始
total_test_loss = 0
total_test_accuracy = 0
myModel.eval()
with torch.no_grad():
for data in test_dataloader:
imgs, targets = data
imgs = imgs.cuda()
targets = targets.cuda()
ouputs = myModel(imgs)
loss = loss_fn(ouputs, targets)
total_test_loss = total_test_loss + loss.item()
test_accuracy = (ouputs.argmax(1) == targets).sum()
total_test_step = total_test_step + 1
total_test_accuracy = total_test_accuracy + test_accuracy
print("整体训练集上的Loss: {}".format(total_train_loss))
print("整体训练集上的正确率: {}".format(total_train_accuracy/train_data_size))
print("整体测试集上的Loss: {}".format(total_test_loss))
print("整体测试集上的正确率: {}".format(total_test_accuracy/test_data_size))
print("第{}轮耗时{:.2f}min".format(i+1,(time.time()-since)/60))
total_test_step = total_test_step + 1
torch.save(model, 'Resnet50_fish.pth')
writer.close()
?预测结果与可视化,需要注意放入模型前将图片进行transorms。
import os
import matplotlib.pyplot as plt
import torch.optim
import torchvision
from torch import nn
from torch.utils.data import DataLoader, random_split, ConcatDataset
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms as transforms
from Read_Data_predit import MyData
from PIL import Image
myTransforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
#准备数据集
root_dir = "NA_Fish_Dataset"
kind_name = ["Black Sea Sprat", "Gilt Head Bream", "Hourse Mackerel", "Red Mullet",
"Red Sea Bream", "Sea Bass", "Shrimp", "Striped Red Mullet", "Trout"]
dataset = []
fig, axes = plt.subplots(3, 3)
myModel = torch.load("Resnet50_fish.pth")
myModel.eval()
with torch.no_grad():
for i, name in enumerate(kind_name):
img_item_path = os.path.join(root_dir, name, "00001.png")#取每类第一张进行验证
img = Image.open(img_item_path)
axes[i // 3, i % 3].imshow(img)
data_transform = myTransforms(img)
data_transform = data_transform.unsqueeze(dim=0) # 添加一维batch
data_transform = data_transform.cuda()
output = myModel(data_transform)
output = output.argmax(1)
axes[i // 3, i % 3].set_title("True:"+kind_name[i]+"\nPredited:"+kind_name[output] , loc = "center")
plt.show()
Unet工具包
import yaml
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import shutil
import os
import pandas as pd
from torch import nn
def create_df(path):
name = []
for dirname, _, filenames in os.walk(path):
for filename in filenames:
name.append(filename.split('.')[0])
return pd.DataFrame({'id': name}, index=np.arange(0, len(name)))
def plot(history, graphType, isTest=False):
if not isTest:
plt.plot(history[f'train_{graphType}'], label='train', marker='*')
plt.plot(history[f'val_{graphType}'], label='val', marker='o')
else:
plt.plot(history[f'test_{graphType}'], label='test', marker='*')
plt.title(f'{graphType} per epoch')
plt.ylabel(f'{graphType}')
plt.xlabel('epoch')
plt.legend(), plt.grid()
plt.show()
def CE_Loss(inputs, target, cls_weights, num_classes=2):
n, c, h, w = inputs.size()
nt, ht, wt = target.size()
if h != ht and w != wt:
inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)
temp_inputs = inputs.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
temp_target = target.view(-1)
CE_loss = nn.CrossEntropyLoss(ignore_index=num_classes)(temp_inputs, temp_target)
return CE_loss
def pixel_accuracy(output, mask):
with torch.no_grad():
output = torch.argmax(F.softmax(output, dim=1), dim=1)
correct = torch.eq(output, mask).int()
accuracy = float(correct.sum()) / float(correct.numel())
return accuracy
def mIoU(pred_mask, mask, n_classes=2, smooth=1e-10):
with torch.no_grad():
pred_mask = F.softmax(pred_mask, dim=1)
pred_mask = torch.argmax(pred_mask, dim=1)
pred_mask = pred_mask.contiguous().view(-1)
mask = mask.contiguous().view(-1)
iou_per_class = []
for clas in range(0, n_classes): # loop per pixel class
true_class = pred_mask == clas
true_label = mask == clas
if true_label.long().sum().item() == 0: # no exist label in this loop
iou_per_class.append(np.nan)
else:
intersect = torch.logical_and(true_class, true_label).sum().float().item()
union = torch.logical_or(true_class, true_label).sum().float().item()
iou = (intersect + smooth) / (union + smooth)
iou_per_class.append(iou)
return np.nanmean(iou_per_class)
def load_train_config(path):
with open(path) as f:
data = yaml.load(f, Loader=yaml.Loader)
return data
def visualize(image, mask, pred_mask, score):
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20, 10))
ax1.imshow(image)
ax1.set_title('Picture')
ax2.imshow(mask)
ax2.set_title('Ground truth')
ax2.set_axis_off()
ax3.imshow(pred_mask)
ax3.set_title('UnetResnet34 | mIoU {:.3f}'.format(score))
ax3.set_axis_off()
def save_checkpoint(state, current_checkpoint_path, is_best=False, best_model_path=None):
torch.save(state, current_checkpoint_path)
if is_best:
assert best_model_path != None, 'best_model_path should not be None.'
shutil.copyfile(current_checkpoint_path, best_model_path)
def load_checkpoint(best_model_path, current_checkpoint_path, model, optimizer, scheduler, best_checkpoint=False):
train_loss_key = 'train_loss'
val_loss_key = 'val_loss'
path = current_checkpoint_path
if best_checkpoint:
path = best_model_path
model, optimizer, scheduler, epoch, train_loss, val_loss = load_model(path, model, optimizer, scheduler,
train_loss_key, val_loss_key)
print(f'optimizer = {optimizer}, start epoch = {epoch}, train loss = {train_loss}, val loss = {val_loss}')
return model, optimizer, scheduler, val_loss
def load_model(model_path, model, optimizer, scheduler, train_loss_key, val_loss_key):
checkpoint = torch.load(model_path)
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
scheduler.load_state_dict(checkpoint['scheduler'])
epoch = checkpoint['epoch']
train_loss = checkpoint[train_loss_key]
val_loss = checkpoint[val_loss_key]
return model, optimizer, scheduler, epoch, train_loss, val_loss
def get_lr(optimizer):
for param_group in optimizer.param_groups:
return param_group['lr']