备注
点击 这里 下载完整示例代码
自定义 Python 算子¶
Created On: Jun 18, 2024 | Last Updated: Mar 19, 2025 | Last Verified: Nov 05, 2024
如何将用 Python 编写的自定义算子集成到 PyTorch 中
如何使用
torch.library.opcheck
测试自定义算子
PyTorch 2.4 或更高版本
PyTorch 提供了一个非常丰富的算子库,这些算子可用于 Tensor(例如 torch.add
、torch.sum
等)。然而,您可能希望在 PyTorch 中使用一个由第三方库编写的新自定义算子。本教程展示了如何封装 Python 函数,使其行为类似 PyTorch 原生算子。您可能希望在 PyTorch 中创建自定义算子的原因包括:
将任意 Python 函数作为
torch.compile
的不透明可调用(即防止torch.compile
跟踪函数内部)。为任意 Python 函数添加训练支持。
使用 torch.library.custom_op()
创建 Python 自定义算子。在无 Python 环境中,可以使用 C++ TORCH_LIBRARY
API 来创建 C++ 自定义算子。有关更多详情,请参见 自定义算子简介。
请注意,如果您的操作可以用现有 PyTorch 算子的组合来表达,则通常无需使用自定义算子 API —— 所有内容(例如 torch.compile
、训练支持)应该都能正常工作。
示例:将 PIL 的 crop 封装为一个自定义算子¶
假设我们正在使用 PIL 的 crop
操作。
import torch
from torchvision.transforms.functional import to_pil_image, pil_to_tensor
import PIL
import IPython
import matplotlib.pyplot as plt
def crop(pic, box):
img = to_pil_image(pic.cpu())
cropped_img = img.crop(box)
return pil_to_tensor(cropped_img).to(pic.device) / 255.
def display(img):
plt.imshow(img.numpy().transpose((1, 2, 0)))
img = torch.ones(3, 64, 64)
img *= torch.linspace(0, 1, steps=64) * torch.linspace(0, 1, steps=64).unsqueeze(-1)
display(img)
cropped_img = crop(img, (10, 10, 50, 50))
display(cropped_img)
crop
无法开箱即用地被 torch.compile
处理:当它无法处理时,torch.compile
导致 “图分裂”,而图分裂对性能不利。以下代码通过报错来展示这一点(fullgraph=True
时 torch.compile
引发错误,如果发生图分裂)。
@torch.compile(fullgraph=True)
def f(img):
return crop(img, (10, 10, 50, 50))
# The following raises an error. Uncomment the line to see it.
# cropped_img = f(img)
为了将 crop
黑箱化并与 torch.compile
一起使用,我们需要做两件事:
将函数封装为一个 PyTorch 自定义算子。
为算子添加一个 “
FakeTensor
内核”(也称为 “元内核”)。给定一些FakeTensors
输入(不包含存储的虚拟 Tensor),该函数应该返回具有正确 Tensor 元数据(形状/步幅/dtype
/设备)的虚拟 Tensor。
from typing import Sequence
# Use torch.library.custom_op to define a new custom operator.
# If your operator mutates any input Tensors, their names must be specified
# in the ``mutates_args`` argument.
@torch.library.custom_op("mylib::crop", mutates_args=())
def crop(pic: torch.Tensor, box: Sequence[int]) -> torch.Tensor:
img = to_pil_image(pic.cpu())
cropped_img = img.crop(box)
return (pil_to_tensor(cropped_img) / 255.).to(pic.device, pic.dtype)
# Use register_fake to add a ``FakeTensor`` kernel for the operator
@crop.register_fake
def _(pic, box):
channels = pic.shape[0]
x0, y0, x1, y1 = box
result = pic.new_empty(y1 - y0, x1 - x0, channels).permute(2, 0, 1)
# The result should have the same metadata (shape/strides/``dtype``/device)
# as running the ``crop`` function above.
return result
完成后,crop
现在可以无图分裂地工作:
@torch.compile(fullgraph=True)
def f(img):
return crop(img, (10, 10, 50, 50))
cropped_img = f(img)
display(img)
display(cropped_img)
为 crop 添加训练支持¶
使用 torch.library.register_autograd
为算子添加训练支持。优先使用此方法而非直接用 torch.autograd.Function
;有些情况与 PyTorch 算子注册 API 组合可能导致(并且确实导致)与 torch.compile
组合时发生静默错误。
如果您不需要训练支持,则无需使用 torch.library.register_autograd
。如果您最终用一个``custom_op``进行训练,而它没有 autograd 注册,我们会抛出一条错误消息。
crop
的梯度公式本质上是 PIL.paste``(推导过程留给读者练习)。首先将 ``paste
封装为一个自定义算子:
@torch.library.custom_op("mylib::paste", mutates_args=())
def paste(im1: torch.Tensor, im2: torch.Tensor, coord: Sequence[int]) -> torch.Tensor:
assert im1.device == im2.device
assert im1.dtype == im2.dtype
im1_pil = to_pil_image(im1.cpu())
im2_pil = to_pil_image(im2.cpu())
PIL.Image.Image.paste(im1_pil, im2_pil, coord)
return (pil_to_tensor(im1_pil) / 255.).to(im1.device, im1.dtype)
@paste.register_fake
def _(im1, im2, coord):
assert im1.device == im2.device
assert im1.dtype == im2.dtype
return torch.empty_like(im1)
然后使用 register_autograd
为 crop
指定梯度公式:
def backward(ctx, grad_output):
grad_input = grad_output.new_zeros(ctx.pic_shape)
grad_input = paste(grad_input, grad_output, ctx.coords)
return grad_input, None
def setup_context(ctx, inputs, output):
pic, box = inputs
ctx.coords = box[:2]
ctx.pic_shape = pic.shape
crop.register_autograd(backward, setup_context=setup_context)
注意反向必须是 PyTorch 可理解操作的组合,因此我们将 paste 封装为一个自定义算子,而不是直接使用 PIL 的 paste。
img = img.requires_grad_()
result = crop(img, (10, 10, 50, 50))
result.sum().backward()
display(img.grad)
这是正确的梯度,裁剪区域为 1(白色),未使用的区域为 0(黑色)。
测试 Python 自定义算子¶
使用 torch.library.opcheck
测试自定义算子是否注册正确。这不会测试梯度是否数学上正确;请单独编写测试(可以是手动测试或 torch.autograd.gradcheck
)。
使用 opcheck
时,传入一组示例输入以进行测试。如果您的算子支持训练,则示例如中应该包含需要梯度的 Tensor。如果算子支持多设备,则示例应该包括来自每个设备的 Tensor。
examples = [
[torch.randn(3, 64, 64), [0, 0, 10, 10]],
[torch.randn(3, 91, 91, requires_grad=True), [10, 0, 20, 10]],
[torch.randn(3, 60, 60, dtype=torch.double), [3, 4, 32, 20]],
[torch.randn(3, 512, 512, requires_grad=True, dtype=torch.double), [3, 4, 32, 45]],
]
for example in examples:
torch.library.opcheck(crop, example)
可变 Python 自定义算子¶
您还可以将一个会改变其输入的 Python 函数封装为自定义算子。这种会改变输入的函数很常见,因为这是许多底层内核的编写方式;例如,一个计算 sin
的内核可能会接收输入 Tensor 和输出 Tensor,并将 input.sin()
写入输出 Tensor。
我们用 numpy.sin
举例说明一个可变 Python 自定义算子。
import numpy as np
@torch.library.custom_op("mylib::numpy_sin", mutates_args={"output"}, device_types="cpu")
def numpy_sin(input: torch.Tensor, output: torch.Tensor) -> None:
assert input.device == output.device
assert input.device.type == "cpu"
input_np = input.numpy()
output_np = output.numpy()
np.sin(input_np, out=output_np)
由于该算子不返回任何内容,因此无须为其注册 FakeTensor
内核(元内核)以使其与 torch.compile
一起工作。
@torch.compile(fullgraph=True)
def f(x):
out = torch.empty(3)
numpy_sin(x, out)
return out
x = torch.randn(3)
y = f(x)
assert torch.allclose(y, x.sin())
以下是一个 opcheck
运行,告诉我们确实正确地注册了算子。例如,如果我们忘记将输出包含到 mutates_args
中,则 opcheck
会抛出错误。
example_inputs = [
[torch.randn(3), torch.empty(3)],
[torch.randn(0, 3), torch.empty(0, 3)],
[torch.randn(1, 2, 3, 4, dtype=torch.double), torch.empty(1, 2, 3, 4, dtype=torch.double)],
]
for example in example_inputs:
torch.library.opcheck(numpy_sin, example)
总结¶
在本教程中,我们学习了如何使用 torch.library.custom_op
在 Python 中创建一个自定义算子,使其可以与 PyTorch 子系统(例如 torch.compile
和autograd)一起工作。
本教程提供了自定义算子基本介绍。更详细的信息,请参见: