备注
点击 此处 下载完整示例代码
训练一个分类器¶
Created On: Mar 24, 2017 | Last Updated: Dec 20, 2024 | Last Verified: Not Verified
就是这样。您已经看到了如何定义神经网络、计算损失并更新网络权重。
现在您可能会想到,
数据怎么办?¶
通常,当您需要处理图像、文本、音频或视频数据时,可以使用标准的 Python 包将数据加载到一个 numpy 数组中。然后可以将此数组转换为 torch.*Tensor
。
对于图像,可以使用 Pillow、OpenCV 等包
对于音频,可以使用 scipy 和 librosa
对于文本,可以使用原生 Python、基于 Cython 的加载器,或者 NLTK 和 SpaCy
特别是针对视觉任务,我们创建了一个名为 torchvision
的包,它具有常用数据集(例如 ImageNet、CIFAR10、MNIST 等)的数据加载器和图像数据变换工具,即 torchvision.datasets
和 torch.utils.data.DataLoader
。
这极大地提供了便利,避免了编写样板代码。
在本教程中,我们将使用 CIFAR10 数据集。它包含以下类别:’飞机’、’汽车’、’鸟’、’猫’、’鹿’、’狗’、’青蛙’、’马’、’船’、’卡车’。CIFAR-10 的图像大小为 3x32x32,即 32x32 像素的三通道彩色图像。

