pytorch学习(六)、网上模型的使用及修改、模型的保存及加载、完整模型的训练和测试过程的小案例

发布时间:2024年01月22日

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档


前言

至此,pytorch的基础入门已经结束了。

一、现有网络模型的使用及修改


import torch
import torchvision
from torch import nn

# 调用网络
vgg16_true = torchvision.models.vgg16(pretrained=True)
vgg16_false = torchvision.models.vgg16(pretrained=False)

print(vgg16_true)
# 修改网络的分类层
vgg16_true.classifier.add_module('add_linear',nn.Linear(1000,10))
print(vgg16_true)

二、模型的保存及加载

  • 模型的保存:torch.save()
  • 模型的加载:torch.load()

import torch
import torchvision
from torch import nn

vgg16_false = torchvision.models.vgg16(pretrained=False)

# 模型的保存方式1(模型结构+模型参数)
torch.save(vgg16_false,'vgg16_method1.pth')
# 加载模型1
model = torch.load('vgg16_method1.pth')
# print(model)

# 模型的保存方式2(模型参数)
torch.save(vgg16_false.state_dict(),'vgg16_method2.pth')
# 加载模型2
model = torchvision.models.vgg16(pretrained=False )
model.load_state_dict(torch.load('vgg16_method2.pth'))
# print(model)

三、完整模型的训练和测试过程(一个完整的小案例)

  • 以pythorch官网的CIFAR10数据集为例。

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from model import *
import torch
from torch import nn
import torchvision

# 加载数据
train_data = torchvision.datasets.CIFAR10('./dataset/CIFAR10',train=True,transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10('./dataset/CIFAR10',train=False,transform=torchvision.transforms.ToTensor(),download=True)

#数据长度
train_data_size = len(train_data)
test_data_size = len(test_data)
print('训练数据集的长度为:{}'.format(train_data_size))
print('测试数据集的长度为:{}'.format(test_data_size))

#加载数据集
train_data_loader = DataLoader(train_data,batch_size=64)
test_data_loader = DataLoader(test_data,batch_size=64)

# 创建网络模型
model = Model()

# 损失函数
loss = nn.CrossEntropyLoss()

# 优化器
lr = 1e-2   #学习率
optim = torch.optim.SGD(model.parameters(),lr=lr)

# 设置训练网络的一些参数
# 记录训练的次数,测试的次数,训练的轮数
total_train_cnt = 0
total_test_cnt = 0
epoch = 10

# 添加tensorboard
writer = SummaryWriter('./logs_train')

for i in range(epoch):
    print("--------第{}轮训练开始---------".format(i+1))

    # 训练步骤开始
    model.train()
    for data in train_data_loader:
        imgs,targets = data
        output = model(imgs)
        # 计算损失函数
        out_loss = loss(output,targets)
        # 梯度清零
        optim.zero_grad()
        # 反向传播
        out_loss.backward()
        # 开始优化
        optim.step()

        total_train_cnt+=1
        if total_train_cnt % 100 ==0:
            print("训练次数:{},Loss:{}".format(total_train_cnt,out_loss.item()))
            writer.add_scalar("train_loss",out_loss.item(),total_train_cnt)

    # 测试步骤开始
    model.eval()
    total_test_loss = 0
    # 正确率
    total_accuracy = 0
    with torch.no_grad():
        for data in test_data_loader:
            imgs,targets = data
            output = model(imgs)
            out_loss = loss(output,targets)
            # 总体损失值
            total_test_loss+=out_loss.item()
           # 计算正确率
            accuracy = (output.argmax(1) == targets).sum()
            total_accuracy+=accuracy

    print("整体测试集上的Loss:{}".format(total_test_loss))
    print("整体测试集上的正确率:{}".format(total_accuracy / test_data_size))
    writer.add_scalar("test_loss",total_test_loss,total_test_cnt)
    writer.add_scalar("test_accuracy",total_accuracy / test_data_size,total_test_cnt)
    total_test_cnt+=1
    # 模型的保存
    torch.save(model,"model_{}.pth".format(i))
    print("模型已保存!")

writer.close()

文章来源:https://blog.csdn.net/weixin_73044854/article/details/135704049
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。