PyTorch自定义操作符¶
Created On: Jun 18, 2024 | Last Updated: Jan 06, 2025 | Last Verified: Nov 05, 2024
PyTorch 提供了一个丰富的运算符库,可以操作张量(例如 torch.add
、torch.sum
等)。然而,您可能希望为 PyTorch 引入一个新的自定义操作,并使其与 torch.compile
、autograd 和 torch.vmap
等系统协同工作。要实现这一点,您需要通过 Python 的 torch.library 文档 或 C++ 的 TORCH_LIBRARY
API 在 PyTorch 中注册自定义操作。
从 Python 编写自定义操作¶
请参阅 自定义 Python 算子。
如果有以下需求,您可能希望从 Python(而不是 C++)编写自定义操作:
您有一个希望 PyTorch 视为不透明可调用对象的 Python 函数,特别是在涉及
torch.compile
和torch.export
时。您有一些 C++/CUDA 内核的 Python 绑定,并希望这些绑定能与 PyTorch 的系统(如
torch.compile
或torch.autograd
)协同工作。您正在使用 Python(而不是 AOTInductor 这样的仅 C++ 环境)。
将自定义 C++ 和/或 CUDA 代码与 PyTorch 集成¶
请参阅 自定义 C++ 和 CUDA 算子。
如果有以下需求,您可能希望从 C++(而不是 Python)编写自定义操作:
您有自定义的 C++ 和/或 CUDA 代码。
您计划使用此代码与
AOTInductor
一起进行无 Python 环境的推理。
自定义操作手册¶
关于教程和此页面未涵盖的信息,请参阅 `自定义操作手册 <https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU>`_(我们正在努力将信息迁移到我们的文档网站上)。我们建议您先阅读上述教程之一,然后将自定义操作手册用作参考;这不是一本从头到尾阅读的手册。
什么时候应该创建自定义操作?¶
如果您的操作可以通过组合内置的 PyTorch 操作符表示,请将其写为 Python 函数并调用它,而不是创建自定义操作。如果您调用的是 PyTorch 无法理解的某些库(例如,自定义的 C/C++ 代码、自定义 CUDA 内核或 Python 对 C/C++/CUDA 扩展的绑定),请使用操作符注册 API 创建自定义操作。
为什么应该创建自定义操作?¶
可以通过获取张量的数据指针并将其传递给 pybind 的内核来使用 C/C++/CUDA 内核。然而,这种方法无法与 PyTorch 的子系统(如 autograd、torch.compile、vmap 等)集成。为了让一个操作可以与 PyTorch 的子系统协同工作,必须通过操作符注册 API 注册这个操作。