Shortcuts

模型集成

Created On: Mar 15, 2023 | Last Updated: Jan 16, 2024 | Last Verified: Nov 05, 2024

本教程说明如何使用 torch.vmap 对模型集成进行向量化。

什么是模型集成?

模型集成将多个模型的预测结果组合在一起。传统上,这是通过分别运行每个模型处理某些输入,然后合并预测结果来实现的。然而,如果运行的模型具有相同的架构,则可能利用 torch.vmap 将它们组合。vmap 是一个函数转换工具,可将函数映射到输入张量的维度之一。其用途之一是消除循环并通过向量化加速处理。

让我们演示如何使用简单MLP进行集成。

备注

本教程需要PyTorch 2.0.0或更高版本。

import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(0)

# Here's a simple MLP
class SimpleMLP(nn.Module):
    def __init__(self):
        super(SimpleMLP, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.flatten(1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        return x

我们生成一批虚拟数据,并假设正在处理MNIST数据集。因此,虚拟图像尺寸为28乘28,我们有一个大小为64的小批量。此外,假设我们希望组合来自10个不同模型的预测结果。

device = 'cuda'
num_models = 10

data = torch.randn(100, 64, 1, 28, 28, device=device)
targets = torch.randint(10, (6400,), device=device)

models = [SimpleMLP().to(device) for _ in range(num_models)]

我们有几种生成预测的选择。可能我们希望为每个模型提供一个不同的随机化数据小批量。或者,也可能希望为每个模型运行相同的数据小批量(例如,如果我们想测试不同模型初始化的效果)。

选项1:为每个模型使用不同的小批量

minibatches = data[:num_models]
predictions_diff_minibatch_loop = [model(minibatch) for model, minibatch in zip(models, minibatches)]

选项2:使用相同的小批量

minibatch = data[0]
predictions2 = [model(minibatch) for model in models]

使用 vmap 对集成进行向量化

让我们使用 vmap 加速循环。首先需要为 vmap 准备模型。

首先,通过堆叠每个参数来组合模型状态。例如,model[i].fc1.weight 的形状为 [784, 128];我们将每个模型的`.fc1.weight`堆叠起来以生成一个形状为 ``[10, 784, 128]``的大权重。

PyTorch提供了便捷函数 torch.func.stack_module_state 来完成这一操作。

from torch.func import stack_module_state

params, buffers = stack_module_state(models)

接下来,我们需要定义一个用于 vmap 的函数。该函数应该接收参数、缓冲区和输入,根据这些参数、缓冲区和输入运行模型。我们将使用 torch.func.functional_call 来帮助实现:

from torch.func import functional_call
import copy

# Construct a "stateless" version of one of the models. It is "stateless" in
# the sense that the parameters are meta Tensors and do not have storage.
base_model = copy.deepcopy(models[0])
base_model = base_model.to('meta')

def fmodel(params, buffers, x):
    return functional_call(base_model, (params, buffers), (x,))

选项1:通过为每个模型使用不同的小批量获得预测。

默认情况下,vmap 会将输入函数映射到传递进来函数输入的第一个维度。在使用 stack_module_state 后,每个参数和缓冲区的前方都有一个附加维度大小为 ‘num_models’,同时小批量也有一个 ‘num_models’ 的维度。

print([p.size(0) for p in params.values()]) # show the leading 'num_models' dimension

assert minibatches.shape == (num_models, 64, 1, 28, 28) # verify minibatch has leading dimension of size 'num_models'

from torch import vmap

predictions1_vmap = vmap(fmodel)(params, buffers, minibatches)

# verify the ``vmap`` predictions match the
assert torch.allclose(predictions1_vmap, torch.stack(predictions_diff_minibatch_loop), atol=1e-3, rtol=1e-5)
[10, 10, 10, 10, 10, 10]

选项2:通过为所有10个模型使用相同的小批量数据获得预测。

vmap 有一个 in_dims 参数,用于指定要映射的维度。通过使用 None,我们告诉 vmap 我们希望同一个小批量应用于所有10个模型。

predictions2_vmap = vmap(fmodel, in_dims=(0, 0, None))(params, buffers, minibatch)

assert torch.allclose(predictions2_vmap, torch.stack(predictions2), atol=1e-3, rtol=1e-5)

简短说明:使用 vmap 转换函数时,对某些类型函数有一定限制。最适合转换的函数是纯函数:一种输出只依赖于输入且没有副作用(例如,改变)的函数。vmap 无法处理任意Python数据结构的变异操作,但能够处理许多PyTorch的原位操作。

性能

想知道性能数据吗?下面是性能表现的结果。

from torch.utils.benchmark import Timer
without_vmap = Timer(
    stmt="[model(minibatch) for model, minibatch in zip(models, minibatches)]",
    globals=globals())
with_vmap = Timer(
    stmt="vmap(fmodel)(params, buffers, minibatches)",
    globals=globals())
print(f'Predictions without vmap {without_vmap.timeit(100)}')
print(f'Predictions with vmap {with_vmap.timeit(100)}')
Predictions without vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7f060c4c4f10>
[model(minibatch) for model, minibatch in zip(models, minibatches)]
  1.04 ms
  1 measurement, 100 runs , 1 thread
Predictions with vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7f060c4b7be0>
vmap(fmodel)(params, buffers, minibatches)
  363.37 us
  1 measurement, 100 runs , 1 thread

使用 vmap 有显著的速度提升!

通常使用 vmap 的向量化应比在循环中运行函数更快,并且与手动批处理具有竞争力。但仍有一些例外情况,比如某种操作可能没有实现 vmap 规则,或者底层内核没有针对较老硬件(如GPU)进行优化。如果发现这些情况,请通过在GitHub上打开问题让我们知道。

脚本总运行时间: ( 0分钟 1.504秒)

通过Sphinx-Gallery生成的图集

文档

访问 PyTorch 的详细开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源