在本单元中,我们将了解如何加载模型及其持久参数状态和推理模型预测。
%matplotlib inline
import torch
import onnxruntime
from torch import nn
import torch.onnx as onnx
import torchvision.models as models
from torchvision import datasets
from torchvision.transforms import ToTensor
为了加载模型,我们将定义模型类,其中包含用于训练模型的神经网络的状态和参数。
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
nn.ReLU(),
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
加载模型权重时,我们需要首先实例化模型类,因为该类定义了网络的结构。接下来,我们使用 load_state_dict()
方法加载参数。
model = NeuralNetwork()
model.load_state_dict(torch.load('data/model.pth'))
model.eval()
NeuralNetwork(
(flatten): Flatten(start_dim=1, end_dim=-1)
(linear_relu_stack): Sequential(
(0): Linear(in_features=784, out_features=512, bias=True)
(1): ReLU()
(2): Linear(in_features=512, out_features=512, bias=True)
(3): ReLU()
(4): Linear(in_features=512, out_features=10, bias=True)
(5): ReLU()
)
)
**注意:**请务必在推理之前调用
model.eval()
方法,以将 dropout 和批量归一化层设置为评估模式。否则,您将看到不一致的推理结果。
优化模型以在各种平台和编程语言上运行是很困难的。在所有不同的框架和硬件组合中最大限度地提高性能非常耗时。Open Neural Network Exchange (ONNX) 开放神经网络交换运行时为您提供了一种解决方案,可在任何硬件、云或边缘设备上进行一次训练并加速推理。
ONNX 是许多供应商支持的通用格式,用于共享神经网络和其他机器学习模型。您可以使用 ONNX 格式在其他编程语言(Java, JavaScript, C# 和 ML.NET)和框架上对模型进行推理。
PyTorch 还具有本机 ONNX 导出支持。然而,考虑到 PyTorch 执行图的动态特性,导出过程必须遍历执行图以生成持久的 ONNX 模型。因此,应将适当大小的测试变量传递到导出例程中(在我们的例子中,我们将创建正确大小的虚拟零张量。您可以从训练数据集的shape
函数中获取大小:tensor.shape
):
input_image = torch.zeros((1,28,28))
onnx_model = 'data/model.onnx'
onnx.export(model, input_image, onnx_model)
我们将使用测试数据集作为示例数据,从 ONNX 模型进行推理以进行预测。
test_data = datasets.FashionMNIST(
root = "data",
train = False,
download = True,
transform = ToTensor()
)
classes = [
"T-shirt/top",
"Trouser",
"Pullover",
"Dress",
"Coat",
"Sandal",
"Shirt",
"Sneaker",
"Bag",
"Ankle boot",
]
x, y = test_data[0][0], test_data[0][1]
我们使用 onnxruntime.InferenceSession
创建推理会话。要推断 ONNX 模型,请调用 run
并传入您想要返回的输出列表(如果您需要所有输出,请保留为空)和输入值的映射。结果是输出列表。
session = onnxruntime.InferenceSession(onnx_model, None)
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
result = session.run([output_name], {input_name:x.numpy()})
predicted, actual = classes[result[0][0].argmax(0)], classes[y]
print(f'Predicted: "{predicted}", Actual: {actual}')
Predicted: "Ankle boot", Actual: Ankle boot
ONNX 模型使您能够在不同平台上以不同编程语言运行推理。
什么是 PyTorch 模型 state_dict?
它是模型的内部状态字典,用于存储已学习的参数。