Shortcuts

训练一个分类器

Created On: Mar 24, 2017 | Last Updated: Dec 20, 2024 | Last Verified: Not Verified

就是这样。您已经看到了如何定义神经网络、计算损失并更新网络权重。

现在您可能会想到,

数据怎么办?

通常,当您需要处理图像、文本、音频或视频数据时,可以使用标准的 Python 包将数据加载到一个 numpy 数组中。然后可以将此数组转换为 torch.*Tensor

  • 对于图像,可以使用 Pillow、OpenCV 等包

  • 对于音频,可以使用 scipy 和 librosa

  • 对于文本,可以使用原生 Python、基于 Cython 的加载器,或者 NLTK 和 SpaCy

特别是针对视觉任务,我们创建了一个名为 torchvision 的包,它具有常用数据集(例如 ImageNet、CIFAR10、MNIST 等)的数据加载器和图像数据变换工具,即 torchvision.datasetstorch.utils.data.DataLoader

这极大地提供了便利,避免了编写样板代码。

在本教程中,我们将使用 CIFAR10 数据集。它包含以下类别:’飞机’、’汽车’、’鸟’、’猫’、’鹿’、’狗’、’青蛙’、’马’、’船’、’卡车’。CIFAR-10 的图像大小为 3x32x32,即 32x32 像素的三通道彩色图像。

cifar10

cifar10

训练一个图像分类器

我们将按以下步骤进行:

  1. 使用 torchvision 加载并标准化 CIFAR10 的训练和测试数据集

  2. 定义一个卷积神经网络

  3. 定义一个损失函数

  4. 在训练数据上训练网络

  5. 在测试数据上测试网络

1. 加载并标准化 CIFAR10

使用 torchvision,加载 CIFAR10 非常简单。

import torch
import torchvision
import torchvision.transforms as transforms

torchvision 数据集的输出是范围为 [0, 1] 的 PILImage 图像。我们将其转换为范围为 [-1, 1] 的标准化 Tensor。

备注

