还记得这篇文章吗?迁移学习|代码实现
在这篇文章中,我们知道了在构建模型时,可以借助一些非常有名的模型,这些模型在ImageNet数据集上早已经得到了检验。
同时torchvision模块也提供了预训练好的模型。我们只需稍作修改,便可运用到自己的实际任务中!
我们仍然按照这个步骤开始我们的模型的训练
准备一个可迭代的数据集
定义一个神经网络
将数据集输入到神经网络进行处理
计算损失
通过梯度下降算法更新参数
import?torch?
import?torchvision
import?torchvision.transforms?as?transforms
import?torch.nn?as?nn
import?torch.optim?as?optim
import?matplotlib.pyplot?as?plt
from?torchvision?import?models
数据集准备
cifar10_train?=?torchvision.datasets.CIFAR10(
????root?=?'cifar10/',
????train?=?True,
????download?=?True
)
cifar10_test=torchvision.datasets.CIFAR10(
????root?=?'cifar10/',
????train?=?False,
????download?=?True
)
transform?=?transforms.Compose([
????????transforms.ToTensor(),
????????transforms.Resize((224,224))
????])
cifar2_train=[(transform(img),[3,5].index(label))?for?img,label?in?cifar10_train?if?label?in?[3,5]]
cifar2_test=[(transform(img),[3,5].index(label))?for?img,label?in?cifar10_test?if?label?in?[3,5]]
train_loader?=?torch.utils.data.DataLoader(cifar2_train,?batch_size=64,shuffle=True)
test_loader?=?torch.utils.data.DataLoader(cifar2_test,?batch_size=64,shuffle=True)
数据集使用CIFAR-10数据集中的猫和狗。
CIFAR-10数据集类别
种类???????标签
plane???????0
car???????????1
bird?????????2
cat???????????3
deer?????????4
dog??????????5
frog?????????6
horse???????7
ship?????????8
truck????????9
可以看到其中cat和dog的标签分别为3和5
借助:
[3,5].index(label)
我们可以将cat标签变为0,dog标签变为1,从而回到二分类问题。
举个例子:
>>>?[3,5].index(3)
0
>>>?[3,5].index(5)
1
定义模型
参考这篇文章:迁移学习|代码实现
#网络搭建
network=models.resnet18(pretrained=True)
for?param?in?network.parameters():
????param.requires_grad=False
network.fc=nn.Linear(512,2)
#损失函数
criterion=nn.CrossEntropyLoss()
#优化器
optimizer=optim.SGD(network.fc.parameters(),lr=0.01,momentum=0.9)
device=torch.device("cuda"?if?torch.cuda.is_available()?else?"cpu")
network=network.to(device)
训练模型:
for?epoch?in?range(10):
????total_loss?=?0
????total_correct?=?0
????for?batch?in?train_loader:???#?Get?batch
????????images,?labels?=batch
????????images=images.to(device)
????????labels=labels.to(device)
????????????
????????optimizer.zero_grad()??#告诉优化器把梯度属性中权重的梯度归零,否则pytorch会累积梯度
????????preds?=?network(images)
????????loss?=?criterion(preds,?labels)
????????loss.backward()
????????optimizer.step()
????????
????????total_loss?+=?loss.item()
????????_,prelabels=torch.max(preds,dim=1)
????????total_correct?+=?int((prelabels==labels).sum())
????accuracy?=?total_correct/len(cifar2_train)
????print("Epoch:%d??,??Loss:%f??,?Accuracy:%f?"%(epoch,total_loss,accuracy))
Epoch:0??,??Loss:78.549439??,?Accuracy:0.788900
Epoch:1??,??Loss:77.828066??,?Accuracy:0.801500
Epoch:2??,??Loss:66.151785??,?Accuracy:0.828100
Epoch:3??,??Loss:76.204446??,?Accuracy:0.816800
Epoch:4??,??Loss:68.886606??,?Accuracy:0.828100
Epoch:5??,??Loss:71.129405??,?Accuracy:0.821200
Epoch:6??,??Loss:66.096364??,?Accuracy:0.829900
Epoch:7??,??Loss:65.504227??,?Accuracy:0.827700
Epoch:8??,??Loss:76.303878??,?Accuracy:0.817100
Epoch:9??,??Loss:70.546953??,?Accuracy:0.820700
测试模型:
correct=0
total=0
network.eval()
with?torch.no_grad():
????for?batch?in?test_loader:
????????imgs,labels=batch
????????imgs=imgs.cuda()
????????labels=labels.cuda()
????????
????????preds=network(imgs)
????????_,prelabels=torch.max(preds,dim=1)
????????#print(prelabels.size())
????????total=total+labels.size(0)
????????correct=correct+int((prelabels==labels).sum())
????#print(total)
????accuracy=correct/total
????print("Accuracy:?",accuracy)
Accuracy:??0.8025
这里使用的预训练模型是resnet18,我们也可以使用VGG16模型,同时记得改变最后一个全连接层的输出参数,使得其满足我们自己的任务。
除了预训练模型之外,我们还可以对一些超参数进行调整,使最后的效果变得更好!