Shortcuts

ONNX 简介 || 将 PyTorch 模型导出到 ONNX || 扩展 ONNX 导出器的操作符支持 || 将带有控制流的模型导出到 ONNX

扩展 ONNX 导出器的操作符支持

Created On: Oct 06, 2023 | Last Updated: Mar 05, 2025 | Last Verified: Nov 05, 2024

作者: Ti-Tai Wang, Justin Chu

概述

本教程描述了如何为不受支持的 PyTorch 操作符创建 ONNX 实现或用您自己的实现替换现有实现。

我们将介绍三种需要扩展 ONNX 导出器操作符支持的场景:

  • 重写现有 PyTorch 操作符的实现

  • 使用自定义的 ONNX 操作符

  • 支持自定义 PyTorch 操作符

你将学习到的内容:

  • 如何在 ONNX 中重写或新增对 PyTorch 操作符的支持。

  • 如何为特殊运行时集成自定义 ONNX 操作符。

  • 如何实现和转换自定义的 PyTorch 操作符为 ONNX。

前提条件

在开始本教程之前,请确保您已完成以下先决条件:

重写现有 PyTorch 操作符的实现

虽然 ONNX 导出器团队尽力支持所有 PyTorch 操作符,但某些操作符可能仍未被支持。在本节中,我们将演示如何将不支持的 PyTorch 操作符添加到 ONNX 注册表。

备注

实现不支持的 PyTorch 操作符的步骤与用自定义实现替换现有 PyTorch 操作符的步骤相同。由于在本教程中我们实际上没有一个不支持的 PyTorch 操作符可以使用,我们将用这种方式替换 torch.ops.aten.add.Tensor 的实现,就像如果该操作符未被 ONNX 导出器实现一样。

当由于不支持的操作符导致模型无法导出到 ONNX 时,ONNX 导出器将显示类似以下的错误消息:

No decompositions registered for [...]

错误消息表明不支持的 PyTorch 操作符是 torch.ops.aten.add.Tensor。该操作符是类型 <class &apos;torch._ops.OpOverload&apos;>, 这是我们将用作目标以注册自定义实现的操作符。

import torch
import onnxscript

# Opset 18 is the standard supported version as of PyTorch 2.6
from onnxscript import opset18 as op


# Create a model that uses the operator torch.ops.aten.add.Tensor
class Model(torch.nn.Module):
    def forward(self, input_x, input_y):
        return torch.ops.aten.add.Tensor(input_x, input_y)


# NOTE: The function signature (including parameter names) must match the signature of the unsupported PyTorch operator.
# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml
# All attributes must be annotated with type hints.
def custom_aten_add(self, other, alpha: float = 1.0):
    if alpha != 1.0:
        alpha = op.CastLike(alpha, other)
        other = op.Mul(other, alpha)
    # To distinguish the custom implementation from the builtin one, we switch the order of the inputs
    return op.Add(other, self)


x = torch.tensor([1.0])
y = torch.tensor([2.0])

# Then we provide the custom implementation to the ONNX exporter as a ``custom_translation_table``.
onnx_program = torch.onnx.export(
    Model().eval(),
    (x, y),
    dynamo=True,
    custom_translation_table={
        torch.ops.aten.add.Tensor: custom_aten_add,
    },
)
# Optimize the ONNX graph to remove redundant nodes
onnx_program.optimize()
[torch.onnx] Obtain model graph for `Model()` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `Model()` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅

现在让我们检查模型并验证模型正在使用自定义实现。

print(onnx_program.model)
<
    ir_version=10,
    opset_imports={'': 18},
    producer_name='pytorch',
    producer_version='2.7.0+cu126',
    domain=None,
    model_version=None,
>
graph(
    name=main_graph,
    inputs=(
        %"input_x"<FLOAT,[1]>,
        %"input_y"<FLOAT,[1]>
    ),
    outputs=(
        %"add"<FLOAT,[1]>
    ),
) {
    0 |  # node_Add_0
         %"add"<FLOAT,[1]> ⬅️ ::Add(%"input_y", %"input_x")
    return %"add"<FLOAT,[1]>
}

转换是使用我们的自定义实现: 在节点 node_Add_0 中, input_y 现在在前,input_x 在后。

我们可以使用 ONNX Runtime 运行模型并通过直接对输入张量调用 torch.onnx.ONNXProgram 验证结果。

使用自定义的 ONNX 操作符

在这种情况下,我们使用标准的 PyTorch 操作符创建一个模型,但运行时(例如 Microsoft’s ONNX Runtime)可以为该内核提供一个自定义实现,从而有效地替换现有的实现。

在以下示例中,我们使用由 ONNX Runtime 提供的 com.microsoft.Gelu 操作符,它不同于来自 ONNX 规范的 Gelu

class GeluModel(torch.nn.Module):
    def forward(self, input_x):
        return torch.ops.aten.gelu(input_x)


# Create a namespace for the custom operator using ONNX Script
# ``com.microsoft`` is an official ONNX Runtime namespace
microsoft_op = onnxscript.values.Opset(domain="com.microsoft", version=1)

# NOTE: The function signature (including parameter names) must match the signature of the unsupported PyTorch operator.
# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml
# NOTE: All attributes must be annotated with type hints.
# The function must be scripted using the ``@onnxscript.script()`` decorator when
# using operators from custom domains. This may be improved in future versions.
from onnxscript import FLOAT