cifar10¶
训练一个图像分类器¶
我们将按以下步骤进行:
使用
torchvision
加载并标准化 CIFAR10 的训练和测试数据集定义一个卷积神经网络
定义一个损失函数
在训练数据上训练网络
在测试数据上测试网络
1. 加载并标准化 CIFAR10¶
使用 torchvision
,加载 CIFAR10 非常简单。
import torch
import torchvision
import torchvision.transforms as transforms
torchvision 数据集的输出是范围为 [0, 1] 的 PILImage 图像。我们将其转换为范围为 [-1, 1] 的标准化 Tensor。
备注
如果在运行 Windows 时出现 BrokenPipeError,尝试将 torch.utils.data.DataLoader() 的 num_worker 设置为 0。
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
batch_size = 4
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
0%| | 0.00/170M [00:00<?, ?B/s]
0%| | 32.8k/170M [00:00<39:10, 72.5kB/s]
0%| | 98.3k/170M [00:00<17:28, 162kB/s]
0%| | 197k/170M [00:00<10:47, 263kB/s]
0%| | 360k/170M [00:01<06:42, 422kB/s]
0%| | 721k/170M [00:01<03:30, 808kB/s]
1%| | 1.41M/170M [00:01<01:51, 1.52MB/s]
2%|1 | 2.75M/170M [00:01<00:49, 3.37MB/s]
2%|1 | 3.21M/170M [00:01<00:46, 3.57MB/s]
3%|3 | 5.47M/170M [00:02<00:27, 6.00MB/s]
5%|5 | 8.59M/170M [00:02<00:15, 10.3MB/s]
6%|5 | 9.80M/170M [00:02<00:15, 10.3MB/s]
7%|6 | 11.7M/170M [00:02<00:13, 11.9MB/s]
8%|7 | 13.0M/170M [00:02<00:14, 11.0MB/s]
9%|8 | 14.8M/170M [00:02<00:14, 10.6MB/s]
10%|# | 17.1M/170M [00:02<00:11, 13.3MB/s]
11%|# | 18.6M/170M [00:03<00:13, 11.6MB/s]
12%|#2 | 21.1M/170M [00:03<00:12, 12.2MB/s]
13%|#3 | 22.9M/170M [00:03<00:11, 13.3MB/s]
14%|#4 | 24.3M/170M [00:03<00:12, 11.5MB/s]
16%|#5 | 27.1M/170M [00:03<00:09, 15.1MB/s]
17%|#6 | 28.8M/170M [00:03<00:11, 12.8MB/s]
18%|#7 | 30.3M/170M [00:03<00:12, 11.3MB/s]
19%|#9 | 33.0M/170M [00:04<00:11, 11.5MB/s]
21%|##1 | 35.9M/170M [00:04<00:11, 11.8MB/s]
23%|##2 | 39.0M/170M [00:04<00:10, 12.4MB/s]
25%|##4 | 42.2M/170M [00:04<00:10, 12.8MB/s]
27%|##6 | 45.3M/170M [00:05<00:09, 13.1MB/s]
28%|##8 | 48.4M/170M [00:05<00:09, 13.2MB/s]
30%|### | 51.5M/170M [00:05<00:08, 13.4MB/s]
32%|###2 | 54.7M/170M [00:05<00:08, 13.6MB/s]
34%|###3 | 57.8M/170M [00:06<00:08, 13.6MB/s]
36%|###5 | 60.9M/170M [00:06<00:08, 13.6MB/s]
38%|###7 | 64.1M/170M [00:06<00:07, 13.6MB/s]
39%|###9 | 67.2M/170M [00:06<00:07, 13.5MB/s]
41%|####1 | 70.3M/170M [00:06<00:07, 13.8MB/s]
42%|####2 | 72.4M/170M [00:07<00:06, 14.9MB/s]
43%|####3 | 74.0M/170M [00:07<00:07, 13.2MB/s]
45%|####4 | 76.5M/170M [00:07<00:06, 13.5MB/s]
46%|####5 | 78.2M/170M [00:07<00:06, 14.1MB/s]
47%|####6 | 79.7M/170M [00:07<00:06, 13.1MB/s]
48%|####8 | 82.1M/170M [00:07<00:05, 15.5MB/s]
49%|####9 | 83.8M/170M [00:07<00:06, 13.3MB/s]
50%|##### | 85.9M/170M [00:08<00:06, 13.2MB/s]
51%|#####1 | 87.6M/170M [00:08<00:05, 14.0MB/s]
52%|#####2 | 89.1M/170M [00:08<00:06, 12.6MB/s]
54%|#####3 | 91.5M/170M [00:08<00:05, 15.2MB/s]
55%|#####4 | 93.2M/170M [00:08<00:05, 13.1MB/s]
56%|#####5 | 95.2M/170M [00:08<00:05, 13.2MB/s]
57%|#####6 | 96.9M/170M [00:08<00:05, 14.1MB/s]
58%|#####7 | 98.4M/170M [00:08<00:05, 12.8MB/s]
59%|#####8 | 100M/170M [00:09<00:05, 13.9MB/s]
60%|#####9 | 102M/170M [00:09<00:05, 12.4MB/s]
61%|######1 | 104M/170M [00:09<00:04, 15.6MB/s]
62%|######2 | 106M/170M [00:09<00:04, 13.3MB/s]
63%|######3 | 108M/170M [00:09<00:04, 13.0MB/s]
64%|######4 | 109M/170M [00:09<00:04, 13.9MB/s]
65%|######5 | 111M/170M [00:09<00:04, 12.5MB/s]
66%|######6 | 113M/170M [00:10<00:04, 13.7MB/s]
67%|######6 | 114M/170M [00:10<00:04, 12.3MB/s]
69%|######8 | 117M/170M [00:10<00:03, 14.2MB/s]
69%|######9 | 118M/170M [00:10<00:03, 13.8MB/s]
70%|####### | 120M/170M [00:10<00:03, 13.6MB/s]
71%|#######1 | 121M/170M [00:10<00:03, 13.4MB/s]
72%|#######2 | 123M/170M [00:10<00:03, 14.0MB/s]
73%|#######3 | 125M/170M [00:10<00:03, 13.2MB/s]
74%|#######4 | 126M/170M [00:11<00:03, 13.7MB/s]
75%|#######4 | 128M/170M [00:11<00:03, 13.2MB/s]
76%|#######5 | 129M/170M [00:11<00:02, 14.2MB/s]
77%|#######6 | 131M/170M [00:11<00:03, 13.1MB/s]
78%|#######7 | 132M/170M [00:11<00:02, 13.8MB/s]
79%|#######8 | 134M/170M [00:11<00:02, 13.1MB/s]
80%|#######9 | 136M/170M [00:11<00:02, 14.0MB/s]
80%|######## | 137M/170M [00:11<00:02, 13.1MB/s]
81%|########1 | 139M/170M [00:11<00:02, 14.0MB/s]
82%|########2 | 140M/170M [00:12<00:02, 13.1MB/s]
83%|########3 | 142M/170M [00:12<00:02, 14.0MB/s]
84%|########4 | 143M/170M [00:12<00:02, 13.1MB/s]
85%|########5 | 145M/170M [00:12<00:01, 14.1MB/s]
86%|########5 | 147M/170M [00:12<00:01, 13.1MB/s]
87%|########6 | 148M/170M [00:12<00:01, 14.1MB/s]
88%|########7 | 150M/170M [00:12<00:01, 13.1MB/s]
89%|########8 | 151M/170M [00:12<00:01, 13.9MB/s]
90%|########9 | 153M/170M [00:13<00:01, 13.1MB/s]
91%|######### | 154M/170M [00:13<00:01, 14.0MB/s]
91%|#########1| 156M/170M [00:13<00:01, 13.1MB/s]
92%|#########2| 158M/170M [00:13<00:00, 14.0MB/s]
93%|#########3| 159M/170M [00:13<00:00, 13.1MB/s]
94%|#########4| 161M/170M [00:13<00:00, 14.1MB/s]
95%|#########5| 162M/170M [00:13<00:00, 13.1MB/s]
96%|#########6| 164M/170M [00:13<00:00, 13.9MB/s]
97%|#########6| 165M/170M [00:13<00:00, 13.1MB/s]
98%|#########7| 167M/170M [00:14<00:00, 14.0MB/s]
99%|#########8| 168M/170M [00:14<00:00, 13.0MB/s]
100%|#########9| 170M/170M [00:14<00:00, 13.8MB/s]
100%|##########| 170M/170M [00:14<00:00, 11.9MB/s]
让我们展示一些训练图像,以增加趣味性。
import matplotlib.pyplot as plt
import numpy as np
# functions to show an image
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
# get some random training images
dataiter = iter(trainloader)
images, labels = next(dataiter)
# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))

