.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "beginner/basics/buildmodel_tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_beginner_basics_buildmodel_tutorial.py: `Learn the Basics `_ || `Quickstart `_ || `Tensors `_ || `Datasets & DataLoaders `_ || `Transforms `_ || **Build Model** || `Autograd `_ || `Optimization `_ || `Save & Load Model `_ Build the Neural Network ======================== Neural networks comprise of layers/modules that perform operations on data. The `torch.nn `_ namespace provides all the building blocks you need to build your own neural network. Every module in PyTorch subclasses the `nn.Module `_. A neural network is a module itself that consists of other modules (layers). This nested structure allows for building and managing complex architectures easily. In the following sections, we'll build a neural network to classify images in the FashionMNIST dataset. .. GENERATED FROM PYTHON SOURCE LINES 24-32 .. code-block:: default import os import torch from torch import nn from torch.utils.data import DataLoader from torchvision import datasets, transforms .. GENERATED FROM PYTHON SOURCE LINES 33-37 Get Device for Training ----------------------- We want to be able to train our model on an `accelerator `__ such as CUDA, MPS, MTIA, or XPU. If the current accelerator is available, we will use it. Otherwise, we use the CPU. .. GENERATED FROM PYTHON SOURCE LINES 37-41 .. code-block:: default device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu" print(f"Using {device} device") .. rst-class:: sphx-glr-script-out .. code-block:: none Using cuda device .. GENERATED FROM PYTHON SOURCE LINES 42-47 Define the Class ------------------------- We define our neural network by subclassing ``nn.Module``, and initialize the neural network layers in ``__init__``. Every ``nn.Module`` subclass implements the operations on input data in the ``forward`` method. .. GENERATED FROM PYTHON SOURCE LINES 47-65 .. code-block:: default class NeuralNetwork(nn.Module): def __init__(self): super().__init__() self.flatten = nn.Flatten() self.linear_relu_stack = nn.Sequential( nn.Linear(28*28, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 10), ) def forward(self, x): x = self.flatten(x) logits = self.linear_relu_stack(x) return logits .. GENERATED FROM PYTHON SOURCE LINES 66-68 We create an instance of ``NeuralNetwork``, and move it to the ``device``, and print its structure. .. GENERATED FROM PYTHON SOURCE LINES 68-73 .. code-block:: default model = NeuralNetwork().to(device) print(model) .. rst-class:: sphx-glr-script-out .. code-block:: none NeuralNetwork( (flatten): Flatten(start_dim=1, end_dim=-1) (linear_relu_stack): Sequential( (0): Linear(in_features=784, out_features=512, bias=True) (1): ReLU() (2): Linear(in_features=512, out_features=512, bias=True) (3): ReLU() (4): Linear(in_features=512, out_features=10, bias=True) ) ) .. GENERATED FROM PYTHON SOURCE LINES 74-80 To use the model, we pass it the input data. This executes the model's ``forward``, along with some `background operations `_. Do not call ``model.forward()`` directly! Calling the model on the input returns a 2-dimensional tensor with dim=0 corresponding to each output of 10 raw predicted values for each class, and dim=1 corresponding to the individual values of each output. We get the prediction probabilities by passing it through an instance of the ``nn.Softmax`` module. .. GENERATED FROM PYTHON SOURCE LINES 80-88 .. code-block:: default X = torch.rand(1, 28, 28, device=device) logits = model(X) pred_probab = nn.Softmax(dim=1)(logits) y_pred = pred_probab.argmax(1) print(f"Predicted class: {y_pred}") .. rst-class:: sphx-glr-script-out .. code-block:: none Predicted class: tensor([3], device='cuda:0') .. GENERATED FROM PYTHON SOURCE LINES 89-91 -------------- .. GENERATED FROM PYTHON SOURCE LINES 94-100 Model Layers ------------------------- Let's break down the layers in the FashionMNIST model. To illustrate it, we will take a sample minibatch of 3 images of size 28x28 and see what happens to it as we pass it through the network. .. GENERATED FROM PYTHON SOURCE LINES 100-104 .. code-block:: default input_image = torch.rand(3,28,28) print(input_image.size()) .. rst-class:: sphx-glr-script-out .. code-block:: none torch.Size([3, 28, 28]) .. GENERATED FROM PYTHON SOURCE LINES 105-110 nn.Flatten ^^^^^^^^^^^^^^^^^^^^^^ We initialize the `nn.Flatten `_ layer to convert each 2D 28x28 image into a contiguous array of 784 pixel values ( the minibatch dimension (at dim=0) is maintained). .. GENERATED FROM PYTHON SOURCE LINES 110-115 .. code-block:: default flatten = nn.Flatten() flat_image = flatten(input_image) print(flat_image.size()) .. rst-class:: sphx-glr-script-out .. code-block:: none torch.Size([3, 784]) .. GENERATED FROM PYTHON SOURCE LINES 116-121 nn.Linear ^^^^^^^^^^^^^^^^^^^^^^ The `linear layer `_ is a module that applies a linear transformation on the input using its stored weights and biases. .. GENERATED FROM PYTHON SOURCE LINES 121-126 .. code-block:: default layer1 = nn.Linear(in_features=28*28, out_features=20) hidden1 = layer1(flat_image) print(hidden1.size()) .. rst-class:: sphx-glr-script-out .. code-block:: none torch.Size([3, 20]) .. GENERATED FROM PYTHON SOURCE LINES 127-135 nn.ReLU ^^^^^^^^^^^^^^^^^^^^^^ Non-linear activations are what create the complex mappings between the model's inputs and outputs. They are applied after linear transformations to introduce *nonlinearity*, helping neural networks learn a wide variety of phenomena. In this model, we use `nn.ReLU `_ between our linear layers, but there's other activations to introduce non-linearity in your model. .. GENERATED FROM PYTHON SOURCE LINES 135-142 .. code-block:: default print(f"Before ReLU: {hidden1}\n\n") hidden1 = nn.ReLU()(hidden1) print(f"After ReLU: {hidden1}") .. rst-class:: sphx-glr-script-out .. code-block:: none Before ReLU: tensor([[ 0.1082, 0.2607, -0.0803, 0.4393, -0.2751, -0.1598, -0.3396, -0.1267, 0.1498, 0.5499, -0.3696, -0.2498, 0.1291, 0.2859, -0.4803, -0.3475, 0.0810, -0.2506, -0.2422, 0.2068], [ 0.2251, 0.6316, -0.1201, 0.6227, 0.0540, 0.0446, -0.3771, -0.0794, 0.2245, 0.4890, 0.1999, -0.0085, 0.0879, 0.4191, -0.6633, -0.0096, 0.2844, 0.1656, -0.1447, 0.2600], [ 0.0810, 0.4832, -0.0078, 0.2615, -0.2108, -0.2037, 0.0895, -0.1110, -0.1653, 0.5774, -0.4137, -0.2449, -0.0132, 0.0149, -0.6887, -0.0839, 0.1992, 0.1747, -0.3002, 0.0147]], grad_fn=) After ReLU: tensor([[0.1082, 0.2607, 0.0000, 0.4393, 0.0000, 0.0000, 0.0000, 0.0000, 0.1498, 0.5499, 0.0000, 0.0000, 0.1291, 0.2859, 0.0000, 0.0000, 0.0810, 0.0000, 0.0000, 0.2068], [0.2251, 0.6316, 0.0000, 0.6227, 0.0540, 0.0446, 0.0000, 0.0000, 0.2245, 0.4890, 0.1999, 0.0000, 0.0879, 0.4191, 0.0000, 0.0000, 0.2844, 0.1656, 0.0000, 0.2600], [0.0810, 0.4832, 0.0000, 0.2615, 0.0000, 0.0000, 0.0895, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.0149, 0.0000, 0.0000, 0.1992, 0.1747, 0.0000, 0.0147]], grad_fn=) .. GENERATED FROM PYTHON SOURCE LINES 143-148 nn.Sequential ^^^^^^^^^^^^^^^^^^^^^^ `nn.Sequential `_ is an ordered container of modules. The data is passed through all the modules in the same order as defined. You can use sequential containers to put together a quick network like ``seq_modules``. .. GENERATED FROM PYTHON SOURCE LINES 148-158 .. code-block:: default seq_modules = nn.Sequential( flatten, layer1, nn.ReLU(), nn.Linear(20, 10) ) input_image = torch.rand(3,28,28) logits = seq_modules(input_image) .. GENERATED FROM PYTHON SOURCE LINES 159-165 nn.Softmax ^^^^^^^^^^^^^^^^^^^^^^ The last linear layer of the neural network returns `logits` - raw values in [-\infty, \infty] - which are passed to the `nn.Softmax `_ module. The logits are scaled to values [0, 1] representing the model's predicted probabilities for each class. ``dim`` parameter indicates the dimension along which the values must sum to 1. .. GENERATED FROM PYTHON SOURCE LINES 165-170 .. code-block:: default softmax = nn.Softmax(dim=1) pred_probab = softmax(logits) .. GENERATED FROM PYTHON SOURCE LINES 171-180 Model Parameters ------------------------- Many layers inside a neural network are *parameterized*, i.e. have associated weights and biases that are optimized during training. Subclassing ``nn.Module`` automatically tracks all fields defined inside your model object, and makes all parameters accessible using your model's ``parameters()`` or ``named_parameters()`` methods. In this example, we iterate over each parameter, and print its size and a preview of its values. .. GENERATED FROM PYTHON SOURCE LINES 180-187 .. code-block:: default print(f"Model structure: {model}\n\n") for name, param in model.named_parameters(): print(f"Layer: {name} | Size: {param.size()} | Values : {param[:2]} \n") .. rst-class:: sphx-glr-script-out .. code-block:: none Model structure: NeuralNetwork( (flatten): Flatten(start_dim=1, end_dim=-1) (linear_relu_stack): Sequential( (0): Linear(in_features=784, out_features=512, bias=True) (1): ReLU() (2): Linear(in_features=512, out_features=512, bias=True) (3): ReLU() (4): Linear(in_features=512, out_features=10, bias=True) ) ) Layer: linear_relu_stack.0.weight | Size: torch.Size([512, 784]) | Values : tensor([[-0.0269, -0.0093, -0.0243, ..., 0.0283, 0.0274, 0.0187], [ 0.0176, -0.0315, -0.0225, ..., 0.0161, 0.0100, 0.0230]], device='cuda:0', grad_fn=) Layer: linear_relu_stack.0.bias | Size: torch.Size([512]) | Values : tensor([-0.0129, 0.0176], device='cuda:0', grad_fn=) Layer: linear_relu_stack.2.weight | Size: torch.Size([512, 512]) | Values : tensor([[ 0.0409, -0.0028, 0.0306, ..., -0.0302, 0.0010, 0.0319], [ 0.0010, 0.0105, -0.0276, ..., -0.0122, 0.0358, -0.0189]], device='cuda:0', grad_fn=) Layer: linear_relu_stack.2.bias | Size: torch.Size([512]) | Values : tensor([-0.0146, 0.0077], device='cuda:0', grad_fn=) Layer: linear_relu_stack.4.weight | Size: torch.Size([10, 512]) | Values : tensor([[ 0.0312, 0.0284, -0.0296, ..., -0.0255, -0.0238, -0.0150], [-0.0365, 0.0128, 0.0315, ..., 0.0274, 0.0073, -0.0227]], device='cuda:0', grad_fn=) Layer: linear_relu_stack.4.bias | Size: torch.Size([10]) | Values : tensor([-0.0410, 0.0270], device='cuda:0', grad_fn=) .. GENERATED FROM PYTHON SOURCE LINES 188-190 -------------- .. GENERATED FROM PYTHON SOURCE LINES 192-195 Further Reading ----------------- - `torch.nn API `_ .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 1.015 seconds) .. _sphx_glr_download_beginner_basics_buildmodel_tutorial.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: buildmodel_tutorial.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: buildmodel_tutorial.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_