pytorch之导出ONNX相关问题

发布时间:2023年12月20日

torch.onnx.export

torch.onnx.export(model, args, f, export_params=True, verbose=False, training=<TrainingMode.EVAL: 0>, input_names=None, output_names=None, operator_export_type=<OperatorExportTypes.ONNX: 0>, opset_version=None, do_constant_folding=True, dynamic_axes=None, keep_initializers_as_inputs=None, custom_opsets=None, export_modules_as_functions=False, autograd_inlining=True)

  1. model:PyTorch 模型,即要导出的模型对象;
  2. args:一个包含输入张量的元组。这是用于运行模型以确定图结构的输入;
  3. f:导出的 ONNX 模型的文件路径或文件对象。可以是字符串路径或者类似文件对象的对象,用于保存导出的模型;
  4. export_params:指定是否导出模型参数。如果设置为 True,则导出模型的权重。默认为 True;
  5. verbose:控制是否显示导出过程的详细信息。如果设置为 True,将显示详细信息,默认为 False;
  6. training:指定导出模型时是否保留训练模式。如果设置为 True,则导出的模型将保留训练模式,默认为 False。
  7. input_names:指定输入节点的名称列表。默认为 None,此时将使用自动生成的名称;
  8. output_names:指定输出节点的名称列表。默认为 None,此时将使用自动生成的名称;
  9. operator_export_type:
  10. opset_version:指定要使用的 ONNX 版本。默认为 torch.onnx.constant_opset_version,通常无需手动指定;
  11. do_constant_folding:控制是否进行常量折叠。如果设置为 True,将尝试在导出过程中进行常量折叠,默认为 False;
  12. dynamic_axes:
  13. keep_initializers_as_inputs:
  14. custom_opsets:
  15. export_modules_as_functions:
  16. autograd_inlining:
    例子:
torch.onnx.export(
        model,
        im,
        f,
        verbose=False,
        opset_version=opset,
        training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,
        do_constant_folding=not train,
        input_names=['images'],
        output_names=['p3', 'p4', 'p5'],
        dynamic_axes={
            'images': {
                0: 'batch',
                2: 'height',
                3: 'width'},  # shape(1,3,640,640)
            'p3': {
                0: 'batch',
                2: 'height',
                3: 'width'},  # shape(1,25200,4)
            'p4': {
                0: 'batch',
                2: 'height',
                3: 'width'},
            'p5': {
                0: 'batch',
                2: 'height',
                3: 'width'}
        } if dynamic else None)

onnxsim.simplify

onnxsim.simplify参数:

  1. onnx_model:原始的 ONNX 模型对象;
  2. input_shapes:一个字典,指定输入张量的形状,这对于一些需要知道输入形状的优化非常有用,默认为 None;
  3. dynamic_input_shape:于指定是否允许动态输入形状。动态输入形状是指输入形状在模型推理时可以根据实际输入数据进行调整,而不是在模型导出时就固定好。
    onnxsim.simplify返回值:
  4. model_onnx:简化后的 ONNX 模型对象。这是一个经过优化和简化的新模型,可以用于推理;
  5. check:一个字符串,表示模型简化过程中的检查结果。这个字符串包含一些关于模型简化的信息,模型没有问题,字符串为空。
# Checks
model_onnx = onnx.load(f)  # load onnx model
onnx.checker.check_model(model_onnx)  # check onnx model

# Simplify
if simplify:
    # try:
    check_requirements(('onnx-simplifier',))
    import onnxsim

    LOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
    model_onnx, check = onnxsim.simplify(model_onnx,
                                            dynamic_input_shape=dynamic,
                                            input_shapes={'images': list(im.shape)} if dynamic else None)
    assert check, 'assert check failed'
    onnx.save(model_onnx, f)
文章来源:https://blog.csdn.net/qq_45032341/article/details/134985026
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。