Shortcuts

DCGAN 教程

Created On: Jul 31, 2018 | Last Updated: Jan 19, 2024 | Last Verified: Nov 05, 2024

作者: Nathan Inkawhich

介绍

本教程将通过一个示例介绍 DCGANs。我们将训练一个生成对抗网络(GAN),在向其展示许多真实明星的图片后,让其生成新的明星图片。此处的大部分代码来自 pytorch/examples 中的 DCGAN 实现,本文档将详细说明实现过程并阐明该模型如何以及为何能有效工作。不过请放心,不需要提前了解 GANs,但首次接触的人可能需要一些时间来思考实际发生的事情。此外,为节省时间,最好拥有一块或两块 GPU。让我们从头开始。

生成对抗网络

什么是 GAN?

GAN 是一种框架,可以教深度学习模型捕获训练数据分布,从而从该分布中生成新数据。GAN 由 Ian Goodfellow 在 2014 年发明,并在论文 Generative Adversarial Nets 中首次描述。它们由两个不同的模型组成,一个是 生成器,另一个是 判别器。生成器的任务是生成“假”的图像,使其看起来像训练图像。判别器的任务是检验图像并输出该图像是否是真实的训练图像还是来自生成器的假图像。在训练过程中,生成器不断尝试通过生成越来越好的“假”图像来击败判别器,而判别器则努力提高自己的判别能力,正确分类真实图像和假图像。这个游戏的均衡点是生成器生成的假图像看起来好像直接来自于训练数据,而判别器总是以 50% 的信心猜测生成器的输出是真还是假。

现在,让我们定义教程中使用的一些符号,从判别器开始。设 \(x\) 为一个表示图像的数据。\(D(x)\) 是判别器网络,其输出该图像来自训练数据而不是生成器的概率(标量)。这里,由于我们处理的是图像,\(D(x)\) 的输入是一个 CHW 尺寸为 3x64x64 的图像。从直观上看,当 \(x\) 来自训练数据时 \(D(x)\) 应该高,当 \(x\) 来自生成器时 \(D(x)\) 应该低。\(D(x)\) 也可以被看作一个传统的二分类器。

对于生成器的符号,设 \(z\) 为从标准正态分布采样的潜在空间向量。\(G(z)\) 表示将潜在向量 \(z\) 映射到数据空间的生成器函数。\(G\) 的目标是估计训练数据的分布 (\(p_{data}\)),从而可以从这个估计的分布 (\(p_g\)) 中生成假样本。

因此,\(D(G(z))\) 是判别器 \(G\) 的输出为真实图像的概率(标量)。如 Goodfellow 的论文 所述,\(D\)\(G\) 玩的是一个极小极大博弈,其中 \(D\) 尝试最大化其正确分类真实样本和假样本的概率 (\(logD(x)\)),而 \(G\) 尝试最小化 \(D\) 预测其输出为假的概率 (\(log(1-D(G(z)))\))。根据论文中的描述,GAN 的损失函数为

\[\underset{G}{\text{min}} \underset{D}{\text{max}}V(D,G) = \mathbb{E}_{x\sim p_{data}(x)}\big[logD(x)\big] + \mathbb{E}_{z\sim p_{z}(z)}\big[log(1-D(G(z)))\big] \]

理论上,这个极小极大游戏的解是:\(p_g = p_{data}\),并且判别器随机猜测输入是真实的还是虚假的。然而,GAN的收敛理论仍在积极研究中,实际上模型不一定总能训练到这个点。

什么是DCGAN?

DCGAN是上述GAN的直接扩展,但显式地在判别器和生成器中分别使用了卷积层和反卷积层。它最初由Radford等人在论文《使用深度卷积生成对抗网络进行无监督表示学习 <https://arxiv.org/pdf/1511.06434.pdf>`__ 中描述。判别器由步长的`卷积 <https://pytorch.org/docs/stable/nn.html#torch.nn.Conv2d>`__层、批标准化 <https://pytorch.org/docs/stable/nn.html#torch.nn.BatchNorm2d>`__层和`LeakyReLU <https://pytorch.org/docs/stable/nn.html#torch.nn.LeakyReLU>`__激活组成。输入是一个3x64x64的图像,输出是该输入来自真实数据分布的标量概率。生成器由`反卷积 <https://pytorch.org/docs/stable/nn.html#torch.nn.ConvTranspose2d>`__层、批标准化层和`ReLU <https://pytorch.org/docs/stable/nn.html#relu>`__激活组成。输入是一个从标准正态分布中抽取的潜向量:math:`z,输出是一个3x64x64的RGB图像。步长反卷积层允许将潜向量转换为与图像相同形状的体积。在论文中,作者还提供了一些关于设置优化器、计算损失函数和初始化模型权重的提示,这些将在接下来的章节中解释。

#%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

# Set random seed for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
torch.use_deterministic_algorithms(True) # Needed for reproducible results
Random Seed:  999

输入

让我们为运行定义一些输入:

  • dataroot - 数据集文件夹的根路径。我们将在下一节中详细讨论数据集。

  • workers - 使用``DataLoader``加载数据的工作线程数量。

  • batch_size - 训练中使用的批量大小。DCGAN论文使用128的批量大小。

  • image_size - 用于训练的图像的空间大小。此实现默认使用64x64。如果需要其他大小,则必须更改D和G的结构。更多详细信息,请参见`这里 <https://github.com/pytorch/examples/issues/70>`__。

  • nc - 输入图像中的颜色通道数量。对于彩色图像,这个值是3。

  • nz - 潜向量的长度。

  • ngf - 生成器中传播的特征图深度。

  • ndf - 判别器中传播的特征图深度。

  • num_epochs - 运行的训练轮数。训练时间越长可能会产生更好的结果,但也需要更长的时间。

  • lr - 训练的学习率。根据DCGAN论文,该值应为0.0002。

  • beta1 - Adam优化器的beta1超参数。根据论文,该值应为0.5。

  • ngpu - 可用的GPU数量。如果是0,代码将在CPU模式下运行。如果这个数字大于0,它将运行相应数量的GPU。

