Pipeline for Training PyTorch Image Classification Model / Creating Dataset

Dataset Preparation: Step-by-step guide for training custom PyTorch Image Classification models with any dataset.

For training an image classification model, there are key steps, and building a pipeline for these steps can save time.

In this article, I am going to share my step-by-step guide for training image classification models with PyTorch, and by following this pipeline you can train image classification models with different dataset for your specific task.

Dataset Preparation for PyTorch Image Classification Model Training

What I recommend for training Deep Learning models (object detection, image segmentation, image classification, etc.) is to follow some pipelines and stick to them for a while, and change them from time to time. Basically, create your own documentation by following other pipelines and improving them.

For image classification model training, here is my 6 main steps:

  1. Creating the Dataset
  2. Visualizing Example Images
  3. Visualizing Class Distribution
  4. Creating Functions for Training the Model
  5. Creating the Model
  6. Training the Model

Now, I will cover the first 3 steps which is all about dataset preparation, and next article is about model training, you can find second article in the end of the page.

Also, I have a YouTube video about this article, you can watch it.

Creating the Dataset

This step is probably more important then you think, because model training is all about learning from dataset.
If your dataset is not big enough or quality of your dataset is not good, you can’t expect something that works.
What I mean by quality:

  • Correct labeled images
  • Class imbalance (for example label “A” has 2000 examples, but label “B” has 50, this is not good)
  • Image quality
  • Images from different environments

Okay, now let’s create a dataset. I recommend you first try to find a dataset from websites like Kaggle and Roboflow, and if you can’t find one, create your own dataset by scraping from the internet or taking your images with your phone and labeling them later. But most of the time, if you search the internet, you will find something to work on.

I downloaded a fish dataset from kaggle, I will use this dataset throughout this article.

When you download a dataset from the internet, it is probably not going to be ready to use (I’d say around 50%). The author may have collected the dataset in one folder, or the dataset might be split into subfolders by class, or the labels may be stored in a .csv file. You need to adjust that data. PyTorch works well with the format below, especially if you are going to use torchvision.datasets.ImageFolder

Dataset Format for PyTorch Image Classification Model Training

I am going to use torchvision.datasets.ImageFolder, so I have adjusted my data to the format mentioned above. It is little bit off-topic, therefore you can check my github repository for this process. It is a simple process, just copy images to specific folders.

PyTorch provides two data primitives:

  • torch.utils.data.Dataset: It allows you to use pre-loaded datasets as well as your own data. It stores the samples and their corresponding labels.
  • torch.utils.data.DataLoader: It wraps an iterable around the torch.utils.data.Dataset to enable easy access to the samples.
1.1) Create transformation objects
from torchvision import  transforms

"""
 with transforms you can resize,normalize images,or create augmented datasets
 here , first I resize images and turn images to torch.Tensor
"""

train_transform  = transforms.Compose([
    # Resize Image
    transforms.Resize(size=(180, 180)),   
    # Turn the image into a torch.Tensor , it transforms the image to a tensor with range [0,1]. I t implies some kind of normalization
    transforms.ToTensor()  
])

validation_transform  = transforms.Compose([
    transforms.Resize(size=(180, 180)),
    transforms.ToTensor()   
])

train_transform,validation_transform
Dataset Preparation for PyTorch Image Classification Model Training
151.2) Create Dataset From Folder (torchvision.datasets.ImageFolder)

Dont forget to change:

  • train_dir
  • validation_dir
# Use ImageFolder to create datasets
from torchvision import datasets

train_dir= "../Datasets/Fish_Dataset2/train" # path to the train folder
validation_dir= "../Datasets/Fish_Dataset2/validation" # path to the validation folder

train_data = datasets.ImageFolder(root=train_dir, 
                                  transform=train_transform) 

validation_data = datasets.ImageFolder(root=validation_dir, 
                                 transform=validation_transform)