如果在运行 Windows 时出现 BrokenPipeError,尝试将 torch.utils.data.DataLoader() 的 num_worker 设置为 0。

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 4

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
  0%|          | 0.00/170M [00:00<?, ?B/s]
  0%|          | 32.8k/170M [00:00<39:10, 72.5kB/s]
  0%|          | 98.3k/170M [00:00<17:28, 162kB/s]
  0%|          | 197k/170M [00:00<10:47, 263kB/s]
  0%|          | 360k/170M [00:01<06:42, 422kB/s]
  0%|          | 721k/170M [00:01<03:30, 808kB/s]
  1%|          | 1.41M/170M [00:01<01:51, 1.52MB/s]
  2%|1         | 2.75M/170M [00:01<00:49, 3.37MB/s]
  2%|1         | 3.21M/170M [00:01<00:46, 3.57MB/s]
  3%|3         | 5.47M/170M [00:02<00:27, 6.00MB/s]
  5%|5         | 8.59M/170M [00:02<00:15, 10.3MB/s]
  6%|5         | 9.80M/170M [00:02<00:15, 10.3MB/s]
  7%|6         | 11.7M/170M [00:02<00:13, 11.9MB/s]
  8%|7         | 13.0M/170M [00:02<00:14, 11.0MB/s]
  9%|8         | 14.8M/170M [00:02<00:14, 10.6MB/s]
 10%|#         | 17.1M/170M [00:02<00:11, 13.3MB/s]
 11%|#         | 18.6M/170M [00:03<00:13, 11.6MB/s]
 12%|#2        | 21.1M/170M [00:03<00:12, 12.2MB/s]
 13%|#3        | 22.9M/170M [00:03<00:11, 13.3MB/s]
 14%|#4        | 24.3M/170M [00:03<00:12, 11.5MB/s]
 16%|#5        | 27.1M/170M [00:03<00:09, 15.1MB/s]
 17%|#6        | 28.8M/170M [00:03<00:11, 12.8MB/s]
 18%|#7        | 30.3M/170M [00:03<00:12, 11.3MB/s]
 19%|#9        | 33.0M/170M [00:04<00:11, 11.5MB/s]
 21%|##1       | 35.9M/170M [00:04<00:11, 11.8MB/s]
 23%|##2       | 39.0M/170M [00:04<00:10, 12.4MB/s]
 25%|##4       | 42.2M/170M [00:04<00:10, 12.8MB/s]
 27%|##6       | 45.3M/170M [00:05<00:09, 13.1MB/s]
 28%|##8       | 48.4M/170M [00:05<00:09, 13.2MB/s]
 30%|###       | 51.5M/170M [00:05<00:08, 13.4MB/s]
 32%|###2      | 54.7M/170M [00:05<00:08, 13.6MB/s]
 34%|###3      | 57.8M/170M [00:06<00:08, 13.6MB/s]
 36%|###5      | 60.9M/170M [00:06<00:08, 13.6MB/s]
 38%|###7      | 64.1M/170M [00:06<00:07, 13.6MB/s]
 39%|###9      | 67.2M/170M [00:06<00:07, 13.5MB/s]
 41%|####1     | 70.3M/170M [00:06<00:07, 13.8MB/s]
 42%|####2     | 72.4M/170M [00:07<00:06, 14.9MB/s]
 43%|####3     | 74.0M/170M [00:07<00:07, 13.2MB/s]
 45%|####4     | 76.5M/170M [00:07<00:06, 13.5MB/s]
 46%|####5     | 78.2M/170M [00:07<00:06, 14.1MB/s]
 47%|####6     | 79.7M/170M [00:07<00:06, 13.1MB/s]
 48%|####8     | 82.1M/170M [00:07<00:05, 15.5MB/s]
 49%|####9     | 83.8M/170M [00:07<00:06, 13.3MB/s]
 50%|#####     | 85.9M/170M [00:08<00:06, 13.2MB/s]
 51%|#####1    | 87.6M/170M [00:08<00:05, 14.0MB/s]
 52%|#####2    | 89.1M/170M [00:08<00:06, 12.6MB/s]
 54%|#####3    | 91.5M/170M [00:08<00:05, 15.2MB/s]
 55%|#####4    | 93.2M/170M [00:08<00:05, 13.1MB/s]
 56%|#####5    | 95.2M/170M [00:08<00:05, 13.2MB/s]
 57%|#####6    | 96.9M/170M [00:08<00:05, 14.1MB/s]
 58%|#####7    | 98.4M/170M [00:08<00:05, 12.8MB/s]
 59%|#####8    | 100M/170M [00:09<00:05, 13.9MB/s]
 60%|#####9    | 102M/170M [00:09<00:05, 12.4MB/s]
 61%|######1   | 104M/170M [00:09<00:04, 15.6MB/s]
 62%|######2   | 106M/170M [00:09<00:04, 13.3MB/s]
 63%|######3   | 108M/170M [00:09<00:04, 13.0MB/s]
 64%|######4   | 109M/170M [00:09<00:04, 13.9MB/s]
 65%|######5   | 111M/170M [00:09<00:04, 12.5MB/s]
 66%|######6   | 113M/170M [00:10<00:04, 13.7MB/s]
 67%|######6   | 114M/170M [00:10<00:04, 12.3MB/s]
 69%|######8   | 117M/170M [00:10<00:03, 14.2MB/s]
 69%|######9   | 118M/170M [00:10<00:03, 13.8MB/s]
 70%|#######   | 120M/170M [00:10<00:03, 13.6MB/s]
 71%|#######1  | 121M/170M [00:10<00:03, 13.4MB/s]
 72%|#######2  | 123M/170M [00:10<00:03, 14.0MB/s]
 73%|#######3  | 125M/170M [00:10<00:03, 13.2MB/s]
 74%|#######4  | 126M/170M [00:11<00:03, 13.7MB/s]
 75%|#######4  | 128M/170M [00:11<00:03, 13.2MB/s]
 76%|#######5  | 129M/170M [00:11<00:02, 14.2MB/s]
 77%|#######6  | 131M/170M [00:11<00:03, 13.1MB/s]
 78%|#######7  | 132M/170M [00:11<00:02, 13.8MB/s]
 79%|#######8  | 134M/170M [00:11<00:02, 13.1MB/s]
 80%|#######9  | 136M/170M [00:11<00:02, 14.0MB/s]
 80%|########  | 137M/170M [00:11<00:02, 13.1MB/s]
 81%|########1 | 139M/170M [00:11<00:02, 14.0MB/s]
 82%|########2 | 140M/170M [00:12<00:02, 13.1MB/s]
 83%|########3 | 142M/170M [00:12<00:02, 14.0MB/s]
 84%|########4 | 143M/170M [00:12<00:02, 13.1MB/s]
 85%|########5 | 145M/170M [00:12<00:01, 14.1MB/s]
 86%|########5 | 147M/170M [00:12<00:01, 13.1MB/s]
 87%|########6 | 148M/170M [00:12<00:01, 14.1MB/s]
 88%|########7 | 150M/170M [00:12<00:01, 13.1MB/s]
 89%|########8 | 151M/170M [00:12<00:01, 13.9MB/s]
 90%|########9 | 153M/170M [00:13<00:01, 13.1MB/s]
 91%|######### | 154M/170M [00:13<00:01, 14.0MB/s]
 91%|#########1| 156M/170M [00:13<00:01, 13.1MB/s]
 92%|#########2| 158M/170M [00:13<00:00, 14.0MB/s]
 93%|#########3| 159M/170M [00:13<00:00, 13.1MB/s]
 94%|#########4| 161M/170M [00:13<00:00, 14.1MB/s]
 95%|#########5| 162M/170M [00:13<00:00, 13.1MB/s]
 96%|#########6| 164M/170M [00:13<00:00, 13.9MB/s]
 97%|#########6| 165M/170M [00:13<00:00, 13.1MB/s]
 98%|#########7| 167M/170M [00:14<00:00, 14.0MB/s]
 99%|#########8| 168M/170M [00:14<00:00, 13.0MB/s]