# Root directory for dataset
dataroot = "data/celeba"

# Number of workers for dataloader
workers = 2

# Batch size during training
batch_size = 128

# Spatial size of training images. All images will be resized to this
#   size using a transformer.
image_size = 64

# Number of channels in the training images. For color images this is 3
nc = 3

# Size of z latent vector (i.e. size of generator input)
nz = 100

# Size of feature maps in generator
ngf = 64

# Size of feature maps in discriminator
ndf = 64

# Number of training epochs
num_epochs = 5

# Learning rate for optimizers
lr = 0.0002

# Beta1 hyperparameter for Adam optimizers
beta1 = 0.5

# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1

数据

在本教程中,我们将使用`Celeb-A人脸数据集 <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>`__,可以从链接网站或`Google Drive <https://drive.google.com/drive/folders/0B7EVK8r0v71pTUZsaXdaSnZBZzg>`__下载。该数据集将下载为一个名为``img_align_celeba.zip``的文件。下载后,创建一个名为``celeba``的目录,并将压缩文件解压到该目录。然后,将此笔记本的``dataroot``输入设置为您刚刚创建的``celeba``目录。生成的目录结构应该为:

/path/to/celeba
    -> img_align_celeba
        -> 188242.jpg
        -> 173822.jpg
        -> 284702.jpg
        -> 537394.jpg
           ...

这是一个重要步骤,因为我们将使用``ImageFolder``数据集类,该类要求数据集根文件夹中有子目录。现在,我们可以创建数据集,创建数据加载器,设置运行设备,并最终可视化一些训练数据。

# We can use an image folder dataset the way we have it setup.
# Create the dataset
dataset = dset.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers)

# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

# Plot some training images
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
plt.show()
Training Images

实现

设置好输入参数并准备好数据集后,我们现在可以进入实现过程。我们将从权重初始化策略开始,然后详细讨论生成器、判别器、损失函数和训练循环。

权重初始化

根据DCGAN论文,作者指定所有模型权重应从均值为0、标准差为0.02的正态分布中随机初始化。``weights_init``函数以初始化模型作为输入,并重新初始化所有卷积、反卷积以及批标准化层以满足这些标准。此函数在模型初始化后立即应用。

# custom weights initialization called on ``netG`` and ``netD``
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

生成器

生成器:math:G`被设计为将潜向量:math:`z`映射到数据空间。由于我们的数据是图像,将:math:`z`转换为数据空间意味着最终生成与训练图像大小相同的RGB图像(即3x64x64)。在实践中,这是通过一系列步长二维反卷积层完成的,每一层都与一个二维批标准化层和一个ReLU激活函数配对。生成器的输出通过一个tanh函数最终返回到输入数据范围:math:`[-1,1]。值得注意的是,反卷积层后的批标准化函数是DCGAN论文的重要贡献之一。这些层有助于训练期间的梯度流动。以下是DCGAN论文中生成器的示意图。

dcgan_generator

请注意,我们在输入章节中设置的输入(nzngf``和``nc)如何影响生成器架构代码。``nz``是z输入向量的长度,``ngf``与生成器中传播的特征图大小有关,``nc``是输出图像中的通道数量(对于RGB图像为3)。下面是生成器的代码。

# Generator Code

class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. ``(ngf*8) x 4 x 4``
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. ``(ngf*4) x 8 x 8``
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. ``(ngf*2) x 16 x 16``
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. ``(ngf) x 32 x 32``
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. ``(nc) x 64 x 64``
        )

    def forward(self, input):
        return self.main(input)

现在,我们可以实例化生成器并应用``weights_init``函数。检查打印出的模型以查看生成器对象的结构。

# Create the generator
netG = Generator(ngpu).to(device)

# Handle multi-GPU if desired
if (device.type == 'cuda') and (ngpu > 1):
    netG = nn.DataParallel(netG, list(range(ngpu)))

# Apply the ``weights_init`` function to randomly initialize all weights
#  to ``mean=0``, ``stdev=0.02``.
netG.apply(weights_init)

# Print the model
print(netG)
Generator(
  (main): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)

判别器

如前所述,判别器:math:`D`是一个二进制分类网络,它以图像作为输入并输出该输入图像是真实(而非虚假)的标量概率。在这里,:math:`D`接受一个3x64x64的输入图像,通过一系列Conv2d、BatchNorm2d和LeakyReLU层进行处理,并通过Sigmoid激活函数输出最终概率。如果必要,可以扩展此架构以包含更多层,但使用步长卷积、BatchNorm和LeakyReLU功能具有重要意义。DCGAN论文提到,使用步长卷积而不是池化来下采样是一个好的实践,因为它允许网络学习自己的池化功能。此外,批标准化和LeakyReLU函数促进了健康梯度流动,这对于:math:`G`和:math:`D`的学习过程至关重要。

判别器代码

class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is ``(nc) x 64 x 64``
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. ``(ndf) x 32 x 32``
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. ``(ndf*2) x 16 x 16``
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. ``(ndf*4) x 8 x 8``
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. ``(ndf*8) x 4 x 4``
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

