Shortcuts

逐样本梯度(Per-sample-gradients)

Created On: Mar 15, 2023 | Last Updated: Apr 24, 2024 | Last Verified: Nov 05, 2024

这是什么?

逐样本梯度计算是为一批数据中的每个样本计算其梯度。这是在差分隐私、元学习和优化研究中有用的量。

备注

本教程需要PyTorch 2.0.0或更高版本。

import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(0)

# Here's a simple CNN and loss function:

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

def loss_fn(predictions, targets):
    return F.nll_loss(predictions, targets)

让我们生成一批虚拟数据,并假装我们正在处理MNIST数据集。虚拟图像大小为28x28,我们使用了一个64大小的迷你批次。

device = 'cuda'

num_models = 10
batch_size = 64
data = torch.randn(batch_size, 1, 28, 28, device=device)

targets = torch.randint(10, (64,), device=device)

在常规的模型训练中,会将迷你批次前向传播到模型中,然后调用.backward()来计算梯度。这将生成整个迷你批次的’平均’梯度:

model = SimpleCNN().to(device=device)
predictions = model(data)  # move the entire mini-batch through the model

loss = loss_fn(predictions, targets)
loss.backward()  # back propagate the 'average' gradient of this mini-batch

与上述方法相反,逐样本梯度计算等效于:

  • 对于数据的每个单独样本,执行一次前向和反向传播,以获得单个(逐样本)梯度。

def compute_grad(sample, target):
    sample = sample.unsqueeze(0)  # prepend batch dimension for processing
    target = target.unsqueeze(0)

    prediction = model(sample)
    loss = loss_fn(prediction, target)

    return torch.autograd.grad(loss, list(model.parameters()))


def compute_sample_grads(data, targets):
    """ manually process each sample with per sample gradient """
    sample_grads = [compute_grad(data[i], targets[i]) for i in range(batch_size)]
    sample_grads = zip(*sample_grads)
    sample_grads = [torch.stack(shards) for shards in sample_grads]
    return sample_grads

per_sample_grads = compute_sample_grads(data, targets)

sample_grads[0]``是模型``conv1.weight``的逐样本梯度。``model.conv1.weight.shape``为``[32, 1, 3, 3];注意批次中每个样本都有一个梯度,总共有64个。

print(per_sample_grads[0].shape)

逐样本梯度,更高效的方法,使用函数变换

我们可以通过使用函数变换来高效地计算逐样本梯度。

``torch.func``函数变换API作用于函数之上。我们的策略是定义一个计算损失的函数,然后应用变换以构建一个计算逐样本梯度的函数。

我们将使用``torch.func.functional_call``函数将一个``nn.Module``视为一个函数。

首先,让我们将``model``中的状态提取到两个字典中:参数和缓冲区。我们将分离它们,因为我们不会使用常规的PyTorch自动求导(例如Tensor.backward(),torch.autograd.grad)。

from torch.func import functional_call, vmap, grad

params = {k: v.detach() for k, v in model.named_parameters()}
buffers = {k: v.detach() for k, v in model.named_buffers()}

接下来,让我们定义一个函数来计算模型的损失,给定单个输入,而不是一批输入。重要的是,这个函数必须接受参数、输入和目标,因为我们将对它们进行变换。

注意 - 由于模型最初是为处理批次而设计的,我们将使用``torch.unsqueeze``来添加一个批次维度。

def compute_loss(params, buffers, sample, target):
    batch = sample.unsqueeze(0)
    targets = target.unsqueeze(0)

    predictions = functional_call(model, (params, buffers), (batch,))
    loss = loss_fn(predictions, targets)
    return loss

现在,让我们使用``grad``变换创建一个新函数,该函数计算``compute_loss``第一个参数(即``params``)的梯度。

ft_compute_grad = grad(compute_loss)

ft_compute_grad``函数计算单个(样本,目标)对的梯度。我们可以使用``vmap``使其计算整个批次样本和目标的梯度。请注意,``in_dims=(None, None, 0, 0),因为我们希望在数据和目标的第0维上映射``ft_compute_grad``,并为每个样本使用相同的``params``和缓冲区。

ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0))

最后,让我们使用转换后的函数来计算逐样本梯度:

ft_per_sample_grads = ft_compute_sample_grad(params, buffers, data, targets)

我们可以双重检查,使用``grad``和``vmap``得到的结果与手动对每个样本单独处理的结果是否匹配:

for per_sample_grad, ft_per_sample_grad in zip(per_sample_grads, ft_per_sample_grads.values()):
    assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=3e-3, rtol=1e-5)

一个快速说明:关于``vmap``可以变换的函数类型存在一些限制。最适合转换的函数是纯函数:其输出仅由输入决定,并且没有副作用(如修改操作)。``vmap``无法处理任意Python数据结构的修改,但它可以处理许多PyTorch的就地操作。

性能比较

好奇``vmap``的性能比较如何?

目前,最佳结果是在新款GPU(如A100(Ampere))上获得的,在该示例中我们看到了高达25倍的速度提升,但以下是我们构建机器上的一些结果:

def get_perf(first, first_descriptor, second, second_descriptor):
    """takes torch.benchmark objects and compares delta of second vs first."""
    second_res = second.times[0]
    first_res = first.times[0]

    gain = (first_res-second_res)/first_res
    if gain < 0: gain *=-1
    final_gain = gain*100

    print(f"Performance delta: {final_gain:.4f} percent improvement with {first_descriptor} ")

from torch.utils.benchmark import Timer

without_vmap = Timer(stmt="compute_sample_grads(data, targets)", globals=globals())
with_vmap = Timer(stmt="ft_compute_sample_grad(params, buffers, data, targets)",globals=globals())
no_vmap_timing = without_vmap.timeit(100)
with_vmap_timing = with_vmap.timeit(100)

print(f'Per-sample-grads without vmap {no_vmap_timing}')
print(f'Per-sample-grads with vmap {with_vmap_timing}')

get_perf(with_vmap_timing, "vmap", no_vmap_timing, "no vmap")

在 PyTorch 中计算每样本梯度还有其他优化的解决方案(例如:https://github.com/pytorch/opacus),这些方法性能也比普通方法要好。不过,很酷的一点是,通过组合 vmapgrad 可以带来不错的性能提升。

总体来说,与在循环中运行函数相比,使用 vmap 进行矢量化应该更快,并且与手动批处理性能相当。不过也有一些例外,例如我们没有针对某些操作实现 vmap 规则,或者底层内核未针对较旧硬件(例如 GPU)优化。如果您遇到这些情况,请通过 GitHub 提交问题告诉我们。

**脚本的总运行时间:**(0分钟0.000秒)

通过Sphinx-Gallery生成的图集

文档

访问 PyTorch 的详细开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源