DEV Community

Maik Paixao
Maik Paixao

Posted on • Updated on

Creating an Image Classification Model with PyTorch

Image description

Welcome!

In this tutorial, I will walk you through creating a simple image classification model using PyTorch, a popular deep learning library in Python.

Prerequisites:

  • Python Installed.
  • Familiarity with Python Programming.
  • Basic understanding of neural networks.

1. Installation:

Before we begin, make sure you have installed the necessary packages:

It's about setting up your environment and installing PyTorch, which is a widely used library for deep learning tasks, while archvision provides utilities for computer vision.

  • PyTorch: The core library for implementing deep learning architectures such as neural networks.
  • Torchvision: A helper library for PyTorch that provides access to popular datasets, model architectures, and image transformations for computer vision.
pip install torch torchvision
Enter fullscreen mode Exit fullscreen mode

2. Import Libraries:

Before you can use the features of a package, you need to import it.

toch: the main PyTorch module.

Torchvision: As mentioned, this helps with datasets and models specifically for computer vision tasks.

transforms: This provides common image transformations. In deep learning, input data often requires preprocessing to improve training efficiency and performance.

nn: This module provides all the building blocks for neural networks.
optim: Contains common optimization algorithms for tuning model parameters.

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
Enter fullscreen mode Exit fullscreen mode

3. Read and Preprocess the Data:

We will use the CIFAR-10 dataset, a set of 60,000 32 x 32 color images in 10 classes.

CIFAR-10 dataset: A well-known dataset in computer vision, consisting of 60,000 32 x 32 color images spanning 10 classes.

transform.Compose(): chains multiple image transformations. In this example, images are first converted to tensors and then normalized to have values between -1 and 1.

DataLoader: Helps in feeding data in batches, shuffling it and loading it in parallel, making the training process more efficient.

transform = transforms.Compose([
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Normalize the images
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
Enter fullscreen mode Exit fullscreen mode

4. Define the Neural Network:

We define a simple Convolutional Neural Network (CNN) framework.

Here, we are defining our Convolutional Neural Network (CNN). CNNs are the standard neural network type for image processing tasks.

nn.Conv2d(): Represents a convolutional layer. It expects (input_channels, output_channels, kernel_size) among other parameters.

nn.MaxPool2d(): Represents maximum pooling, which reduces the spatial size of the representation, making calculations faster and extracting dominant features.

nn.Linear(): Represents a fully connected layer that connects each neuron from the previous layer to the next.

The forward function specifies how data flows across the network. This flow is essential for forward and backward propagation.

class Net(nn.Module):
     def __init__(self):
         super(Net, self).__init__()
         self.conv1 = nn.Conv2d(3, 6, 5) # 3 input channels, 6 output channels, 5x5 kernel
         self.pool = nn.MaxPool2d(2, 2) # 2x2 max pooling
         self.conv2 = nn.Conv2d(6, 16, 5)
         self.fc1 = nn.Linear(16 * 5 * 5, 120)
         self.fc2 = nn.Linear(120, 84)
         self.fc3 = nn.Linear(84, 10) # 10 output classes

     def forward(self, x):
         x = self.pool(F.relu(self.conv1(x)))
         x = self.pool(F.relu(self.conv2(x)))
         x = x.view(-1, 16 * 5 * 5)
         x = F.relu(self.fc1(x))
         x = F.relu(self.fc2(x))
         x = self.fc3(x)
         return x

net = Net()
Enter fullscreen mode Exit fullscreen mode

5. Define the Loss Function and Optimizer:

Let's use Cross-Entropy loss and SGD optimizer.

Cross-Entropy loss: Commonly used in classification tasks. It measures the difference between predicted probabilities and true class labels.

SGD (Stochastic Gradient Descent): An optimization method used to minimize loss by adjusting model weights.

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
Enter fullscreen mode Exit fullscreen mode

6. Training the Network:

Here, we will train the model for a few epochs.

The essence of deep learning is this iterative process of adjusting model weights to minimize loss:

Clear gradients: Since PyTorch accumulates gradients, you need to clear them before each step.

Forward Propagation: Pass input through the model to obtain predictions.

Calculate loss: compare predictions with actual labels.
Backward Propagation: Backpropagate the loss throughout the network to calculate the gradient of the loss with respect to each weight.

Optimize: Adjust weights in the direction that minimizes loss.
The loop ensures that the model sees the data multiple times (epochs) and adjusts its weights.

for epoch in range(5): # Loop over the dataset multiple times

     running_loss = 0.0
     for i, data in enumerate(trainloader, 0):
         inputs, labels = data

         optimizer.zero_grad() # Zero the parameter gradients

         outputs = net(inputs) # Forward
         loss = criterion(outputs, labels) # Calculate loss
         loss.backward() # Backward
         optimizer.step() # Optimize

         running_loss += loss.item()
         if i % 2000 == 1999: # Print every 2000 mini-batches
             print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000))
             running_loss = 0.0

print('Finished Training')
Enter fullscreen mode Exit fullscreen mode

7. Testing the Network:

After training, it is crucial to evaluate the model's performance on unseen data:

torch.no_grad(): Disables gradient computation, which is not needed during evaluation, saving memory and computation.

Outputs: These are the predicted probabilities for each class.
Prediction: By choosing the class with the highest probability, we obtain the predicted class label.

Calculate Accuracy: Count how many predictions match the actual labels and calculate the percentage.

At the end of this process, you will have a trained neural network model capable of classifying images from the CIFAR-10 dataset. Remember, this is a basic tutorial. For better accuracy and efficiency in real-world applications, more advanced techniques and fine-tuning are required.

correct = 0
total = 0
with torch.no_grad():
     for data in testloader:
         images, labels = data
         outputs = net(images)
         _, predicted = outputs.max(1)
         total += labels.size(0)
         correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))
Enter fullscreen mode Exit fullscreen mode

And that's it! In just 10 minutes, you learned how to create and train a simple image classification model using PyTorch. With more time and tweaking, you can improve this model further or dive deeper into advanced architectures and techniques!

Hi, I'm Maik. I hope you liked the article. If you have any questions or want to connect with me and access more content, follow my channels:

LinkedIn: https://www.linkedin.com/in/maikpaixao/
Twitter: https://twitter.com/maikpaixao
Youtube: https://www.youtube.com/@maikpaixao
Instagram: https://www.instagram.com/prof.maikpaixao/
Github: https://github.com/maikpaixao

Top comments (2)

Collapse
 
vanessatelles profile image
Vanessa Telles

Parabéns pelo tutorial Maik, minha sugestão seria compartilhar um Google Colab com o código e deixar as células com os resultados, assim quem está iniciando consegue visualizar melhor o passo a passo e até se aventurar a mexer no código :)

Collapse
 
maikpaixao profile image
Maik Paixao

Obrigado pela sugestĂŁo, Vanessa! ;)