Shortcuts

空间变换网络教程

Created On: Nov 08, 2017 | Last Updated: Jan 19, 2024 | Last Verified: Nov 05, 2024

作者Ghassen HAMROUNI

../_images/FSeq.png

在此教程中,你将学习如何使用一种称为空间变换网络的视觉注意机制来增强你的网络。你可以在《DeepMind论文》(<https://arxiv.org/abs/1506.02025>)中了解有关空间变换网络的更多信息。

空间变换网络是任何空间变换的可微注意力的泛化。空间变换网络(简称STN)使神经网络能够学习如何在输入图像上进行空间变换,以增强模型的几何不变性。例如,它可以裁剪感兴趣区域、缩放和校正图像的方向。它是一种有用的机制,因为CNN对旋转和缩放以及更一般的仿射变换不具有不变性。

关于STN的最棒事情之一是它可以简单地插入任何现有的CNN中,只需很少的修改。

# License: BSD
# Author: Ghassen Hamrouni

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np

plt.ion()   # interactive mode
<contextlib.ExitStack object at 0x7f1f9f1bb6d0>

加载数据

在本文中,我们实验经典的MNIST数据集。使用一个增添了空间变换网络的标准卷积网络。

from six.moves import urllib
opener = urllib.request.build_opener()
opener.addheaders = [('User-agent', 'Mozilla/5.0')]
urllib.request.install_opener(opener)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Training dataset
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root='.', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])), batch_size=64, shuffle=True, num_workers=4)
# Test dataset
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root='.', train=False, transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])), batch_size=64, shuffle=True, num_workers=4)
  0%|          | 0.00/9.91M [00:00<?, ?B/s]
  0%|          | 32.8k/9.91M [00:00<01:10, 139kB/s]
  1%|          | 65.5k/9.91M [00:00<01:12, 137kB/s]
  2%|1         | 164k/9.91M [00:00<00:55, 175kB/s]
  4%|3         | 360k/9.91M [00:01<00:25, 372kB/s]
  5%|4         | 492k/9.91M [00:01<00:22, 425kB/s]
  6%|5         | 590k/9.91M [00:01<00:29, 317kB/s]
  8%|7         | 786k/9.91M [00:02<00:20, 441kB/s]
  9%|8         | 852k/9.91M [00:02<00:22, 398kB/s]
  9%|9         | 918k/9.91M [00:02<00:24, 361kB/s]
 10%|#         | 1.02M/9.91M [00:02<00:24, 371kB/s]
 11%|#         | 1.08M/9.91M [00:03<00:25, 344kB/s]
 12%|#1        | 1.15M/9.91M [00:03<00:27, 324kB/s]
 13%|#2        | 1.25M/9.91M [00:03<00:25, 346kB/s]
 13%|#3        | 1.31M/9.91M [00:03<00:29, 292kB/s]
 14%|#4        | 1.41M/9.91M [00:04<00:23, 358kB/s]
 15%|#4        | 1.47M/9.91M [00:04<00:25, 333kB/s]
 16%|#5        | 1.54M/9.91M [00:04<00:26, 316kB/s]
 16%|#6        | 1.61M/9.91M [00:04<00:27, 301kB/s]
 17%|#6        | 1.64M/9.91M [00:05<00:33, 249kB/s]
 17%|#7        | 1.70M/9.91M [00:05<00:31, 257kB/s]
 18%|#7        | 1.77M/9.91M [00:05<00:31, 262kB/s]
 19%|#8        | 1.87M/9.91M [00:05<00:26, 304kB/s]
 20%|#9        | 1.93M/9.91M [00:06<00:27, 292kB/s]
 20%|##        | 2.00M/9.91M [00:06<00:27, 283kB/s]
 21%|##        | 2.06M/9.91M [00:06<00:27, 281kB/s]
 21%|##1       | 2.13M/9.91M [00:06<00:27, 279kB/s]
 22%|##2       | 2.20M/9.91M [00:07<00:27, 276kB/s]
 23%|##2       | 2.26M/9.91M [00:07<00:27, 282kB/s]
 23%|##3       | 2.33M/9.91M [00:07<00:23, 329kB/s]
 24%|##4       | 2.39M/9.91M [00:07<00:21, 351kB/s]
 25%|##4       | 2.46M/9.91M [00:07<00:22, 325kB/s]
 25%|##5       | 2.52M/9.91M [00:08<00:24, 308kB/s]
 26%|##5       | 2.56M/9.91M [00:08<00:29, 251kB/s]
 27%|##6       | 2.65M/9.91M [00:08<00:24, 298kB/s]
 27%|##7       | 2.72M/9.91M [00:08<00:26, 276kB/s]
 28%|##8       | 2.79M/9.91M [00:08<00:24, 289kB/s]
 29%|##8       | 2.85M/9.91M [00:09<00:24, 283kB/s]
 29%|##9       | 2.88M/9.91M [00:09<00:29, 236kB/s]
 30%|###       | 2.98M/9.91M [00:09<00:24, 287kB/s]
 30%|###       | 3.01M/9.91M [00:09<00:27, 250kB/s]
 31%|###1      | 3.08M/9.91M [00:10<00:26, 256kB/s]
 32%|###1      | 3.15M/9.91M [00:10<00:26, 258kB/s]
 33%|###2      | 3.24M/9.91M [00:10<00:22, 294kB/s]
 33%|###3      | 3.31M/9.91M [00:10<00:22, 289kB/s]
 34%|###4      | 3.38M/9.91M [00:11<00:23, 283kB/s]
 35%|###4      | 3.44M/9.91M [00:11<00:23, 277kB/s]
 35%|###5      | 3.51M/9.91M [00:11<00:23, 277kB/s]
 36%|###6      | 3.57M/9.91M [00:11<00:22, 286kB/s]
 37%|###6      | 3.64M/9.91M [00:12<00:22, 281kB/s]
 37%|###7      | 3.70M/9.91M [00:12<00:22, 276kB/s]
 38%|###8      | 3.80M/9.91M [00:12<00:20, 303kB/s]
 39%|###9      | 3.87M/9.91M [00:12<00:20, 295kB/s]
 40%|###9      | 3.93M/9.91M [00:13<00:20, 289kB/s]
 40%|####      | 4.00M/9.91M [00:13<00:21, 281kB/s]
 41%|####1     | 4.10M/9.91M [00:13<00:18, 316kB/s]
 42%|####1     | 4.16M/9.91M [00:13<00:18, 305kB/s]
 43%|####2     | 4.23M/9.91M [00:14<00:19, 294kB/s]
 44%|####3     | 4.33M/9.91M [00:14<00:17, 324kB/s]
 44%|####3     | 4.36M/9.91M [00:14<00:20, 269kB/s]
 45%|####4     | 4.42M/9.91M [00:14<00:17, 316kB/s]
 45%|####5     | 4.49M/9.91M [00:15<00:20, 260kB/s]
 46%|####5     | 4.55M/9.91M [00:15<00:26, 202kB/s]
 47%|####6     | 4.65M/9.91M [00:15<00:20, 259kB/s]
 47%|####7     | 4.69M/9.91M [00:15<00:19, 265kB/s]
 48%|####7     | 4.72M/9.91M [00:16<00:20, 250kB/s]
 48%|####7     | 4.75M/9.91M [00:16<00:32, 160kB/s]
 49%|####8     | 4.82M/9.91M [00:16<00:25, 196kB/s]
 49%|####9     | 4.88M/9.91M [00:17<00:24, 209kB/s]
 50%|####9     | 4.92M/9.91M [00:17<00:26, 187kB/s]
 50%|####9     | 4.95M/9.91M [00:17<00:28, 177kB/s]
 50%|#####     | 4.98M/9.91M [00:17<00:29, 166kB/s]
 51%|#####     | 5.01M/9.91M [00:17<00:31, 157kB/s]
 51%|#####     | 5.05M/9.91M [00:18<00:27, 177kB/s]
 51%|#####1    | 5.08M/9.91M [00:18<00:28, 171kB/s]
 52%|#####1    | 5.11M/9.91M [00:18<00:28, 169kB/s]
 52%|#####2    | 5.18M/9.91M [00:18<00:26, 182kB/s]
 53%|#####2    | 5.21M/9.91M [00:18<00:25, 188kB/s]
 53%|#####2    | 5.24M/9.91M [00:19<00:26, 178kB/s]
 53%|#####3    | 5.28M/9.91M [00:19<00:28, 164kB/s]
 54%|#####3    | 5.31M/9.91M [00:19<00:26, 172kB/s]
 54%|#####3    | 5.34M/9.91M [00:19<00:26, 170kB/s]
 54%|#####4    | 5.37M/9.91M [00:19<00:27, 168kB/s]
 55%|#####4    | 5.41M/9.91M [00:20<00:25, 177kB/s]
 55%|#####4    | 5.44M/9.91M [00:20<00:22, 198kB/s]
 55%|#####5    | 5.47M/9.91M [00:20<00:24, 184kB/s]
 56%|#####5    | 5.51M/9.91M [00:20<00:24, 177kB/s]
 56%|#####5    | 5.54M/9.91M [00:20<00:25, 168kB/s]
 56%|#####6    | 5.57M/9.91M [00:21<00:27, 159kB/s]
 57%|#####6    | 5.60M/9.91M [00:21<00:23, 184kB/s]
 57%|#####6    | 5.64M/9.91M [00:21<00:24, 177kB/s]
 57%|#####7    | 5.67M/9.91M [00:21<00:24, 173kB/s]
 58%|#####7    | 5.73M/9.91M [00:21<00:21, 193kB/s]
 58%|#####8    | 5.77M/9.91M [00:22<00:22, 186kB/s]
 59%|#####8    | 5.80M/9.91M [00:22<00:23, 177kB/s]
 59%|#####8    | 5.83M/9.91M [00:22<00:20, 198kB/s]
 59%|#####9    | 5.87M/9.91M [00:22<00:21, 185kB/s]
 60%|#####9    | 5.90M/9.91M [00:22<00:25, 159kB/s]
 60%|#####9    | 5.93M/9.91M [00:23<00:22, 174kB/s]
 60%|######    | 6.00M/9.91M [00:23<00:20, 192kB/s]
 61%|######    | 6.03M/9.91M [00:23<00:21, 181kB/s]
 61%|######1   | 6.06M/9.91M [00:23<00:23, 166kB/s]
 61%|######1   | 6.09M/9.91M [00:23<00:20, 185kB/s]
 62%|######1   | 6.13M/9.91M [00:24<00:20, 181kB/s]
 62%|######2   | 6.16M/9.91M [00:24<00:21, 174kB/s]
 62%|######2   | 6.19M/9.91M [00:24<00:21, 169kB/s]
 63%|######2   | 6.23M/9.91M [00:24<00:24, 148kB/s]
 63%|######3   | 6.29M/9.91M [00:25<00:19, 187kB/s]
 64%|######3   | 6.32M/9.91M [00:25<00:20, 172kB/s]
 64%|######4   | 6.36M/9.91M [00:25<00:22, 160kB/s]
 64%|######4   | 6.39M/9.91M [00:25<00:22, 159kB/s]
 65%|######4   | 6.42M/9.91M [00:26<00:22, 153kB/s]
 65%|######5   | 6.46M/9.91M [00:26<00:19, 174kB/s]
 65%|######5   | 6.49M/9.91M [00:26<00:18, 182kB/s]
 66%|######5   | 6.52M/9.91M [00:26<00:19, 177kB/s]
 66%|######6   | 6.59M/9.91M [00:26<00:17, 195kB/s]
 67%|######6   | 6.62M/9.91M [00:27<00:18, 177kB/s]
 67%|######7   | 6.65M/9.91M [00:27<00:18, 174kB/s]
 67%|######7   | 6.68M/9.91M [00:27<00:16, 191kB/s]
 68%|######7   | 6.72M/9.91M [00:27<00:16, 193kB/s]
 68%|######8   | 6.75M/9.91M [00:27<00:18, 171kB/s]
 68%|######8   | 6.78M/9.91M [00:27<00:18, 170kB/s]
 69%|######9   | 6.85M/9.91M [00:28<00:15, 194kB/s]
 69%|######9   | 6.88M/9.91M [00:28<00:16, 185kB/s]
 70%|#######   | 6.95M/9.91M [00:28<00:14, 198kB/s]
 70%|#######   | 6.98M/9.91M [00:28<00:16, 180kB/s]
 71%|#######   | 7.01M/9.91M [00:29<00:16, 178kB/s]
 71%|#######1  | 7.05M/9.91M [00:29<00:14, 192kB/s]
 71%|#######1  | 7.08M/9.91M [00:29<00:13, 210kB/s]
 72%|#######1  | 7.11M/9.91M [00:29<00:12, 219kB/s]
 72%|#######2  | 7.14M/9.91M [00:29<00:12, 213kB/s]
 72%|#######2  | 7.18M/9.91M [00:29<00:14, 184kB/s]
 73%|#######2  | 7.21M/9.91M [00:30<00:15, 180kB/s]
 73%|#######3  | 7.27M/9.91M [00:30<00:13, 196kB/s]
 74%|#######3  | 7.31M/9.91M [00:30<00:13, 188kB/s]
 74%|#######4  | 7.37M/9.91M [00:30<00:11, 216kB/s]
 75%|#######5  | 7.44M/9.91M [00:31<00:11, 220kB/s]
 76%|#######5  | 7.50M/9.91M [00:31<00:12, 188kB/s]
 77%|#######6  | 7.60M/9.91M [00:31<00:09, 234kB/s]
 77%|#######7  | 7.67M/9.91M [00:32<00:09, 242kB/s]
 78%|#######7  | 7.70M/9.91M [00:32<00:12, 175kB/s]
 78%|#######8  | 7.77M/9.91M [00:32<00:09, 221kB/s]
 79%|#######8  | 7.80M/9.91M [00:32<00:09, 221kB/s]
 79%|#######9  | 7.83M/9.91M [00:33<00:10, 194kB/s]
 79%|#######9  | 7.86M/9.91M [00:33<00:14, 141kB/s]
 80%|#######9  | 7.93M/9.91M [00:33<00:11, 175kB/s]
 80%|########  | 7.96M/9.91M [00:34<00:11, 164kB/s]
 81%|########  | 8.00M/9.91M [00:34<00:13, 147kB/s]
 81%|########  | 8.03M/9.91M [00:34<00:14, 134kB/s]
 81%|########1 | 8.06M/9.91M [00:34<00:13, 138kB/s]
 82%|########1 | 8.09M/9.91M [00:35<00:15, 116kB/s]
 82%|########1 | 8.13M/9.91M [00:35<00:15, 115kB/s]
 82%|########2 | 8.16M/9.91M [00:35<00:14, 119kB/s]
 83%|########2 | 8.19M/9.91M [00:35<00:13, 131kB/s]
 83%|########2 | 8.22M/9.91M [00:36<00:13, 125kB/s]
 83%|########3 | 8.26M/9.91M [00:36<00:13, 127kB/s]
 84%|########3 | 8.29M/9.91M [00:36<00:13, 117kB/s]
 84%|########3 | 8.32M/9.91M [00:37<00:14, 109kB/s]
 84%|########4 | 8.36M/9.91M [00:37<00:14, 106kB/s]
 85%|########4 | 8.39M/9.91M [00:37<00:13, 113kB/s]
 85%|########4 | 8.42M/9.91M [00:38<00:15, 97.3kB/s]
 85%|########5 | 8.45M/9.91M [00:38<00:14, 102kB/s]
 86%|########5 | 8.49M/9.91M [00:38<00:12, 115kB/s]
 86%|########5 | 8.52M/9.91M [00:38<00:11, 121kB/s]
 86%|########6 | 8.55M/9.91M [00:39<00:11, 118kB/s]
 87%|########6 | 8.59M/9.91M [00:39<00:11, 116kB/s]
 87%|########6 | 8.62M/9.91M [00:39<00:10, 121kB/s]
 87%|########7 | 8.65M/9.91M [00:40<00:11, 115kB/s]
 88%|########7 | 8.68M/9.91M [00:40<00:11, 111kB/s]
 88%|########7 | 8.72M/9.91M [00:40<00:10, 118kB/s]
 89%|########8 | 8.78M/9.91M [00:40<00:07, 145kB/s]
 89%|########8 | 8.81M/9.91M [00:41<00:09, 113kB/s]
 89%|########9 | 8.85M/9.91M [00:41<00:08, 128kB/s]
 90%|########9 | 8.88M/9.91M [00:41<00:08, 119kB/s]
 90%|########9 | 8.91M/9.91M [00:42<00:09, 105kB/s]
 90%|######### | 8.95M/9.91M [00:42<00:08, 112kB/s]
 91%|######### | 8.98M/9.91M [00:42<00:07, 118kB/s]
 91%|######### | 9.01M/9.91M [00:43<00:07, 118kB/s]
 91%|#########1| 9.04M/9.91M [00:43<00:06, 138kB/s]
 92%|#########1| 9.08M/9.91M [00:43<00:06, 129kB/s]
 92%|#########1| 9.11M/9.91M [00:43<00:06, 129kB/s]
 92%|#########2| 9.14M/9.91M [00:44<00:05, 129kB/s]
 93%|#########2| 9.18M/9.91M [00:44<00:05, 130kB/s]
 93%|#########2| 9.21M/9.91M [00:44<00:05, 132kB/s]
 93%|#########3| 9.24M/9.91M [00:44<00:05, 134kB/s]
 94%|#########3| 9.27M/9.91M [00:45<00:04, 134kB/s]
 94%|#########3| 9.31M/9.91M [00:45<00:03, 160kB/s]
 94%|#########4| 9.34M/9.91M [00:45<00:03, 151kB/s]
 95%|#########4| 9.37M/9.91M [00:45<00:03, 135kB/s]
 95%|#########4| 9.40M/9.91M [00:45<00:03, 134kB/s]
 95%|#########5| 9.44M/9.91M [00:46<00:03, 135kB/s]
 96%|#########5| 9.47M/9.91M [00:46<00:03, 136kB/s]
 96%|#########5| 9.50M/9.91M [00:46<00:02, 162kB/s]
 96%|#########6| 9.54M/9.91M [00:46<00:02, 152kB/s]
 97%|#########6| 9.57M/9.91M [00:47<00:02, 120kB/s]
 97%|#########7| 9.63M/9.91M [00:47<00:01, 149kB/s]
 98%|#########7| 9.67M/9.91M [00:47<00:01, 145kB/s]
 98%|#########7| 9.70M/9.91M [00:48<00:01, 141kB/s]
 98%|#########8| 9.73M/9.91M [00:48<00:01, 139kB/s]
 99%|#########8| 9.76M/9.91M [00:48<00:01, 139kB/s]
 99%|#########8| 9.80M/9.91M [00:48<00:00, 137kB/s]
 99%|#########9| 9.83M/9.91M [00:49<00:00, 109kB/s]
