Shortcuts

简介 || 张量 || 自动梯度 || 构建模型 || TensorBoard 支持 || 训练模型 || 模型理解

使用 PyTorch 进行训练

Created On: Nov 30, 2021 | Last Updated: May 31, 2023 | Last Verified: Nov 05, 2024

可观看下面的视频,或者在`YouTube <https://www.youtube.com/watch?v=jF43_wj_DCQ>`__上观看。

介绍

在以往的视频中,我们讨论并演示了:

  • 使用torch.nn模块的神经网络层和函数来构建模型

  • 自动梯度计算的机制,这是基于梯度的模型训练的核心

  • 使用TensorBoard可视化训练进展和其他活动

在本视频中,我们将为您添加一些新的工具。

  • 我们将熟悉数据集和数据加载器抽象,以及它们如何简化在训练循环中为模型提供数据的过程

  • 我们将讨论特定的损失函数以及何时使用它们

  • 我们将了解PyTorch优化器,它们实现了根据损失函数结果调整模型权重的算法

最后,我们将把这些整合在一起,实际展示完整的PyTorch训练循环。

数据集和数据加载器

“Dataset”和“DataLoader”类封装了从存储中提取数据并以批量形式将其暴露给训练循环的过程。

“Dataset”负责访问和处理单个数据实例。

“DataLoader”从“Dataset”中提取数据实例(可以自动或使用您定义的采样器),将它们收集成批量,并返回给训练循环使用。“DataLoader”可以处理各种类型的数据集,无论数据的类型如何。

在本教程中,我们将使用TorchVision提供的Fashion-MNIST数据集。我们使用“torchvision.transforms.Normalize()”对图像的内容进行零中心化和归一化,并下载训练和验证数据集。

import torch
import torchvision
import torchvision.transforms as transforms

# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime


transform = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))])

# Create datasets for training & validation, download if necessary
training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform, download=True)
validation_set = torchvision.datasets.FashionMNIST('./data', train=False, transform=transform, download=True)

# Create data loaders for our datasets; shuffle for training, not for validation
training_loader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=4, shuffle=False)

# Class labels
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
        'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')

# Report split sizes
print('Training set has {} instances'.format(len(training_set)))
print('Validation set has {} instances'.format(len(validation_set)))
  0%|          | 0.00/26.4M [00:00<?, ?B/s]
  0%|          | 32.8k/26.4M [00:00<01:53, 232kB/s]
  0%|          | 65.5k/26.4M [00:00<01:54, 230kB/s]
  0%|          | 98.3k/26.4M [00:00<01:54, 229kB/s]
  1%|          | 197k/26.4M [00:00<01:04, 409kB/s]
  1%|1         | 360k/26.4M [00:00<00:38, 674kB/s]
  3%|2         | 721k/26.4M [00:00<00:19, 1.30MB/s]
  5%|5         | 1.44M/26.4M [00:01<00:09, 2.51MB/s]
 11%|#         | 2.85M/26.4M [00:01<00:04, 4.84MB/s]
 19%|#9        | 5.14M/26.4M [00:01<00:02, 8.32MB/s]
 31%|###1      | 8.29M/26.4M [00:01<00:01, 12.5MB/s]
 43%|####3     | 11.4M/26.4M [00:01<00:00, 15.3MB/s]
 55%|#####5    | 14.5M/26.4M [00:01<00:00, 17.3MB/s]
 66%|######6   | 17.5M/26.4M [00:01<00:00, 18.3MB/s]
 78%|#######8  | 20.6M/26.4M [00:02<00:00, 19.3MB/s]
 89%|########9 | 23.6M/26.4M [00:02<00:00, 19.7MB/s]
100%|##########| 26.4M/26.4M [00:02<00:00, 11.9MB/s]

  0%|          | 0.00/29.5k [00:00<?, ?B/s]
100%|##########| 29.5k/29.5k [00:00<00:00, 186kB/s]
100%|##########| 29.5k/29.5k [00:00<00:00, 186kB/s]

  0%|          | 0.00/4.42M [00:00<?, ?B/s]
  1%|          | 32.8k/4.42M [00:00<00:23, 185kB/s]
  1%|1         | 65.5k/4.42M [00:00<00:23, 185kB/s]
  2%|2         | 98.3k/4.42M [00:00<00:23, 185kB/s]
  5%|5         | 229k/4.42M [00:00<00:10, 403kB/s]
 10%|9         | 426k/4.42M [00:00<00:05, 670kB/s]
 20%|##        | 885k/4.42M [00:01<00:02, 1.30MB/s]
 39%|###9      | 1.74M/4.42M [00:01<00:01, 2.49MB/s]
 79%|#######8  | 3.47M/4.42M [00:01<00:00, 4.83MB/s]
100%|##########| 4.42M/4.42M [00:01<00:00, 3.11MB/s]

  0%|          | 0.00/5.15k [00:00<?, ?B/s]
100%|##########| 5.15k/5.15k [00:00<00:00, 26.1MB/s]
Training set has 60000 instances
Validation set has 10000 instances

