最近在捣鼓K210端的算法部署,不得不吐槽官方文档真的不行,乱七八蕉的。。。
这个帖子主要讲述一下模型转换的步骤,我这里常用的框架是pytorch,相较于tensorflow转换步骤更繁琐一点。
?
模型的话,我这里用的是openmmlab的开源库mmpretrain实现的MobileNetV3-Small,具体的模型可以根据自己的任务更换。
from mmpretrain.models.backbones import MobileNetV3
import torch
import torch.nn as nn
import math
import numpy as np
import onnxruntime
import onnx
from onnx import load_model, save_model
def MobileNet():
return MobileNetV3(arch='small')
model = MobileNet().cuda()
modelfile =".checkpoints/epoch_90.pth"
#####加载权重
checkpoint = torch.load(modelfile, map_location='cuda:0')
state = checkpoint['state_dict']
state_keys = list(state.keys())
for i, key in enumerate(state_keys):
if "backbone." in key and not("neck." in key) and not("head." in key) :
newkey = key.replace("backbone.","")
state[newkey] = state.pop(key)
else:
state.pop(key)
model_dict_load = model.state_dict()
model_dict_load.update(state)
model.load_state_dict(model_dict_load)
model.eval()
#####转换.onnx格式
x = torch.rand(1, 3, 224, 224).cuda()
export_onnx_file = "./output/mobilenet_small.onnx"# 【改】输出ONNX权重地址
torch.onnx.export(model,
x,
export_onnx_file,
opset_version=12, # ONNX算子的版本,不设置默认为13.0
)
导出的.onnx文件可以使用Netron(网页端和本地端)可视化。
这一步是可选的,因为nncase提供了直接从.onnx-->.kmodel的配置,但是只能是float32数据类型。如果需要将模型量化到(u)int8或int16,就需要先将.onnx-->.tflite。
这里推荐给大家一个比较好用的转换工具:onnx2tflite
import sys
# 注意需要把onnx2tflite库下载下来放到当前目录,不然会提示找不到converter
sys.path.append("onnx2tflite")
from converter import onnx_converter
onnx_path = "./mobilenet_small.onnx" # 需要转换的onnx文件位置
onnx_converter(
onnx_model_path = onnx_path,
need_simplify = True,
output_path = "./tflite_model/", # 输出的tflite存储路径
target_formats = ['tflite'], # or ['keras'], ['keras', 'tflite']
weight_quant = False, #只权重量化
int8_model = False,#权重、输入输出全量化
int8_mean=[123.675, 116.28, 103.53],# 量化时有用
int8_std=[58.395,57.12, 57.375],# 量化时有用
image_root = './test' # 校准数据集,量化时有用
注意,这一步其实就可以选择是否对模型量化,但是nncase不支持量化模型作为输入,因此在这一步我们不能量化,输出的.tflite模型仍然是float32的。
到这一步,nncase官方github库提供了相应的脚本,根据自己的配置进行修改就可以了。
目前nncase 2.x.x的版本不支持K210,需要pip install nncase的1.x.x的版本,个人发现1.6.0的版本速度很快,推荐大家使用这个。
pip install nncase==1.6.0.20220505
可能是网络结构中用到了reshape,可以将onnx_converter中的need_simplify设置为True.
或者是:
# 可以借助onnx-simplifier
pip install onnx-simplifier
python -m onnxsim mobilenet_small.onnx mobilenet_small_sim.onnx
nncase版本过低,1.3.0往后才支持tflite的split算子,可以升级一下版本。
?