100%|#########9| 170M/170M [00:14<00:00, 13.8MB/s]
100%|##########| 170M/170M [00:14<00:00, 11.9MB/s]

让我们展示一些训练图像,以增加趣味性。

import matplotlib.pyplot as plt
import numpy as np

# functions to show an image


def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


# get some random training images
dataiter = iter(trainloader)
images, labels = next(dataiter)

# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))
cifar10 tutorial
plane truck deer  frog

2. 定义一个卷积神经网络

复制之前神经网络部分的神经网络,并将其修改为接受三通道图像(而不是定义为一通道图像)。

import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = Net()

3. 定义一个损失函数和优化器

我们使用分类交叉熵损失和带动量的 SGD。

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

4. 训练网络

这时事情开始变得有趣。我们只需要遍历数据迭代器,将输入提供给网络并进行优化。

for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0

print('Finished Training')
[1,  2000] loss: 2.175
[1,  4000] loss: 1.840
[1,  6000] loss: 1.680
[1,  8000] loss: 1.613
[1, 10000] loss: 1.530
[1, 12000] loss: 1.494
[2,  2000] loss: 1.403
[2,  4000] loss: 1.391
[2,  6000] loss: 1.348
[2,  8000] loss: 1.307
[2, 10000] loss: 1.296
[2, 12000] loss: 1.279
Finished Training

快速保存我们的训练模型:

PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH)

更多关于保存 PyTorch 模型的细节请参阅 这里

5. 在测试数据上测试网络

我们已经对训练数据集进行了两次训练。但我们需要检查网络是否真的学到了东西。