现在,与生成器一样,我们可以创建判别器,应用``weights_init``函数,并打印模型的结构。

# Create the Discriminator
netD = Discriminator(ngpu).to(device)

# Handle multi-GPU if desired
if (device.type == 'cuda') and (ngpu > 1):
    netD = nn.DataParallel(netD, list(range(ngpu)))

# Apply the ``weights_init`` function to randomly initialize all weights
# like this: ``to mean=0, stdev=0.2``.
netD.apply(weights_init)

# Print the model
print(netD)
Discriminator(
  (main): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (12): Sigmoid()
  )
)

损失函数和优化器

设置好:math:D`和:math:`G`后,我们可以通过损失函数和优化器指定它们如何学习。我们将使用二进制交叉熵损失(`BCELoss)函数,该函数在PyTorch中定义为:

\[\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad l_n = - \left[ y_n \cdot \log x_n + (1 - y_n) \cdot \log (1 - x_n) \right] \]

注意此函数如何提供目标函数中两个对数成分的计算(即:math:log(D(x))`和:math:`log(1-D(G(z))))。我们可以通过:math:`y`输入指定要使用BCE公式的哪个部分。这将在即将到来的训练循环中完成,但重要的是理解我们可以仅通过更改:math:`y`(即GT标签)来选择我们希望计算的组件。

接下来,我们定义真实标签为1,虚假标签为0。这些标签将在计算:math:D`和:math:`G`的损失时使用,这也是原始GAN论文中使用的约定。最后,我们设置两个独立的优化器,一个用于:math:`D,一个用于:math:G。按照DCGAN论文中的规定,两者均使用Adam优化器,学习率为0.0002,Beta1为0.5。为了跟踪生成器的学习进度,我们将生成一批固定的潜向量,这些向量是从高斯分布中抽取的(即固定噪声)。在训练循环中,我们将定期将此固定噪声输入至:math:G,随着迭代次数的增加,我们会看到从噪声中形成图像。

# Initialize the ``BCELoss`` function
criterion = nn.BCELoss()

# Create batch of latent vectors that we will use to visualize
#  the progression of the generator
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

# Establish convention for real and fake labels during training
real_label = 1.
fake_label = 0.

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

训练

最终,现在我们已经定义了GAN框架的所有部分,我们可以对其进行训练。请注意,训练GAN在某种程度上是一门艺术,错误的超参数设置可能会导致模式崩溃,而难以解释问题的原因。在这里,我们将严格遵循`Goodfellow论文中的算法1 <https://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf>`__,同时遵守`ganhacks <https://github.com/soumith/ganhacks>`__中提供的一些最佳实践。即,我们将“为真实和虚假图像构建不同的小批量”,还调整生成器的目标函数以最大化:math:log(D(G(z)))。训练分为两个主要部分。第一部分更新判别器,第二部分更新生成器。

第1部分 - 训练判别器

回顾一下,训练判别器的目标是最大化正确分类给定输入是真实还是虚假样本的概率。根据Goodfellow的说法,我们希望“通过其随机梯度的上升来更新判别器”。实际上,我们想最大化 \(log(D(x)) + log(1-D(G(z)))\)。根据`ganhacks <https://github.com/soumith/ganhacks>`__ 的分离小批次建议,我们将分两步计算。首先,我们将从训练集构建一个真实样本的批次,正向通过 \(D\),计算损失 (\(log(D(x))\)),然后通过反向传播计算梯度。其次,我们将用当前生成器构建一个假样本的批次,正向通过 \(D\),计算损失 (\(log(1-D(G(z)))\)),并*累积*梯度通过反向传播。现在,利用从全真实和全虚假的批次累积的梯度,我们调用判别器优化器的一步。

第2部分 - 训练生成器

如原论文所述,我们希望通过最小化 \(log(1-D(G(z)))\) 来训练生成器,以生成更好的虚假样本。如前所述,Goodfellow证明了这在学习早期不能提供足够的梯度。作为一种修正,我们转而希望最大化 \(log(D(G(z)))\)。在代码中,我们通过以下方式实现:用判别器对第1部分生成器的输出进行分类,使用真实标签作为GT计算生成器的损失,用反向传播计算生成器的梯度,最后用优化器步骤更新生成器的参数。使用真实标签作为损失函数的GT标签看起来可能有悖直觉,但这允许我们使用``BCELoss``的 \(log(x)\) 部分(而不是 \(log(1-x)\) 部分),这正是我们需要的。

最后,我们将进行一些统计报告,并在每个训练周期结束时通过生成器推送固定噪声批次,以可视化生成器训练的进展。报告的训练统计包括:

  • Loss_D - 判别器损失,计算为全真实和全虚假批次的损失总和 (\(log(D(x)) + log(1 - D(G(z)))\))。

  • Loss_G - 生成器损失,计算为 \(log(D(G(z)))\)

  • D(x) - 判别器的平均输出(跨越批次)针对所有真实批次。这应该一开始接近1,然后理论上当生成器变得更好时收敛到0.5。思考为什么会这样。

  • D(G(z)) - 针对所有虚假批次的判别器平均输出。第一个数字是在判别器更新之前,第二个数字是在判别器更新之后。这些数字一开始应该接近0,随着生成器变得更好收敛到0.5。思考为什么会这样。

注意: 根据你运行的训练周期数量以及是否从数据集中删除了一些数据,这一步可能需要一段时间。

