使用 PyTorch 框架中的 torch.save()
函数将模型的参数保存到名为 model.pkl
?的文件中。
使用 PyTorch 框架中的 torch.load()
函数从文件model.pkl
?中加载已经训练好的模型的状态字典。
在 PyTorch 中,模型的参数通常存储在一个名为“state_dict”的字典对象中。这个字典对象包含了模型中所有可学习的参数及其对应的张量。这样,就可以将模型的参数以二进制格式保存到磁盘上。
import torch
import torch.nn as nn
# 定义神经网络模型
class MyModel(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(MyModel, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, output_size)
def forward(self, x):
x = self.fc1(x)
x = torch.relu(x)
x = self.fc2(x)
return x
# 创建模型实例
model = MyModel(input_size=10, hidden_size=20, output_size=2)
# 保存模型状态字典
torch.save(model.state_dict(), 'model.pkl')
# 加载模型状态字典
model.load_state_dict(torch.load('model.pkl', map_location=torch.device('cpu')))
# 使用模型进行预测
input_tensor = torch.randn(1, 10)
output_tensor = model(input_tensor)