像往常一样,让我们将数据可视化以进行基本检查:

import matplotlib.pyplot as plt
import numpy as np

# Helper function for inline image display
def matplotlib_imshow(img, one_channel=False):
    if one_channel:
        img = img.mean(dim=0)
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    if one_channel:
        plt.imshow(npimg, cmap="Greys")
    else:
        plt.imshow(np.transpose(npimg, (1, 2, 0)))

dataiter = iter(training_loader)
images, labels = next(dataiter)

# Create a grid from the images and show them
img_grid = torchvision.utils.make_grid(images)
matplotlib_imshow(img_grid, one_channel=True)
print('  '.join(classes[labels[j]] for j in range(4)))
trainingyt
Coat  Sneaker  Sandal  Dress

模型

我们在本示例中使用的模型是LeNet-5的一个变体——如果您观看过系列之前的视频,它应该会很熟悉。

import torch.nn as nn
import torch.nn.functional as F

# PyTorch models inherit from torch.nn.Module
class GarmentClassifier(nn.Module):
    def __init__(self):
        super(GarmentClassifier, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


model = GarmentClassifier()

损失函数

在本示例中,我们将使用交叉熵损失。为了演示,我们将创建一些虚拟的输出和标签批量值,将它们输入到损失函数中,并观察结果。

loss_fn = torch.nn.CrossEntropyLoss()

# NB: Loss functions expect data in batches, so we're creating batches of 4
# Represents the model's confidence in each of the 10 classes for a given input
dummy_outputs = torch.rand(4, 10)
# Represents the correct class among the 10 being tested
dummy_labels = torch.tensor([1, 5, 3, 7])

print(dummy_outputs)
print(dummy_labels)

loss = loss_fn(dummy_outputs, dummy_labels)
print('Total loss for this batch: {}'.format(loss.item()))
tensor([[0.1361, 0.2725, 0.0907, 0.0082, 0.4228, 0.7388, 0.4125, 0.5239, 0.7478,
         0.2745],
        [0.9673, 0.0197, 0.7789, 0.9434, 0.9163, 0.0952, 0.0651, 0.4475, 0.3295,
         0.4615],
        [0.0884, 0.4782, 0.7056, 0.8047, 0.7331, 0.6193, 0.0116, 0.1595, 0.9667,
         0.9834],
        [0.9127, 0.2350, 0.8619, 0.5442, 0.4277, 0.9194, 0.4962, 0.0656, 0.8193,
         0.5803]])
tensor([1, 5, 3, 7])
Total loss for this batch: 2.5408973693847656

优化器

本示例中我们将使用简单的`随机梯度下降法 <https://pytorch.org/docs/stable/optim.html>`__配合动量。

尝试一些该优化方案的变种可能是很有启发性的:

  • 学习率决定了优化器每次迈出的步伐有多大。不同的学习率会对训练结果的准确性和收敛时间带来什么样的影响?

  • 动量使优化器在多步中朝着最强梯度的方向调整。当调整该值时,结果又会如何变化?

  • 尝试一些不同的优化算法,例如平均化的SGD、Adagrad或Adam。结果有何不同?

# Optimizers specified in the torch.optim package
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

训练循环

如下所示,我们有一个执行单个训练轮次的函数。它从DataLoader枚举数据,在每次循环中执行以下操作:

  • 从DataLoader中获取一批训练数据

  • 将优化器的梯度归零

  • 执行一次推理——即从模型中获取一个输入批量的预测

  • 计算该批预测和数据集标签之间的损失

  • 计算学习权重的后向梯度

  • 告诉优化器根据我们选择的优化算法对这个批次的观察梯度执行一次学习步骤——即调整模型的学习权重

  • 它会每1000个批次报告一次损失。

  • 最后,它会报告最后1000个批次的平均每批损失,以便与验证运行进行比较

def train_one_epoch(epoch_index, tb_writer):
    running_loss = 0.
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(training_loader):
        # Every data instance is an input + label pair
        inputs, labels = data

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = model(inputs)

        # Compute the loss and its gradients
        loss = loss_fn(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        if i % 1000 == 999:
            last_loss = running_loss / 1000 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            tb_x = epoch_index * len(training_loader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.

    return last_loss

每个轮次的活动

每个轮次,我们需要完成以下几件事情:

  • 通过检查未用于训练的数据集上的相对损失来进行验证,并报告结果

  • 保存模型的一个副本

在这里,我们将在TensorBoard中进行报告。这需要打开命令行启动TensorBoard,并在另一个浏览器标签页中打开它。

# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
epoch_number = 0

EPOCHS = 5

best_vloss = 1_000_000.

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch(epoch_number, writer)


    running_vloss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, vdata in enumerate(validation_loader):
            vinputs, vlabels = vdata
            voutputs = model(vinputs)
            vloss = loss_fn(voutputs, vlabels)
            running_vloss += vloss

    avg_vloss = running_vloss / (i + 1)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars('Training vs. Validation Loss',
                    { 'Training' : avg_loss, 'Validation' : avg_vloss },
                    epoch_number + 1)
    writer.flush()

    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = 'model_{}_{}'.format(timestamp, epoch_number)
        torch.save(model.state_dict(), model_path)

    epoch_number += 1
EPOCH 1:
  batch 1000 loss: 1.859723868340254
  batch 2000 loss: 0.7959158919476904
  batch 3000 loss: 0.6875963613968342
  batch 4000 loss: 0.5840691136796958
  batch 5000 loss: 0.601861707421951
  batch 6000 loss: 0.5326420218963176
  batch 7000 loss: 0.5101273439126089
  batch 8000 loss: 0.4920893395068124
  batch 9000 loss: 0.45734640963794665
  batch 10000 loss: 0.4638966364953085
  batch 11000 loss: 0.4171614876713138
  batch 12000 loss: 0.4280265563330031
  batch 13000 loss: 0.4263071584069985
  batch 14000 loss: 0.4181580100507708
  batch 15000 loss: 0.42254892712844594
LOSS train 0.42254892712844594 valid 0.4007205665111542
EPOCH 2:
  batch 1000 loss: 0.3819776755711646
  batch 2000 loss: 0.3976921703386761
  batch 3000 loss: 0.37814622786745894
  batch 4000 loss: 0.34922556551004524
  batch 5000 loss: 0.34570265819173074
  batch 6000 loss: 0.3456719932517153
  batch 7000 loss: 0.34924780977396586
  batch 8000 loss: 0.35666742020787207
  batch 9000 loss: 0.33993772263024585
  batch 10000 loss: 0.3685673026727891
  batch 11000 loss: 0.3419286552860576
  batch 12000 loss: 0.33543308569528746
  batch 13000 loss: 0.33815630553881054
  batch 14000 loss: 0.3307116681025509
  batch 15000 loss: 0.3606875884762267
LOSS train 0.3606875884762267 valid 0.3472764194011688
EPOCH 3:
  batch 1000 loss: 0.30639552796969655
  batch 2000 loss: 0.3228924395108479
  batch 3000 loss: 0.31235727290814974
  batch 4000 loss: 0.3145854706429818
  batch 5000 loss: 0.29222741585151746
  batch 6000 loss: 0.31345307654022325
  batch 7000 loss: 0.3163325209467439
  batch 8000 loss: 0.31647060124657583
  batch 9000 loss: 0.3141732101450179
  batch 10000 loss: 0.299438458041157
  batch 11000 loss: 0.30074762070312866
  batch 12000 loss: 0.29664640293602135
  batch 13000 loss: 0.3008534389541455
  batch 14000 loss: 0.3202067443535416
  batch 15000 loss: 0.3115060192462843
LOSS train 0.3115060192462843 valid 0.33710092306137085
EPOCH 4:
  batch 1000 loss: 0.27951306609585846
  batch 2000 loss: 0.29738874051375025
  batch 3000 loss: 0.28281411081092667
  batch 4000 loss: 0.28956373607418934
  batch 5000 loss: 0.28590615158868604
  batch 6000 loss: 0.2835932457768722
  batch 7000 loss: 0.27401943222882985
  batch 8000 loss: 0.2870593838140994
  batch 9000 loss: 0.27542450427323456
  batch 10000 loss: 0.28151207812947177
  batch 11000 loss: 0.27685102666457534
  batch 12000 loss: 0.2891393224728527
  batch 13000 loss: 0.2877112180689146
  batch 14000 loss: 0.281157614742333
  batch 15000 loss: 0.2708676661506033
LOSS train 0.2708676661506033 valid 0.30610325932502747
EPOCH 5:
  batch 1000 loss: 0.2748342432942154
  batch 2000 loss: 0.2569611764227302
  batch 3000 loss: 0.24103483629236688
  batch 4000 loss: 0.2569761824092675
  batch 5000 loss: 0.2605063391393851
  batch 6000 loss: 0.2804330715298693
  batch 7000 loss: 0.25689922451737585
  batch 8000 loss: 0.2736012614666906
  batch 9000 loss: 0.2698981114967537
  batch 10000 loss: 0.26709053535146543
  batch 11000 loss: 0.2632426246944069
  batch 12000 loss: 0.27027270019240224
  batch 13000 loss: 0.2651329760397275
  batch 14000 loss: 0.280449686764925
  batch 15000 loss: 0.2526728889870938
LOSS train 0.2526728889870938 valid 0.28390395641326904

加载已保存的模型版本:

saved_model = GarmentClassifier()
saved_model.load_state_dict(torch.load(PATH))

一旦加载了模型,它就可以根据需要用于更多训练、推理或分析。

请注意,如果模型具有影响模型结构的构造参数,则需要提供这些参数并将模型配置为与保存时的状态一致。

其他资源

**脚本的总运行时间:**(27分钟16.034秒)

画廊由 Sphinx-Gallery 生成

文档

访问 PyTorch 的详细开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源