# Training Loop

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):

        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()
        # Format batch
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        # Forward pass real batch through D
        output = netD(real_cpu).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # Generate fake image batch with G
        fake = netG(noise)
        label.fill_(fake_label)
        # Classify all fake batch with D
        output = netD(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch, accumulated (summed) with previous gradients
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Compute error of D as sum over the fake and the real batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()

        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters += 1
Starting Training Loop...
[0/5][0/1583]   Loss_D: 1.9253  Loss_G: 4.2295  D(x): 0.4681    D(G(z)): 0.5867 / 0.0234
[0/5][50/1583]  Loss_D: 0.1787  Loss_G: 19.3262 D(x): 0.9021    D(G(z)): 0.0000 / 0.0000
[0/5][100/1583] Loss_D: 0.8667  Loss_G: 11.6657 D(x): 0.8051    D(G(z)): 0.0001 / 0.0008
[0/5][150/1583] Loss_D: 1.3525  Loss_G: 12.1388 D(x): 0.9505    D(G(z)): 0.6232 / 0.0000
[0/5][200/1583] Loss_D: 0.2267  Loss_G: 3.0958  D(x): 0.9012    D(G(z)): 0.0884 / 0.0801
[0/5][250/1583] Loss_D: 0.5070  Loss_G: 5.4072  D(x): 0.8867    D(G(z)): 0.2253 / 0.0176
[0/5][300/1583] Loss_D: 1.6674  Loss_G: 3.8464  D(x): 0.3278    D(G(z)): 0.0075 / 0.0428
[0/5][350/1583] Loss_D: 0.4727  Loss_G: 3.7289  D(x): 0.7323    D(G(z)): 0.0545 / 0.0510
[0/5][400/1583] Loss_D: 0.6952  Loss_G: 2.7046  D(x): 0.6052    D(G(z)): 0.0549 / 0.0986
[0/5][450/1583] Loss_D: 0.5542  Loss_G: 3.0846  D(x): 0.7670    D(G(z)): 0.1568 / 0.0664
[0/5][500/1583] Loss_D: 1.0535  Loss_G: 2.1153  D(x): 0.4753    D(G(z)): 0.0334 / 0.1758
[0/5][550/1583] Loss_D: 0.4631  Loss_G: 3.7431  D(x): 0.7778    D(G(z)): 0.1181 / 0.0379
[0/5][600/1583] Loss_D: 0.8609  Loss_G: 3.8593  D(x): 0.5462    D(G(z)): 0.0152 / 0.0599
[0/5][650/1583] Loss_D: 0.4497  Loss_G: 3.3994  D(x): 0.8336    D(G(z)): 0.1825 / 0.0557
[0/5][700/1583] Loss_D: 1.1335  Loss_G: 6.6631  D(x): 0.9307    D(G(z)): 0.5644 / 0.0042
[0/5][750/1583] Loss_D: 0.9320  Loss_G: 2.4560  D(x): 0.5228    D(G(z)): 0.0341 / 0.1436
[0/5][800/1583] Loss_D: 0.3762  Loss_G: 3.7155  D(x): 0.7771    D(G(z)): 0.0367 / 0.0394
[0/5][850/1583] Loss_D: 0.4964  Loss_G: 3.8047  D(x): 0.8701    D(G(z)): 0.2378 / 0.0425
[0/5][900/1583] Loss_D: 0.8660  Loss_G: 2.5298  D(x): 0.5728    D(G(z)): 0.0613 / 0.1060
[0/5][950/1583] Loss_D: 0.6210  Loss_G: 5.1183  D(x): 0.8608    D(G(z)): 0.3201 / 0.0113
[0/5][1000/1583]        Loss_D: 0.4514  Loss_G: 3.0069  D(x): 0.8198    D(G(z)): 0.1793 / 0.0816
[0/5][1050/1583]        Loss_D: 0.9952  Loss_G: 2.4624  D(x): 0.5886    D(G(z)): 0.2386 / 0.1327
[0/5][1100/1583]        Loss_D: 0.3753  Loss_G: 3.4336  D(x): 0.8538    D(G(z)): 0.1545 / 0.0491
[0/5][1150/1583]        Loss_D: 1.2277  Loss_G: 8.3980  D(x): 0.9757    D(G(z)): 0.6354 / 0.0004
[0/5][1200/1583]        Loss_D: 0.4791  Loss_G: 4.4621  D(x): 0.9104    D(G(z)): 0.2816 / 0.0174
[0/5][1250/1583]        Loss_D: 0.9103  Loss_G: 2.8637  D(x): 0.6830    D(G(z)): 0.3030 / 0.0814
[0/5][1300/1583]        Loss_D: 0.4706  Loss_G: 3.4780  D(x): 0.7263    D(G(z)): 0.0537 / 0.0564
[0/5][1350/1583]        Loss_D: 0.6167  Loss_G: 2.9633  D(x): 0.7014    D(G(z)): 0.1246 / 0.0843
[0/5][1400/1583]        Loss_D: 0.4402  Loss_G: 3.6633  D(x): 0.8045    D(G(z)): 0.1519 / 0.0405
[0/5][1450/1583]        Loss_D: 0.6162  Loss_G: 3.2270  D(x): 0.7262    D(G(z)): 0.1551 / 0.0698
[0/5][1500/1583]        Loss_D: 0.6604  Loss_G: 5.0311  D(x): 0.8329    D(G(z)): 0.3040 / 0.0142
[0/5][1550/1583]        Loss_D: 0.6295  Loss_G: 2.5956  D(x): 0.7470    D(G(z)): 0.1827 / 0.1079
[1/5][0/1583]   Loss_D: 0.4805  Loss_G: 3.0063  D(x): 0.7542    D(G(z)): 0.0939 / 0.0826
[1/5][50/1583]  Loss_D: 0.5213  Loss_G: 3.1918  D(x): 0.7881    D(G(z)): 0.1799 / 0.0701
[1/5][100/1583] Loss_D: 0.8115  Loss_G: 6.1316  D(x): 0.8694    D(G(z)): 0.4134 / 0.0039
[1/5][150/1583] Loss_D: 0.5532  Loss_G: 4.4881  D(x): 0.8302    D(G(z)): 0.2354 / 0.0201
[1/5][200/1583] Loss_D: 0.3384  Loss_G: 4.1441  D(x): 0.9209    D(G(z)): 0.2018 / 0.0273
[1/5][250/1583] Loss_D: 0.4416  Loss_G: 3.1567  D(x): 0.7591    D(G(z)): 0.0979 / 0.0693
[1/5][300/1583] Loss_D: 0.6491  Loss_G: 4.7122  D(x): 0.9065    D(G(z)): 0.3628 / 0.0185
[1/5][350/1583] Loss_D: 0.4252  Loss_G: 3.0034  D(x): 0.8016    D(G(z)): 0.1348 / 0.0796
[1/5][400/1583] Loss_D: 0.5872  Loss_G: 4.5848  D(x): 0.9056    D(G(z)): 0.3416 / 0.0181
[1/5][450/1583] Loss_D: 0.5208  Loss_G: 3.1924  D(x): 0.6910    D(G(z)): 0.0759 / 0.0614
[1/5][500/1583] Loss_D: 0.6373  Loss_G: 2.3228  D(x): 0.6159    D(G(z)): 0.0194 / 0.1504
[1/5][550/1583] Loss_D: 0.6092  Loss_G: 2.8430  D(x): 0.7126    D(G(z)): 0.1594 / 0.0798
[1/5][600/1583] Loss_D: 2.6392  Loss_G: 1.6702  D(x): 0.1509    D(G(z)): 0.0036 / 0.2652
[1/5][650/1583] Loss_D: 0.6055  Loss_G: 4.3314  D(x): 0.9293    D(G(z)): 0.3700 / 0.0255
[1/5][700/1583] Loss_D: 0.6743  Loss_G: 4.4243  D(x): 0.9419    D(G(z)): 0.4105 / 0.0184
[1/5][750/1583] Loss_D: 0.6271  Loss_G: 2.9614  D(x): 0.7595    D(G(z)): 0.2399 / 0.0767
[1/5][800/1583] Loss_D: 0.4085  Loss_G: 3.9207  D(x): 0.8878    D(G(z)): 0.2179 / 0.0311
[1/5][850/1583] Loss_D: 0.4656  Loss_G: 2.7812  D(x): 0.7499    D(G(z)): 0.1113 / 0.0902
[1/5][900/1583] Loss_D: 0.6448  Loss_G: 4.0957  D(x): 0.8971    D(G(z)): 0.3655 / 0.0274
[1/5][950/1583] Loss_D: 0.5962  Loss_G: 4.6224  D(x): 0.9161    D(G(z)): 0.3504 / 0.0151
[1/5][1000/1583]        Loss_D: 0.4554  Loss_G: 3.5027  D(x): 0.8349    D(G(z)): 0.2024 / 0.0460
[1/5][1050/1583]        Loss_D: 0.3777  Loss_G: 3.6027  D(x): 0.8371    D(G(z)): 0.1483 / 0.0401
[1/5][1100/1583]        Loss_D: 0.8256  Loss_G: 4.3474  D(x): 0.9606    D(G(z)): 0.4579 / 0.0239
[1/5][1150/1583]        Loss_D: 0.6338  Loss_G: 1.8006  D(x): 0.6571    D(G(z)): 0.1016 / 0.2144
[1/5][1200/1583]        Loss_D: 0.4544  Loss_G: 3.9609  D(x): 0.8648    D(G(z)): 0.2375 / 0.0275
[1/5][1250/1583]        Loss_D: 0.4300  Loss_G: 3.2581  D(x): 0.8453    D(G(z)): 0.1992 / 0.0594
[1/5][1300/1583]        Loss_D: 0.3428  Loss_G: 2.8011  D(x): 0.9327    D(G(z)): 0.2062 / 0.0917
[1/5][1350/1583]        Loss_D: 0.6456  Loss_G: 1.5907  D(x): 0.6620    D(G(z)): 0.1341 / 0.2518
[1/5][1400/1583]        Loss_D: 1.0552  Loss_G: 5.5392  D(x): 0.9393    D(G(z)): 0.5810 / 0.0081
[1/5][1450/1583]        Loss_D: 0.5158  Loss_G: 3.9685  D(x): 0.9226    D(G(z)): 0.3171 / 0.0272
[1/5][1500/1583]        Loss_D: 0.5365  Loss_G: 3.8893  D(x): 0.9286    D(G(z)): 0.3350 / 0.0297
[1/5][1550/1583]        Loss_D: 1.7469  Loss_G: 7.0958  D(x): 0.9607    D(G(z)): 0.7453 / 0.0017
[2/5][0/1583]   Loss_D: 0.4801  Loss_G: 2.5083  D(x): 0.7563    D(G(z)): 0.1414 / 0.1120
[2/5][50/1583]  Loss_D: 0.8642  Loss_G: 4.0698  D(x): 0.8873    D(G(z)): 0.4655 / 0.0247
[2/5][100/1583] Loss_D: 0.5755  Loss_G: 3.8060  D(x): 0.9221    D(G(z)): 0.3580 / 0.0294
[2/5][150/1583] Loss_D: 0.5431  Loss_G: 2.7516  D(x): 0.7336    D(G(z)): 0.1651 / 0.0892
[2/5][200/1583] Loss_D: 0.5343  Loss_G: 3.0836  D(x): 0.8583    D(G(z)): 0.2747 / 0.0657
[2/5][250/1583] Loss_D: 0.4806  Loss_G: 2.7586  D(x): 0.8156    D(G(z)): 0.2104 / 0.0845
[2/5][300/1583] Loss_D: 1.3261  Loss_G: 0.8489  D(x): 0.3586    D(G(z)): 0.0284 / 0.4896
[2/5][350/1583] Loss_D: 0.5982  Loss_G: 3.3485  D(x): 0.8648    D(G(z)): 0.3249 / 0.0514
[2/5][400/1583] Loss_D: 0.6146  Loss_G: 3.5353  D(x): 0.9260    D(G(z)): 0.3638 / 0.0412
[2/5][450/1583] Loss_D: 0.6543  Loss_G: 2.2284  D(x): 0.6189    D(G(z)): 0.0859 / 0.1610
[2/5][500/1583] Loss_D: 0.4549  Loss_G: 2.8017  D(x): 0.7619    D(G(z)): 0.1257 / 0.0871
[2/5][550/1583] Loss_D: 0.5540  Loss_G: 1.4729  D(x): 0.6413    D(G(z)): 0.0471 / 0.2910
[2/5][600/1583] Loss_D: 2.1852  Loss_G: 5.1836  D(x): 0.9680    D(G(z)): 0.8271 / 0.0115
[2/5][650/1583] Loss_D: 0.6494  Loss_G: 2.2610  D(x): 0.7654    D(G(z)): 0.2731 / 0.1346
[2/5][700/1583] Loss_D: 0.8246  Loss_G: 1.7544  D(x): 0.5155    D(G(z)): 0.0515 / 0.2318
[2/5][750/1583] Loss_D: 0.5312  Loss_G: 1.7904  D(x): 0.7456    D(G(z)): 0.1756 / 0.1989
[2/5][800/1583] Loss_D: 0.6807  Loss_G: 3.6964  D(x): 0.8213    D(G(z)): 0.3498 / 0.0347
[2/5][850/1583] Loss_D: 0.5764  Loss_G: 3.3782  D(x): 0.8822    D(G(z)): 0.3260 / 0.0465
[2/5][900/1583] Loss_D: 0.5902  Loss_G: 1.6969  D(x): 0.7025    D(G(z)): 0.1623 / 0.2254
[2/5][950/1583] Loss_D: 0.7378  Loss_G: 2.8792  D(x): 0.7970    D(G(z)): 0.3450 / 0.0788
[2/5][1000/1583]        Loss_D: 0.9063  Loss_G: 0.9848  D(x): 0.5777    D(G(z)): 0.2088 / 0.4277
[2/5][1050/1583]        Loss_D: 1.9781  Loss_G: 0.4740  D(x): 0.1853    D(G(z)): 0.0076 / 0.6719
[2/5][1100/1583]        Loss_D: 0.5326  Loss_G: 3.6264  D(x): 0.9161    D(G(z)): 0.3214 / 0.0376
[2/5][1150/1583]        Loss_D: 0.6537  Loss_G: 2.6539  D(x): 0.8052    D(G(z)): 0.3127 / 0.0880
[2/5][1200/1583]        Loss_D: 0.4548  Loss_G: 2.6971  D(x): 0.9091    D(G(z)): 0.2740 / 0.0922
[2/5][1250/1583]        Loss_D: 0.8103  Loss_G: 1.0119  D(x): 0.5446    D(G(z)): 0.0925 / 0.4174
[2/5][1300/1583]        Loss_D: 0.4992  Loss_G: 2.3328  D(x): 0.7667    D(G(z)): 0.1788 / 0.1199
[2/5][1350/1583]        Loss_D: 0.5945  Loss_G: 1.8714  D(x): 0.7370    D(G(z)): 0.2142 / 0.2019
[2/5][1400/1583]        Loss_D: 0.5062  Loss_G: 2.9554  D(x): 0.8657    D(G(z)): 0.2759 / 0.0672
[2/5][1450/1583]        Loss_D: 0.5050  Loss_G: 2.6050  D(x): 0.7379    D(G(z)): 0.1417 / 0.0925
[2/5][1500/1583]        Loss_D: 0.4741  Loss_G: 2.5782  D(x): 0.8164    D(G(z)): 0.2093 / 0.0978
[2/5][1550/1583]        Loss_D: 2.4340  Loss_G: 0.5105  D(x): 0.1405    D(G(z)): 0.0178 / 0.6642
[3/5][0/1583]   Loss_D: 0.5847  Loss_G: 1.8185  D(x): 0.6761    D(G(z)): 0.1390 / 0.2010
[3/5][50/1583]  Loss_D: 0.6756  Loss_G: 1.3954  D(x): 0.6495    D(G(z)): 0.1602 / 0.2877
[3/5][100/1583] Loss_D: 0.9389  Loss_G: 3.5586  D(x): 0.9247    D(G(z)): 0.5097 / 0.0413
[3/5][150/1583] Loss_D: 0.8383  Loss_G: 4.1223  D(x): 0.9423    D(G(z)): 0.4889 / 0.0257
[3/5][200/1583] Loss_D: 0.7028  Loss_G: 1.1357  D(x): 0.5806    D(G(z)): 0.0670 / 0.3769
[3/5][250/1583] Loss_D: 0.8205  Loss_G: 1.5882  D(x): 0.6002    D(G(z)): 0.1722 / 0.2446
[3/5][300/1583] Loss_D: 0.5772  Loss_G: 1.6588  D(x): 0.7126    D(G(z)): 0.1792 / 0.2325
[3/5][350/1583] Loss_D: 0.9131  Loss_G: 2.5469  D(x): 0.7485    D(G(z)): 0.4025 / 0.1120
[3/5][400/1583] Loss_D: 0.7285  Loss_G: 1.4736  D(x): 0.5603    D(G(z)): 0.0560 / 0.2728
[3/5][450/1583] Loss_D: 0.8201  Loss_G: 4.7116  D(x): 0.9017    D(G(z)): 0.4620 / 0.0132
[3/5][500/1583] Loss_D: 0.6197  Loss_G: 2.9266  D(x): 0.8219    D(G(z)): 0.3099 / 0.0695
[3/5][550/1583] Loss_D: 0.5623  Loss_G: 1.9462  D(x): 0.7680    D(G(z)): 0.2197 / 0.1803
[3/5][600/1583] Loss_D: 0.9292  Loss_G: 1.1519  D(x): 0.4686    D(G(z)): 0.0416 / 0.3883
[3/5][650/1583] Loss_D: 0.5886  Loss_G: 2.6010  D(x): 0.7573    D(G(z)): 0.2352 / 0.0955
[3/5][700/1583] Loss_D: 0.4422  Loss_G: 2.2618  D(x): 0.8097    D(G(z)): 0.1775 / 0.1356
[3/5][750/1583] Loss_D: 0.6118  Loss_G: 2.8917  D(x): 0.8792    D(G(z)): 0.3475 / 0.0695
[3/5][800/1583] Loss_D: 0.5473  Loss_G: 1.9801  D(x): 0.7403    D(G(z)): 0.1811 / 0.1653
[3/5][850/1583] Loss_D: 0.6400  Loss_G: 2.7699  D(x): 0.8352    D(G(z)): 0.3345 / 0.0804
[3/5][900/1583] Loss_D: 0.4683  Loss_G: 2.7304  D(x): 0.8466    D(G(z)): 0.2339 / 0.0828
[3/5][950/1583] Loss_D: 1.0093  Loss_G: 5.9043  D(x): 0.9404    D(G(z)): 0.5629 / 0.0047
[3/5][1000/1583]        Loss_D: 0.5349  Loss_G: 2.1615  D(x): 0.7366    D(G(z)): 0.1640 / 0.1430
[3/5][1050/1583]        Loss_D: 1.2765  Loss_G: 0.6246  D(x): 0.3708    D(G(z)): 0.0978 / 0.5693
[3/5][1100/1583]        Loss_D: 0.5150  Loss_G: 3.0683  D(x): 0.8824    D(G(z)): 0.2883 / 0.0633
[3/5][1150/1583]        Loss_D: 0.6427  Loss_G: 2.3221  D(x): 0.6846    D(G(z)): 0.1879 / 0.1345
[3/5][1200/1583]        Loss_D: 1.4129  Loss_G: 0.7791  D(x): 0.3089    D(G(z)): 0.0284 / 0.5125
[3/5][1250/1583]        Loss_D: 0.8410  Loss_G: 0.9022  D(x): 0.5329    D(G(z)): 0.1067 / 0.4592
[3/5][1300/1583]        Loss_D: 0.6202  Loss_G: 2.1313  D(x): 0.6066    D(G(z)): 0.0584 / 0.1690
[3/5][1350/1583]        Loss_D: 1.2004  Loss_G: 0.5458  D(x): 0.3802    D(G(z)): 0.0438 / 0.6172
[3/5][1400/1583]        Loss_D: 0.4449  Loss_G: 3.0477  D(x): 0.8518    D(G(z)): 0.2230 / 0.0635
[3/5][1450/1583]        Loss_D: 0.4845  Loss_G: 2.1626  D(x): 0.7261    D(G(z)): 0.1237 / 0.1463
[3/5][1500/1583]        Loss_D: 0.6804  Loss_G: 2.9004  D(x): 0.8089    D(G(z)): 0.3330 / 0.0715
[3/5][1550/1583]        Loss_D: 0.5228  Loss_G: 2.0975  D(x): 0.7766    D(G(z)): 0.1999 / 0.1542
[4/5][0/1583]   Loss_D: 0.5070  Loss_G: 3.0032  D(x): 0.8567    D(G(z)): 0.2696 / 0.0691
[4/5][50/1583]  Loss_D: 0.9748  Loss_G: 4.4468  D(x): 0.9213    D(G(z)): 0.5354 / 0.0186
[4/5][100/1583] Loss_D: 0.6213  Loss_G: 2.1751  D(x): 0.6348    D(G(z)): 0.0899 / 0.1572
[4/5][150/1583] Loss_D: 0.6697  Loss_G: 1.7084  D(x): 0.6313    D(G(z)): 0.1257 / 0.2225
[4/5][200/1583] Loss_D: 0.8032  Loss_G: 1.5958  D(x): 0.5151    D(G(z)): 0.0426 / 0.2646
[4/5][250/1583] Loss_D: 0.9996  Loss_G: 1.2608  D(x): 0.4645    D(G(z)): 0.0573 / 0.3387
[4/5][300/1583] Loss_D: 0.4582  Loss_G: 2.7660  D(x): 0.8346    D(G(z)): 0.2170 / 0.0858
[4/5][350/1583] Loss_D: 0.3809  Loss_G: 3.5536  D(x): 0.8834    D(G(z)): 0.2013 / 0.0390
[4/5][400/1583] Loss_D: 0.6527  Loss_G: 2.2881  D(x): 0.7386    D(G(z)): 0.2494 / 0.1337
[4/5][450/1583] Loss_D: 0.5231  Loss_G: 1.8814  D(x): 0.7282    D(G(z)): 0.1458 / 0.1877
[4/5][500/1583] Loss_D: 0.5383  Loss_G: 1.5842  D(x): 0.7036    D(G(z)): 0.1342 / 0.2415
[4/5][550/1583] Loss_D: 0.9516  Loss_G: 1.0133  D(x): 0.4710    D(G(z)): 0.0501 / 0.4142
[4/5][600/1583] Loss_D: 1.0705  Loss_G: 0.8433  D(x): 0.4447    D(G(z)): 0.1189 / 0.4909
[4/5][650/1583] Loss_D: 0.6983  Loss_G: 1.8885  D(x): 0.6748    D(G(z)): 0.2142 / 0.1789
[4/5][700/1583] Loss_D: 0.3807  Loss_G: 2.7729  D(x): 0.8701    D(G(z)): 0.1975 / 0.0779
[4/5][750/1583] Loss_D: 0.4087  Loss_G: 2.6961  D(x): 0.8605    D(G(z)): 0.2062 / 0.0937
[4/5][800/1583] Loss_D: 0.6979  Loss_G: 4.2029  D(x): 0.8687    D(G(z)): 0.3796 / 0.0221
[4/5][850/1583] Loss_D: 0.6469  Loss_G: 3.2722  D(x): 0.8828    D(G(z)): 0.3646 / 0.0546
[4/5][900/1583] Loss_D: 0.5220  Loss_G: 1.9707  D(x): 0.7330    D(G(z)): 0.1533 / 0.1705
[4/5][950/1583] Loss_D: 0.4789  Loss_G: 2.9207  D(x): 0.8481    D(G(z)): 0.2463 / 0.0680
[4/5][1000/1583]        Loss_D: 0.5314  Loss_G: 1.9446  D(x): 0.6925    D(G(z)): 0.1124 / 0.1768
[4/5][1050/1583]        Loss_D: 0.5690  Loss_G: 3.1526  D(x): 0.8775    D(G(z)): 0.3241 / 0.0542
[4/5][1100/1583]        Loss_D: 0.4210  Loss_G: 2.5976  D(x): 0.8832    D(G(z)): 0.2351 / 0.0944
[4/5][1150/1583]        Loss_D: 0.4784  Loss_G: 2.1561  D(x): 0.7820    D(G(z)): 0.1739 / 0.1473
[4/5][1200/1583]        Loss_D: 0.5640  Loss_G: 1.6350  D(x): 0.6674    D(G(z)): 0.0848 / 0.2409
[4/5][1250/1583]        Loss_D: 1.1821  Loss_G: 5.9182  D(x): 0.9561    D(G(z)): 0.6136 / 0.0045
[4/5][1300/1583]        Loss_D: 0.5865  Loss_G: 4.5427  D(x): 0.9453    D(G(z)): 0.3724 / 0.0155
[4/5][1350/1583]        Loss_D: 1.4747  Loss_G: 5.4916  D(x): 0.9590    D(G(z)): 0.7031 / 0.0066
[4/5][1400/1583]        Loss_D: 0.4061  Loss_G: 3.5037  D(x): 0.8808    D(G(z)): 0.2233 / 0.0375
[4/5][1450/1583]        Loss_D: 0.5928  Loss_G: 1.4196  D(x): 0.6548    D(G(z)): 0.0989 / 0.2901
[4/5][1500/1583]        Loss_D: 0.5381  Loss_G: 3.6485  D(x): 0.8691    D(G(z)): 0.2970 / 0.0338
[4/5][1550/1583]        Loss_D: 0.7174  Loss_G: 2.1455  D(x): 0.7092    D(G(z)): 0.2569 / 0.1469

结果

最后,让我们看看效果如何。在这里,我们将查看三个不同的结果。首先,我们将查看判别器和生成器的损失在训练期间的变化。其次,我们将可视化生成每个训练周期固定噪声批次的生成器输出的进展。第三,我们将查看一批真实数据旁边的生成器生成的虚假数据。

损失与训练迭代

以下是判别器和生成器的损失与训练迭代的关系图。

plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()
Generator and Discriminator Loss During Training

生成器进展的可视化

还记得我们如何在每个训练周期后保存生成器在固定噪声批次上的输出吗?现在,我们可以通过动画可视化生成器的训练进展。按播放按钮开始动画。

fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())
dcgan faces tutorial


真实图像 vs. 虚假图像

最后,让我们并排看看一些真实图像和虚假图像。

# Grab a batch of real images from the dataloader
real_batch = next(iter(dataloader))

# Plot the real images
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()
Real Images, Fake Images

下一步去哪里

我们已经到达旅程的终点,但这里还有几个方向可以探索。你可以:

  • 训练更长时间以观察效果能达到多好

  • 修改此模型以适应不同的数据集,并可能更改图像大小和模型架构

  • 查看一些其他很酷的GAN项目 这里

  • 创建生成 音乐 的GANs

脚本总运行时间: (5分钟49.207秒)

画廊由 Sphinx-Gallery 生成

文档

访问 PyTorch 的详细开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源