@onnxscript.script(microsoft_op)
def custom_aten_gelu(self: FLOAT, approximate: str = "none") -> FLOAT:
    return microsoft_op.Gelu(self)


onnx_program = torch.onnx.export(
    GeluModel().eval(),
    (x,),
    dynamo=True,
    custom_translation_table={
        torch.ops.aten.gelu.default: custom_aten_gelu,
    },
)

# Optimize the ONNX graph to remove redundant nodes
onnx_program.optimize()
[torch.onnx] Obtain model graph for `GeluModel()` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `GeluModel()` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅

让我们检查模型并验证模型使用的是来自命名空间 com.microsoft 的 op_type Gelu

print(onnx_program.model)
<
    ir_version=10,
    opset_imports={'com.microsoft': 1, '': 18},
    producer_name='pytorch',
    producer_version='2.7.0+cu126',
    domain=None,
    model_version=None,
>
graph(
    name=main_graph,
    inputs=(
        %"input_x"<FLOAT,[1]>
    ),
    outputs=(
        %"gelu"<FLOAT,[1]>
    ),
) {
    0 |  # n0
         %"gelu"<FLOAT,[1]> ⬅️ com.microsoft::Gelu(%"input_x")
    return %"gelu"<FLOAT,[1]>
}

类似于前一个示例,我们可以使用 ONNX Runtime 运行模型并验证结果。

result = onnx_program(x)[0]
torch.testing.assert_close(result, torch.ops.aten.gelu(x))

支持自定义 PyTorch 操作符

在此情况下,该操作符是用户实现并注册到 PyTorch 的操作符。

在以下示例中,我们希望使用一个自定义操作符,该操作符接受一个张量输入,并返回一个输出。该操作符将输入加上它自己,并返回已舍入的结果。

首先,我们假设自定义操作符已通过 torch.library.custom_op() 实现并注册。您可以参考 在 Python 中创建新的自定义操作符 以获取关于如何创建自定义操作符的详细指南。

# Define and use the operator in PyTorch
@torch.library.custom_op("mylibrary::add_and_round_op", mutates_args=())
def add_and_round_op(input: torch.Tensor) -> torch.Tensor:
    return torch.round(input + input)


@add_and_round_op.register_fake
def _add_and_round_op_fake(tensor_x):
    return torch.empty_like(tensor_x)


class AddAndRoundModel(torch.nn.Module):
    def forward(self, input):
        return add_and_round_op(input)


# Implement the custom operator in ONNX using ONNX Script
def onnx_add_and_round(input):
    return op.Round(op.Add(input, input))


onnx_program = torch.onnx.export(
    AddAndRoundModel().eval(),
    (x,),
    dynamo=True,
    custom_translation_table={
        torch.ops.mylibrary.add_and_round_op.default: onnx_add_and_round,
    },
)

# Optimize the ONNX graph to remove redundant nodes
onnx_program.optimize()
print(onnx_program)
[torch.onnx] Obtain model graph for `AddAndRoundModel()` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `AddAndRoundModel()` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
ONNXProgram(
    model=
        <
            ir_version=10,
            opset_imports={'': 18},
            producer_name='pytorch',
            producer_version='2.7.0+cu126',
            domain=None,
            model_version=None,
        >
        graph(
            name=main_graph,
            inputs=(
                %"input"<FLOAT,[1]>
            ),
            outputs=(
                %"add_and_round_op"<FLOAT,[1]>
            ),
        ) {
            0 |  # node_Add_0
                 %"val_0"<FLOAT,[1]> ⬅️ ::Add(%"input", %"input")
            1 |  # node_Round_1
                 %"add_and_round_op"<FLOAT,[1]> ⬅️ ::Round(%"val_0")
            return %"add_and_round_op"<FLOAT,[1]>
        }


    ,
    exported_program=
        ExportedProgram:
            class GraphModule(torch.nn.Module):
                def forward(self, input: "f32[1]"):
                    input_1 = input

                     # File: /data1/lin/pytorch-tutorials/beginner_source/onnx/onnx_registry_tutorial.py:215 in forward, code: return add_and_round_op(input)
                    add_and_round_op: "f32[1]" = torch.ops.mylibrary.add_and_round_op.default(input_1);  input_1 = None
                    return (add_and_round_op,)

        Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='input'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_and_round_op'), target=None)])
        Range constraints: {}

)

转换是使用我们的自定义实现将 torch.export.ExportedProgram` 中的 torch.ops.mylibrary.add_and_round_op.default 操作符转换为 ONNX 操作符 AddRound

最后,我们验证结果。

总结

恭喜!在本教程中,我们探索了 custom_translation_table 选项,并学习了如何使用 ONNX Script 为不支持或现有的 PyTorch 操作符创建自定义实现。

最后,我们利用 ONNX Runtime 执行模型并将结果与 PyTorch 进行比较,为我们提供了对在 ONNX 生态系统中处理不支持的操作符的全面理解。

进一步阅读

以下列表包含从基础示例到高级场景的教程,不一定按列出的顺序排列。可以直接跳到您感兴趣的特定主题,或者坐稳了逐步浏览所有内容,了解关于ONNX导出器的一切。

脚本的总运行时间: ( 0 分钟 2.710 秒)

画廊由 Sphinx-Gallery 生成

文档

访问 PyTorch 的详细开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源