plane truck deer frog
2. 定义一个卷积神经网络¶
复制之前神经网络部分的神经网络,并将其修改为接受三通道图像(而不是定义为一通道图像)。
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1) # flatten all dimensions except batch
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
3. 定义一个损失函数和优化器¶
我们使用分类交叉熵损失和带动量的 SGD。
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
4. 训练网络¶
这时事情开始变得有趣。我们只需要遍历数据迭代器,将输入提供给网络并进行优化。
for epoch in range(2): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if i % 2000 == 1999: # print every 2000 mini-batches
print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
running_loss = 0.0
print('Finished Training')
[1, 2000] loss: 2.175
[1, 4000] loss: 1.840
[1, 6000] loss: 1.680
[1, 8000] loss: 1.613
[1, 10000] loss: 1.530
[1, 12000] loss: 1.494
[2, 2000] loss: 1.403
[2, 4000] loss: 1.391
[2, 6000] loss: 1.348
[2, 8000] loss: 1.307
[2, 10000] loss: 1.296
[2, 12000] loss: 1.279
Finished Training
快速保存我们的训练模型:
PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH)
更多关于保存 PyTorch 模型的细节请参阅 这里。
5. 在测试数据上测试网络¶
我们已经对训练数据集进行了两次训练。但我们需要检查网络是否真的学到了东西。
我们将通过预测神经网络输出的类标签,并将其与真实标签进行比较来检查。如果预测正确,我们将样本添加到正确预测列表中。
好的,第一步。让我们显示测试集中的一张图片以熟悉它。
dataiter = iter(testloader)
images, labels = next(dataiter)
# print images
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join(f'{classes[labels[j]]:5s}' for j in range(4)))

GroundTruth: cat ship ship plane
接下来,让我们加载我们保存的模型(注意:这里保存和重新加载模型并不是必须的,我们仅仅是为了展示如何操作)。
net = Net()
net.load_state_dict(torch.load(PATH, weights_only=True))
<All keys matched successfully>
好的,现在让我们看看神经网络对上面的示例的看法:
输出是对 10 个类别的能量值。某类别的能量值越高,网络越认为该图像属于该类别。所以,让我们获取最高能量的索引:
Predicted: cat car ship plane
结果看起来不错。
让我们看看网络在整个数据集上的表现。
correct = 0
total = 0
# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
for data in testloader:
images, labels = data
# calculate outputs by running images through the network
outputs = net(images)
# the class with the highest energy is what we choose as prediction
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')
Accuracy of the network on the 10000 test images: 54 %
这看起来比随机选择(10% 的准确率)要好得多。看起来网络确实学到了一些东西。
嗯,有哪些分类表现良好,而有哪些分类表现不佳:
# prepare to count predictions for each class
correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}
# again no gradients needed
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = net(images)
_, predictions = torch.max(outputs, 1)
# collect the correct predictions for each class
for label, prediction in zip(labels, predictions):
if label == prediction:
correct_pred[classes[label]] += 1
total_pred[classes[label]] += 1
# print accuracy for each class
for classname, correct_count in correct_pred.items():
accuracy = 100 * float(correct_count) / total_pred[classname]
print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')
Accuracy for class: plane is 62.4 %
Accuracy for class: car is 71.4 %
Accuracy for class: bird is 45.8 %
Accuracy for class: cat is 50.9 %
Accuracy for class: deer is 39.3 %
Accuracy for class: dog is 43.6 %
Accuracy for class: frog is 53.9 %
Accuracy for class: horse is 62.6 %
Accuracy for class: ship is 54.1 %
Accuracy for class: truck is 63.5 %
好的,那接下来呢?
我们如何在 GPU 上运行这些神经网络?
在 GPU 上训练¶
就像将 Tensor 转移到 GPU 一样,可以将神经网络转移到 GPU。
首先定义设备为第一个可见的 CUDA 设备(如果我们有 CUDA 可用的话):
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# Assuming that we are on a CUDA machine, this should print a CUDA device:
print(device)
cuda:0
本节其余部分假设 device
是一个 CUDA 设备。
然后这些方法会递归地遍历所有模块并将其参数和缓冲区转换为 CUDA 张量:
net.to(device)
请记住,在每一步中还需要将输入和目标发送到 GPU:
为什么我没有注意到相比 CPU 的大规模加速?因为你的网络确实很小。
练习: 尝试增加网络的宽度(第一个 nn.Conv2d
的参数 2,和第二个 nn.Conv2d
的参数 1——它们需要是相同的数字),看看你能获得什么样的加速。
目标达成:
在高层次上理解 PyTorch 的 Tensor 库和神经网络。
训练一个小型神经网络来进行图像分类