100%|##########| 9.91M/9.91M [00:49<00:00, 161kB/s]
100%|##########| 9.91M/9.91M [00:49<00:00, 200kB/s]

  0%|          | 0.00/28.9k [00:00<?, ?B/s]
100%|##########| 28.9k/28.9k [00:00<00:00, 51.8kB/s]
100%|##########| 28.9k/28.9k [00:00<00:00, 51.8kB/s]

  0%|          | 0.00/1.65M [00:00<?, ?B/s]
  2%|1         | 32.8k/1.65M [00:00<00:12, 127kB/s]
  4%|3         | 65.5k/1.65M [00:00<00:12, 127kB/s]
  8%|7         | 131k/1.65M [00:00<00:08, 185kB/s]
 20%|#9        | 328k/1.65M [00:01<00:03, 411kB/s]
 38%|###7      | 623k/1.65M [00:01<00:01, 674kB/s]
 76%|#######5  | 1.25M/1.65M [00:01<00:00, 1.26MB/s]
100%|##########| 1.65M/1.65M [00:01<00:00, 1.06MB/s]

  0%|          | 0.00/4.54k [00:00<?, ?B/s]
100%|##########| 4.54k/4.54k [00:00<00:00, 9.42MB/s]

空间变换网络的图示

空间变换网络归结为以下三个主要组件:

  • 定位网络是一个常规的CNN,它对变换参数进行回归。此数据集不会显式地学习变换,而是网络会自动学习增强整体准确性的空间变换。

  • 网格生成器生成一个输入图像中的坐标网格,与输出图像的每个像素对应。

  • 采样器使用变换的参数并将其应用于输入图像。

