pytorch之保存模型训练好的参数状态以及直接加载该参数状态来进行预测

发布时间:2024年01月19日

使用 PyTorch 框架中的 torch.save() 函数将模型的参数保存到名为 model.pkl?的文件中。

使用 PyTorch 框架中的 torch.load() 函数从文件model.pkl?中加载已经训练好的模型的状态字典。

在 PyTorch 中,模型的参数通常存储在一个名为“state_dict”的字典对象中。这个字典对象包含了模型中所有可学习的参数及其对应的张量。这样,就可以将模型的参数以二进制格式保存到磁盘上。

1. 训练模型train.py

import torch
import torch.nn as nn

# 定义神经网络模型
class MyModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MyModel, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        x = self.fc1(x)
        x = torch.relu(x)
        x = self.fc2(x)
        return x

# 创建模型实例
model = MyModel(input_size=10, hidden_size=20, output_size=2)

# 保存模型状态字典
torch.save(model.state_dict(), 'model.pkl')

2. 加载已经训练好的模型的参数并直接进行预测predict.py


# 加载模型状态字典
model.load_state_dict(torch.load('model.pkl', map_location=torch.device('cpu')))

# 使用模型进行预测
input_tensor = torch.randn(1, 10)
output_tensor = model(input_tensor)

文章来源:https://blog.csdn.net/weixin_45947938/article/details/135680594
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。