备注
点击 这里 下载完整的示例代码
ONNX 简介 || 将 PyTorch 模型导出到 ONNX || 扩展 ONNX 导出器的操作符支持 || 将带有控制流的模型导出到 ONNX
扩展 ONNX 导出器的操作符支持¶
Created On: Oct 06, 2023 | Last Updated: Mar 05, 2025 | Last Verified: Nov 05, 2024
作者: Ti-Tai Wang, Justin Chu
概述¶
本教程描述了如何为不受支持的 PyTorch 操作符创建 ONNX 实现或用您自己的实现替换现有实现。
我们将介绍三种需要扩展 ONNX 导出器操作符支持的场景:
重写现有 PyTorch 操作符的实现
使用自定义的 ONNX 操作符
支持自定义 PyTorch 操作符
你将学习到的内容:
如何在 ONNX 中重写或新增对 PyTorch 操作符的支持。
如何为特殊运行时集成自定义 ONNX 操作符。
如何实现和转换自定义的 PyTorch 操作符为 ONNX。
前提条件¶
在开始本教程之前,请确保您已完成以下先决条件:
torch >= 2.6
目标 PyTorch 操作符
完成 ONNX Script 教程 之后再继续
使用 ONNX Script 实现操作符
重写现有 PyTorch 操作符的实现¶
虽然 ONNX 导出器团队尽力支持所有 PyTorch 操作符,但某些操作符可能仍未被支持。在本节中,我们将演示如何将不支持的 PyTorch 操作符添加到 ONNX 注册表。
备注
实现不支持的 PyTorch 操作符的步骤与用自定义实现替换现有 PyTorch 操作符的步骤相同。由于在本教程中我们实际上没有一个不支持的 PyTorch 操作符可以使用,我们将用这种方式替换 torch.ops.aten.add.Tensor
的实现,就像如果该操作符未被 ONNX 导出器实现一样。
当由于不支持的操作符导致模型无法导出到 ONNX 时,ONNX 导出器将显示类似以下的错误消息:
No decompositions registered for [...]
错误消息表明不支持的 PyTorch 操作符是 torch.ops.aten.add.Tensor
。该操作符是类型 <class 'torch._ops.OpOverload'>
, 这是我们将用作目标以注册自定义实现的操作符。
import torch
import onnxscript
# Opset 18 is the standard supported version as of PyTorch 2.6
from onnxscript import opset18 as op
# Create a model that uses the operator torch.ops.aten.add.Tensor
class Model(torch.nn.Module):
def forward(self, input_x, input_y):
return torch.ops.aten.add.Tensor(input_x, input_y)
# NOTE: The function signature (including parameter names) must match the signature of the unsupported PyTorch operator.
# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml
# All attributes must be annotated with type hints.
def custom_aten_add(self, other, alpha: float = 1.0):
if alpha != 1.0:
alpha = op.CastLike(alpha, other)
other = op.Mul(other, alpha)
# To distinguish the custom implementation from the builtin one, we switch the order of the inputs
return op.Add(other, self)
x = torch.tensor([1.0])
y = torch.tensor([2.0])
# Then we provide the custom implementation to the ONNX exporter as a ``custom_translation_table``.
onnx_program = torch.onnx.export(
Model().eval(),
(x, y),
dynamo=True,
custom_translation_table={
torch.ops.aten.add.Tensor: custom_aten_add,
},
)
# Optimize the ONNX graph to remove redundant nodes
onnx_program.optimize()
[torch.onnx] Obtain model graph for `Model()` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `Model()` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
现在让我们检查模型并验证模型正在使用自定义实现。
print(onnx_program.model)
<
ir_version=10,
opset_imports={'': 18},
producer_name='pytorch',
producer_version='2.7.0+cu126',
domain=None,
model_version=None,
>
graph(
name=main_graph,
inputs=(
%"input_x"<FLOAT,[1]>,
%"input_y"<FLOAT,[1]>
),
outputs=(
%"add"<FLOAT,[1]>
),
) {
0 | # node_Add_0
%"add"<FLOAT,[1]> ⬅️ ::Add(%"input_y", %"input_x")
return %"add"<FLOAT,[1]>
}
转换是使用我们的自定义实现: 在节点 node_Add_0
中, input_y
现在在前,input_x
在后。
我们可以使用 ONNX Runtime 运行模型并通过直接对输入张量调用 torch.onnx.ONNXProgram
验证结果。
result = onnx_program(x, y)[0]
torch.testing.assert_close(result, torch.tensor([3.0]))
使用自定义的 ONNX 操作符¶
在这种情况下,我们使用标准的 PyTorch 操作符创建一个模型,但运行时(例如 Microsoft’s ONNX Runtime)可以为该内核提供一个自定义实现,从而有效地替换现有的实现。
在以下示例中,我们使用由 ONNX Runtime 提供的 com.microsoft.Gelu
操作符,它不同于来自 ONNX 规范的 Gelu
。
class GeluModel(torch.nn.Module):
def forward(self, input_x):
return torch.ops.aten.gelu(input_x)
# Create a namespace for the custom operator using ONNX Script
# ``com.microsoft`` is an official ONNX Runtime namespace
microsoft_op = onnxscript.values.Opset(domain="com.microsoft", version=1)
# NOTE: The function signature (including parameter names) must match the signature of the unsupported PyTorch operator.
# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml
# NOTE: All attributes must be annotated with type hints.
# The function must be scripted using the ``@onnxscript.script()`` decorator when
# using operators from custom domains. This may be improved in future versions.
from onnxscript import FLOAT
@onnxscript.script(microsoft_op)
def custom_aten_gelu(self: FLOAT, approximate: str = "none") -> FLOAT:
return microsoft_op.Gelu(self)
onnx_program = torch.onnx.export(
GeluModel().eval(),
(x,),
dynamo=True,
custom_translation_table={
torch.ops.aten.gelu.default: custom_aten_gelu,
},
)
# Optimize the ONNX graph to remove redundant nodes
onnx_program.optimize()
[torch.onnx] Obtain model graph for `GeluModel()` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `GeluModel()` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
让我们检查模型并验证模型使用的是来自命名空间 com.microsoft
的 op_type Gelu
。
print(onnx_program.model)
<
ir_version=10,
opset_imports={'com.microsoft': 1, '': 18},
producer_name='pytorch',
producer_version='2.7.0+cu126',
domain=None,
model_version=None,
>
graph(
name=main_graph,
inputs=(
%"input_x"<FLOAT,[1]>
),
outputs=(
%"gelu"<FLOAT,[1]>
),
) {
0 | # n0
%"gelu"<FLOAT,[1]> ⬅️ com.microsoft::Gelu(%"input_x")
return %"gelu"<FLOAT,[1]>
}
类似于前一个示例,我们可以使用 ONNX Runtime 运行模型并验证结果。
result = onnx_program(x)[0]
torch.testing.assert_close(result, torch.ops.aten.gelu(x))
支持自定义 PyTorch 操作符¶
在此情况下,该操作符是用户实现并注册到 PyTorch 的操作符。
在以下示例中,我们希望使用一个自定义操作符,该操作符接受一个张量输入,并返回一个输出。该操作符将输入加上它自己,并返回已舍入的结果。
首先,我们假设自定义操作符已通过 torch.library.custom_op()
实现并注册。您可以参考 在 Python 中创建新的自定义操作符 以获取关于如何创建自定义操作符的详细指南。
# Define and use the operator in PyTorch
@torch.library.custom_op("mylibrary::add_and_round_op", mutates_args=())
def add_and_round_op(input: torch.Tensor) -> torch.Tensor:
return torch.round(input + input)
@add_and_round_op.register_fake
def _add_and_round_op_fake(tensor_x):
return torch.empty_like(tensor_x)
class AddAndRoundModel(torch.nn.Module):
def forward(self, input):
return add_and_round_op(input)
# Implement the custom operator in ONNX using ONNX Script
def onnx_add_and_round(input):
return op.Round(op.Add(input, input))
onnx_program = torch.onnx.export(
AddAndRoundModel().eval(),
(x,),
dynamo=True,
custom_translation_table={
torch.ops.mylibrary.add_and_round_op.default: onnx_add_and_round,
},
)
# Optimize the ONNX graph to remove redundant nodes
onnx_program.optimize()
print(onnx_program)
[torch.onnx] Obtain model graph for `AddAndRoundModel()` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `AddAndRoundModel()` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
ONNXProgram(
model=
<
ir_version=10,
opset_imports={'': 18},
producer_name='pytorch',
producer_version='2.7.0+cu126',
domain=None,
model_version=None,
>
graph(
name=main_graph,
inputs=(
%"input"<FLOAT,[1]>
),
outputs=(
%"add_and_round_op"<FLOAT,[1]>
),
) {
0 | # node_Add_0
%"val_0"<FLOAT,[1]> ⬅️ ::Add(%"input", %"input")
1 | # node_Round_1
%"add_and_round_op"<FLOAT,[1]> ⬅️ ::Round(%"val_0")
return %"add_and_round_op"<FLOAT,[1]>
}
,
exported_program=
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, input: "f32[1]"):
input_1 = input
# File: /data1/lin/pytorch-tutorials/beginner_source/onnx/onnx_registry_tutorial.py:215 in forward, code: return add_and_round_op(input)
add_and_round_op: "f32[1]" = torch.ops.mylibrary.add_and_round_op.default(input_1); input_1 = None
return (add_and_round_op,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='input'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_and_round_op'), target=None)])
Range constraints: {}
)
转换是使用我们的自定义实现将 torch.export.ExportedProgram`
中的 torch.ops.mylibrary.add_and_round_op.default
操作符转换为 ONNX 操作符 Add
和 Round
。
最后,我们验证结果。
总结¶
恭喜!在本教程中,我们探索了 custom_translation_table
选项,并学习了如何使用 ONNX Script 为不支持或现有的 PyTorch 操作符创建自定义实现。
最后,我们利用 ONNX Runtime 执行模型并将结果与 PyTorch 进行比较,为我们提供了对在 ONNX 生态系统中处理不支持的操作符的全面理解。
进一步阅读¶
以下列表包含从基础示例到高级场景的教程,不一定按列出的顺序排列。可以直接跳到您感兴趣的特定主题,或者坐稳了逐步浏览所有内容,了解关于ONNX导出器的一切。
脚本的总运行时间: ( 0 分钟 2.710 秒)