../_images/stn-arch.png

备注

我们需要最新版本的PyTorch,其中包含affine_grid和grid_sample模块。

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

        # Spatial transformer localization-network
        self.localization = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=7),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True),
            nn.Conv2d(8, 10, kernel_size=5),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True)
        )

        # Regressor for the 3 * 2 affine matrix
        self.fc_loc = nn.Sequential(
            nn.Linear(10 * 3 * 3, 32),
            nn.ReLU(True),
            nn.Linear(32, 3 * 2)
        )

        # Initialize the weights/bias with identity transformation
        self.fc_loc[2].weight.data.zero_()
        self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))

    # Spatial transformer network forward function
    def stn(self, x):
        xs = self.localization(x)
        xs = xs.view(-1, 10 * 3 * 3)
        theta = self.fc_loc(xs)
        theta = theta.view(-1, 2, 3)

        grid = F.affine_grid(theta, x.size())
        x = F.grid_sample(x, grid)

        return x

    def forward(self, x):
        # transform the input
        x = self.stn(x)

        # Perform the usual forward pass
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


model = Net().to(device)

训练模型

现在,让我们使用SGD算法来训练模型。网络以监督方式学习分类任务。同时,模型以端到端方式自动学习STN。

optimizer = optim.SGD(model.parameters(), lr=0.01)