我们将通过预测神经网络输出的类标签,并将其与真实标签进行比较来检查。如果预测正确,我们将样本添加到正确预测列表中。

好的,第一步。让我们显示测试集中的一张图片以熟悉它。

dataiter = iter(testloader)
images, labels = next(dataiter)

# print images
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join(f'{classes[labels[j]]:5s}' for j in range(4)))
cifar10 tutorial
GroundTruth:  cat   ship  ship  plane

接下来,让我们加载我们保存的模型(注意:这里保存和重新加载模型并不是必须的,我们仅仅是为了展示如何操作)。

net = Net()
net.load_state_dict(torch.load(PATH, weights_only=True))
<All keys matched successfully>

好的,现在让我们看看神经网络对上面的示例的看法:

outputs = net(images)

输出是对 10 个类别的能量值。某类别的能量值越高,网络越认为该图像属于该类别。所以,让我们获取最高能量的索引:

_, predicted = torch.max(outputs, 1)

print('Predicted: ', ' '.join(f'{classes[predicted[j]]:5s}'
                              for j in range(4)))
Predicted:  cat   car   ship  plane

结果看起来不错。

让我们看看网络在整个数据集上的表现。

correct = 0
total = 0
# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
    for data in testloader:
        images, labels = data
        # calculate outputs by running images through the network
        outputs = net(images)
        # the class with the highest energy is what we choose as prediction
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')
Accuracy of the network on the 10000 test images: 54 %

这看起来比随机选择(10% 的准确率)要好得多。看起来网络确实学到了一些东西。

嗯,有哪些分类表现良好,而有哪些分类表现不佳:

# prepare to count predictions for each class
correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}

# again no gradients needed
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predictions = torch.max(outputs, 1)
        # collect the correct predictions for each class
        for label, prediction in zip(labels, predictions):
            if label == prediction:
                correct_pred[classes[label]] += 1
            total_pred[classes[label]] += 1


# print accuracy for each class
for classname, correct_count in correct_pred.items():
    accuracy = 100 * float(correct_count) / total_pred[classname]
    print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')
Accuracy for class: plane is 62.4 %
Accuracy for class: car   is 71.4 %
Accuracy for class: bird  is 45.8 %
Accuracy for class: cat   is 50.9 %
Accuracy for class: deer  is 39.3 %
Accuracy for class: dog   is 43.6 %
Accuracy for class: frog  is 53.9 %
Accuracy for class: horse is 62.6 %
Accuracy for class: ship  is 54.1 %
Accuracy for class: truck is 63.5 %

好的,那接下来呢?

我们如何在 GPU 上运行这些神经网络?

在 GPU 上训练

就像将 Tensor 转移到 GPU 一样,可以将神经网络转移到 GPU。

首先定义设备为第一个可见的 CUDA 设备(如果我们有 CUDA 可用的话):

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Assuming that we are on a CUDA machine, this should print a CUDA device:

print(device)
cuda:0

本节其余部分假设 device 是一个 CUDA 设备。

然后这些方法会递归地遍历所有模块并将其参数和缓冲区转换为 CUDA 张量:

net.to(device)

请记住,在每一步中还需要将输入和目标发送到 GPU:

inputs, labels = data[0].to(device), data[1].to(device)

为什么我没有注意到相比 CPU 的大规模加速?因为你的网络确实很小。

练习: 尝试增加网络的宽度(第一个 nn.Conv2d 的参数 2,和第二个 nn.Conv2d 的参数 1——它们需要是相同的数字),看看你能获得什么样的加速。

目标达成:

  • 在高层次上理解 PyTorch 的 Tensor 库和神经网络。

  • 训练一个小型神经网络来进行图像分类

在多个 GPU 上训练

如果你想使用所有的 GPU 获得更大的加速,请查看 可选:数据并行化

文档

访问 PyTorch 的详细开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源