print(f"Train data:\n{train_data}\n\nValidation data:\n{validation_data}")
Dataset Preparation for PyTorch Image Classification Model Training
1.3) Create Iterable Dataset for Training (torch.utils.data.DataLoader)

Iterable Dataset : An iterable dataset allows you to iterate over its elements in batches during training. The DataLoader manages batch creation, shuffling, and parallel data loading based on the specified parameters. When training model, it is not a good approach to update parameters with just a single image pass. Instead, we create batches of images(like 8,16,32 images), and parameters are updated after processing one batch.

from torch.utils.data import DataLoader

# I created  train_data and validation_data with datasets.ImageFolder 

train_set = DataLoader(dataset=train_data, 
                              batch_size=16, # how many samples per batch?
                              num_workers=1, # how many subprocesses to use for data loading? (higher = more)
                              shuffle=True) # shuffle the data?

validation_set = DataLoader(dataset=validation_data, 
                             batch_size=16, 
                             num_workers=1, 
                             shuffle=False) # dont usually need to shuffle testing data

train_set,validation_set
Dataset Preparation for PyTorch Image Classification Model Training

→ train_set and validation_set will be used for training

Visualizing Example Images

It is good practice to look at some example images from your dataset. Particularly when using an augmented dataset, it is helpful to observe augmented images because sometimes the augmentation might not give what you expect. By observing some examples, you can adjust the augmentation parameters. I am going to use Matplotlib for visualization.

import matplotlib.pyplot as plt
import numpy as np
import torchvision
from torchvision.transforms.functional import to_pil_image

# I create train_data above , and I will use it here 
label_dict = {y: x for x, y in train_data.class_to_idx.items()}

# Define a function to display images
def show_images(images, labels):
    plt.figure(figsize=(12, 8))
    for i in range(len(images)):
        plt.subplot(4, 4, i + 1)
        image = to_pil_image(images[i])  # Convert tensor to PIL Image
        plt.imshow(image)
        plt.title(label_dict[labels[i].item()])  # Convert numerical label to string label
        plt.axis('off')
    plt.show()

# Get the first batch
for batch_idx, (images, labels) in enumerate(train_set):
    if batch_idx == 0:  # Only process the first batch
        show_images(images, labels)
        break
Example Images from Dataset for PyTorch Image Classification Model Training

Visualizing Class Distribution

Balancing a dataset is a crucial step because it helps prevent the model from becoming biased toward one class. If you have an imbalanced dataset, the model will probably perform better on specific classes and worse on others. Visualizing the distribution can be useful for determining whether the dataset is balanced or imbalanced.

# path to train and validation sets
train_dir = train_dir
validation_dir = validation_dir

# calculate distributions in train set and save them to dictionary
train_class_counts = {}
for class_folder in os.listdir(train_dir):
    class_path = os.path.join(train_dir, class_folder)
    if os.path.isdir(class_path):
        num_images = len(os.listdir(class_path))
        train_class_counts[class_folder] = num_images

# calculate distributions in validation set and save them to dictionary
validation_class_counts = {}
for class_folder in os.listdir(validation_dir):
    class_path = os.path.join(validation_dir, class_folder)
    if os.path.isdir(class_path):
        num_images = len(os.listdir(class_path))
        validation_class_counts[class_folder] = num_images
import matplotlib.pyplot as plt

plt.figure(figsize=(15, 6))

# plot for train set
plt.subplot(1, 2, 1)
plt.bar(train_class_counts.keys(), train_class_counts.values())
plt.title('Training set Distribution')
plt.xlabel('Classes')
plt.ylabel('Sample Numbers')
plt.xticks(rotation=45)

# plot for validations set
plt.subplot(1, 2, 2)
plt.bar(validation_class_counts.keys(), validation_class_counts.values())
plt.title('Validation set Distribution')
plt.xlabel('Classes')
plt.ylabel('Sample Numbers')
plt.xticks(rotation=45)

plt.tight_layout()
plt.show()
Distribution of Classes from Dataset for PyTorch Image Classification Model Training

Okay, that’s it. Second part is all about model training.