def train(epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 500 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
#
# A simple test procedure to measure the STN performances on MNIST.
#


def test():
    with torch.no_grad():
        model.eval()
        test_loss = 0
        correct = 0
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)

            # sum up batch loss
            test_loss += F.nll_loss(output, target, size_average=False).item()
            # get the index of the max log-probability
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(test_loader.dataset)
        print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'
              .format(test_loss, correct, len(test_loader.dataset),
                      100. * correct / len(test_loader.dataset)))

可视化STN的结果

现在,我们将检查我们学习的视觉注意机制的结果。

我们定义一个小的辅助函数以便在训练时可视化变换。

def convert_image_np(inp):
    """Convert a Tensor to numpy image."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    return inp

# We want to visualize the output of the spatial transformers layer
# after the training, we visualize a batch of input images and
# the corresponding transformed batch using STN.


def visualize_stn():
    with torch.no_grad():
        # Get a batch of training data
        data = next(iter(test_loader))[0].to(device)

        input_tensor = data.cpu()
        transformed_input_tensor = model.stn(data).cpu()

        in_grid = convert_image_np(
            torchvision.utils.make_grid(input_tensor))

        out_grid = convert_image_np(
            torchvision.utils.make_grid(transformed_input_tensor))

        # Plot the results side-by-side
        f, axarr = plt.subplots(1, 2)
        axarr[0].imshow(in_grid)
        axarr[0].set_title('Dataset Images')

        axarr[1].imshow(out_grid)
        axarr[1].set_title('Transformed Images')

for epoch in range(1, 20 + 1):
    train(epoch)
    test()

# Visualize the STN transformation on some input batch
visualize_stn()

plt.ioff()
plt.show()
Dataset Images, Transformed Images
/data1/lin/pytorch-tutorials/.venv/lib/python3.10/site-packages/torch/nn/functional.py:5082: UserWarning:

Default grid_sample and affine_grid behavior has changed to align_corners=False since 1.3.0. Please specify align_corners=True if the old behavior is desired. See the documentation of grid_sample for details.

/data1/lin/pytorch-tutorials/.venv/lib/python3.10/site-packages/torch/nn/functional.py:5015: UserWarning:

Default grid_sample and affine_grid behavior has changed to align_corners=False since 1.3.0. Please specify align_corners=True if the old behavior is desired. See the documentation of grid_sample for details.

Train Epoch: 1 [0/60000 (0%)]   Loss: 2.315377
Train Epoch: 1 [32000/60000 (53%)]      Loss: 0.805421
/data1/lin/pytorch-tutorials/.venv/lib/python3.10/site-packages/torch/nn/_reduction.py:51: UserWarning:

size_average and reduce args will be deprecated, please use reduction='sum' instead.


Test set: Average loss: 0.2426, Accuracy: 9317/10000 (93%)

Train Epoch: 2 [0/60000 (0%)]   Loss: 0.580400
Train Epoch: 2 [32000/60000 (53%)]      Loss: 0.477957

Test set: Average loss: 0.1199, Accuracy: 9636/10000 (96%)

Train Epoch: 3 [0/60000 (0%)]   Loss: 0.240618
Train Epoch: 3 [32000/60000 (53%)]      Loss: 0.168760

Test set: Average loss: 0.0947, Accuracy: 9712/10000 (97%)

Train Epoch: 4 [0/60000 (0%)]   Loss: 0.110410
Train Epoch: 4 [32000/60000 (53%)]      Loss: 0.183183

Test set: Average loss: 0.0850, Accuracy: 9726/10000 (97%)

Train Epoch: 5 [0/60000 (0%)]   Loss: 0.463134
Train Epoch: 5 [32000/60000 (53%)]      Loss: 0.134187

Test set: Average loss: 0.0665, Accuracy: 9801/10000 (98%)

Train Epoch: 6 [0/60000 (0%)]   Loss: 0.275155
Train Epoch: 6 [32000/60000 (53%)]      Loss: 0.342379

Test set: Average loss: 0.0759, Accuracy: 9767/10000 (98%)

Train Epoch: 7 [0/60000 (0%)]   Loss: 0.135867
Train Epoch: 7 [32000/60000 (53%)]      Loss: 0.076504

Test set: Average loss: 0.0579, Accuracy: 9828/10000 (98%)

Train Epoch: 8 [0/60000 (0%)]   Loss: 0.032235
Train Epoch: 8 [32000/60000 (53%)]      Loss: 0.296299

Test set: Average loss: 0.0536, Accuracy: 9838/10000 (98%)

Train Epoch: 9 [0/60000 (0%)]   Loss: 0.161988
Train Epoch: 9 [32000/60000 (53%)]      Loss: 0.034477

Test set: Average loss: 0.0525, Accuracy: 9836/10000 (98%)

Train Epoch: 10 [0/60000 (0%)]  Loss: 0.078196
Train Epoch: 10 [32000/60000 (53%)]     Loss: 0.130281

Test set: Average loss: 0.0928, Accuracy: 9743/10000 (97%)

Train Epoch: 11 [0/60000 (0%)]  Loss: 0.353375
Train Epoch: 11 [32000/60000 (53%)]     Loss: 0.418211

Test set: Average loss: 0.0429, Accuracy: 9868/10000 (99%)

Train Epoch: 12 [0/60000 (0%)]  Loss: 0.036706
Train Epoch: 12 [32000/60000 (53%)]     Loss: 0.090855

Test set: Average loss: 0.0426, Accuracy: 9874/10000 (99%)

Train Epoch: 13 [0/60000 (0%)]  Loss: 0.242624
Train Epoch: 13 [32000/60000 (53%)]     Loss: 0.192313

Test set: Average loss: 0.0431, Accuracy: 9861/10000 (99%)

Train Epoch: 14 [0/60000 (0%)]  Loss: 0.114481
Train Epoch: 14 [32000/60000 (53%)]     Loss: 0.027952

Test set: Average loss: 0.0411, Accuracy: 9873/10000 (99%)

Train Epoch: 15 [0/60000 (0%)]  Loss: 0.014810
Train Epoch: 15 [32000/60000 (53%)]     Loss: 0.044349

Test set: Average loss: 0.0489, Accuracy: 9857/10000 (99%)

Train Epoch: 16 [0/60000 (0%)]  Loss: 0.017412
Train Epoch: 16 [32000/60000 (53%)]     Loss: 0.157169

Test set: Average loss: 0.0461, Accuracy: 9869/10000 (99%)

Train Epoch: 17 [0/60000 (0%)]  Loss: 0.082671
Train Epoch: 17 [32000/60000 (53%)]     Loss: 0.015023

Test set: Average loss: 0.0376, Accuracy: 9891/10000 (99%)

Train Epoch: 18 [0/60000 (0%)]  Loss: 0.115485
Train Epoch: 18 [32000/60000 (53%)]     Loss: 0.086408

Test set: Average loss: 0.0697, Accuracy: 9800/10000 (98%)

Train Epoch: 19 [0/60000 (0%)]  Loss: 0.092634
Train Epoch: 19 [32000/60000 (53%)]     Loss: 0.016713

Test set: Average loss: 0.0408, Accuracy: 9875/10000 (99%)

Train Epoch: 20 [0/60000 (0%)]  Loss: 0.033407
Train Epoch: 20 [32000/60000 (53%)]     Loss: 0.013921

Test set: Average loss: 0.0522, Accuracy: 9840/10000 (98%)

Total running time of the script: ( 3 minutes 12.415 seconds)

通过Sphinx-Gallery生成的图集

文档

访问 PyTorch 的详细开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源