Fine-Tuning Image Classifiers with PyTorch and the timm library for Beginners
- Introduction
- Getting Started with the Code
- Setting Up Your Python Environment
- Importing the Required Dependencies
- Setting Up the Project
- Loading and Exploring the Dataset
- Selecting a Model
- Preparing the Data
- Fine-tuning the Model
- Making Predictions with the Model
- Exploring the In-Browser Demo
- Conclusion
Tutorial Links
- Fine-Tuning Image Classifiers with PyTorch and the timm library for Beginners: Learn how to fine-tune image classification models with PyTorch and the timm library by creating a hand gesture recognizer in this easy-to-follow guide for beginners.
- Exporting timm Image Classifiers from Pytorch to ONNX: Learn how to export timm image classification models from PyTorch to ONNX and perform inference using ONNX Runtime.
Introduction
Welcome to this hands-on guide to fine-tuning image classifiers with PyTorch and the timm library! Fine-tuning refers to taking a pre-trained model and adjusting its parameters using a new dataset to enhance its performance on a specific task. We can leverage pre-trained models to achieve high performance even when working with limited data and computational resources. The timm library further aids our goal with its wide range of pre-trained models, catering to diverse needs and use cases.
In this tutorial, we develop a hand gesture recognizer. Hand gesture recognition has many real-world applications, ranging from human-computer interaction and sign-language translation to creating immersive gaming experiences. By the end of this tutorial, you will have a practical hand gesture recognizer and a solid foundation to apply to other image classification tasks. You’ll also be able to interact with a model trained with this tutorial’s code through an in-browser demo that runs locally on your computer. Check out the video below for a quick preview.
This guide is structured so that you don’t need a deep understanding of deep learning to complete it. If you follow the instructions, you can make it through! Yet, if you are eager to delve deeper into machine learning and deep learning, I recommend fast.ai’s Practical Deep Learning for Coders course. The course employs a hands-on approach that starts you off training models from the get-go and gradually digs deeper into the foundational concepts.
Let’s dive in and start training our hand gesture classifier!
Getting Started with the Code
The tutorial code is available as a Jupyter Notebook, which you can run locally or in a cloud-based environment like Google Colab. If you’re new to these platforms or need some guidance setting up, I’ve created dedicated tutorials to help you:
Getting Started with Google Colab: This tutorial introduces you to Google Colab, a free, cloud-based Jupyter Notebook service. You’ll learn to write, run, and share Python code directly in your browser.
Setting Up a Local Python Environment with Mamba for Machine Learning Projects on Windows: This tutorial guides you through installing the Mamba package manager on Windows, setting up a local Python environment, and installing PyTorch and Jupyter for machine learning projects.
No matter your choice of environment, you’ll be well-prepared to follow along with the rest of this tutorial. You can download the notebook from the tutorial’s GitHub repository or open the notebook directly in Google Colab using the links below.
Platform | Jupyter Notebook | Utility File |
---|---|---|
Google Colab | Open In Colab | |
Linux | GitHub Repository | |
Linux (Intel Arc) | GitHub Repository | |
Windows | GitHub Repository | windows_utils_hf.py |
Windows (Intel Arc) | GitHub Repository | windows_utils_hf.py |
Setting Up Your Python Environment
Before diving into the code, we’ll create a Python environment and install the necessary libraries. Creating a dedicated environment will ensure our project has all its dependencies in one place and does not interfere with other Python projects you may have.
Please note that this section is for readers setting up a local Python environment on their machines. If you’re following this tutorial on a cloud-based platform like Google Colab, the platform already provides an isolated environment with many Python libraries pre-installed. In that case, you may skip this section and directly proceed to the code sections. However, you may still need to install certain libraries specific to this tutorial using similar pip install
commands within your notebook. The dedicated Colab Notebook contains the instructions for running it in Google Colab.
Creating a Python Environment
First, we’ll create a Python environment using Conda. Conda is a package manager that can create isolated Python environments. These environments are like sandboxed spaces where you can install Python libraries without affecting the rest of your system.
To create a new Python environment, open a terminal with Conda/Mamba installed and run the following commands:
# Create a new Python 3.10 environment
conda create --name pytorch-env python=3.10 -y
# Activate the environment
conda activate pytorch-env
# Create a new Python 3.10 environment
mamba create --name pytorch-env python=3.10 -y
# Activate the environment
mamba activate pytorch-env
The first command creates a new Python environment named pytorch-env
using Python 3.10. The -y
flag confirms that we want to proceed with the installation. After building the environment, the second command activates it, setting it as the active Python environment.
Installing PyTorch
PyTorch is a popular open-source machine learning framework that enables users to perform tensor computations, build dynamic computational graphs, and implement custom machine learning architectures. To install PyTorch with CUDA support (which allows PyTorch to leverage NVIDIA GPUs for faster training), we’ll use the following command:
# Install PyTorch with CUDA
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# MPS (Metal Performance Shaders) acceleration is available on MacOS 12.3+
pip install torch torchvision torchaudio
# Install PyTorch with CUDA
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
Installation instructions for specific hardware and operating systems are available in the “Get Started” section of the PyTorch website.
Installing Additional Libraries
We also need to install some additional libraries for our project. If you’re new to Python or haven’t used some of these packages before, don’t worry.
Here’s a brief overview:
datasets
: A library for accessing and sharing datasets for Audio, Computer Vision, and Natural Language Processing (NLP) tasks.jupyter
: An open-source web application that allows you to create and share documents that contain live code, equations, visualizations, and narrative text.matplotlib
: This package provides a comprehensive collection of visualization tools to create high-quality plots, charts, and graphs for data exploration and presentation.pandas
: This package provides fast, powerful, and flexible data analysis and manipulation tools.pillow
: The Python Imaging Library adds image processing capabilities.timm
: The timm library provides state-of-the-art (SOTA) computer vision models, layers, utilities, optimizers, schedulers, data loaders, augmentations, and training/evaluation scripts.torcheval
: A library with simple and straightforward tooling for model evaluations.torchtnt
: A library for PyTorch training tools and utilities.tqdm
: A Python library that provides fast, extensible progress bars for loops and other iterable objects in Python.Jupyter Client
: This package contains the reference implementation of the Jupyter protocol. It also provides client and kernel management APIs for working with kernels. We will install an older version than the one included with Jupyter (<8) to avoid an issue that causes the training notebook to freeze during training (link).PyZMQ
: This package provides Python bindings for ZeroMQ, a lightweight and fast messaging implementation used by Jupyter Notebooks. We will install an older version than the one included with Jupyter (<25) to avoid an issue that causes the training notebook to freeze during training (link).
To install these additional libraries, we’ll use the following command:
# Install additional dependencies
pip install datasets jupyter matplotlib pandas pillow timm torcheval torchtnt tqdm
# Install older pyzmq and jupyter client versions
pip install --upgrade "jupyter_client<8" "pyzmq<25"
Installing Utility Packages
Finally, we will install some utility packages I made to help us handle images (cjm_pil_utils
), interact with PyTorch (cjm_pytorch_utils
), and work with pandas DataFrames (cjm_pandas_utils
):
# Install utility packages
pip install cjm_pandas_utils cjm_pil_utils cjm_pytorch_utils
Now, our environment is all setup and ready to go! Remember, these libraries are just tools. If you don’t fully understand them yet, don’t worry. As we go through the tutorial, we’ll learn more about these tools and see them in action.
Launching Jupyter Notebook
Now that our environment is ready, it’s time to launch Jupyter Notebook. Jupyter Notebooks provide an interactive coding environment where we’ll work for the rest of this tutorial. To launch Jupyter Notebook, navigate to the location where you have stored the tutorial notebook (if you downloaded it) from a terminal with the pytorch-env
environment active, and type the following command:
jupyter notebook
This command will open a new tab in your default web browser, showing the Jupyter file browser. From the Jupyter file browser, you can open the tutorial notebook or create a new one to start the next section. Remember: If you close your terminal, the Jupyter Notebook server will stop. So, keep your terminal running while you’re working on the tutorial.
Importing the Required Dependencies
With our environment set up, it’s time to start the coding part of this tutorial. First, we will import the necessary Python packages into our Jupyter Notebook. Here’s a brief overview of how we’ll use these packages:
- HuggingFace Datasets dependencies: I host the dataset on HuggingFace Hub, and this package allows us to load our dataset with a single line of code.
- matplotlib: We use the matplotlib package to explore the dataset samples and class distribution.
- NumPy: We’ll use it to store PIL Images as arrays of pixel values.
- pandas: We use Pandas
DataFrame
andSeries
objects to format data as tables. - PIL (Pillow): We’ll use it for opening and working with image files.
- Python Standard Library dependencies: These are built-in modules that come with Python. We’ll use them for various tasks like handling file paths (
pathlib.Path
), manipulating JSON files (json
), random number generation (random
), multiprocessing (multiprocessing
), mathematical operations (math
), copying Python objects (copy
), file matching patterns (glob
), working with dates and times (datetime
), and interacting with the operating system (os
). - PyTorch dependencies: We’ll use PyTorch’s various modules for building our model, processing data, and training.
- timm library: We’ll use the timm library to download and prepare a pre-trained model for fine-tuning.
- tqdm: We use the library to track the progress of longer processes like training.
- Utility functions: These are helper functions from the packages we installed earlier. They provide shortcuts for routine tasks and keep our code clean and readable.
# Import Python Standard Library dependencies
from copy import copy
import datetime
from glob import glob
import json
import math
import multiprocessing
import os
from pathlib import Path
import random
# Import utility functions
from cjm_pandas_utils.core import markdown_to_pandas
from cjm_pil_utils.core import resize_img
from cjm_pytorch_utils.core import set_seed, pil_to_tensor, tensor_to_pil, get_torch_device, denorm_img_tensor
# Import HuggingFace Datasets dependencies
from datasets import load_dataset
# Import matplotlib for creating plots
import matplotlib.pyplot as plt
# Import numpy
import numpy as np
# Import pandas module for data manipulation
import pandas as pd
# Set options for Pandas DataFrame display
'max_colwidth', None) # Do not truncate the contents of cells in the DataFrame
pd.set_option('display.max_rows', None) # Display all rows in the DataFrame
pd.set_option('display.max_columns', None) # Display all columns in the DataFrame
pd.set_option(
# Import PIL for image manipulation
from PIL import Image
# Import timm library
import timm
# Import PyTorch dependencies
import torch
import torch.nn as nn
from torch.amp import autocast
from torch.cuda.amp import GradScaler
from torchvision import transforms
import torchvision.transforms.functional as TF
from torch.utils.data import Dataset, DataLoader
from torchtnt.utils import get_module_summary
from torcheval.metrics import MulticlassAccuracy
# Import tqdm for progress bar
from tqdm.auto import tqdm
Having successfully imported the dependencies, we are ready to move to the next step.
The get_module_summary function moved from the torcheval package to torchtnt.
Setting Up the Project
In this section, we set up some basics for our project. First, we set a seed for generating random numbers using the set_seed
function from the cjm_pytorch_utils
package.
Setting a Random Number Seed
A fixed seed value is helpful when training deep-learning models for reproducibility, debugging, and comparison. Having reproducible results allows others to confirm your findings. Using a fixed seed can make it easier to find bugs as it ensures the same inputs produce the same outputs. Likewise, using fixed seed values lets you compare performance between models and training parameters. That said, it’s often a good idea to test different seed values to see how your model’s performance varies between them. Also, don’t use a fixed seed value when you deploy the final model.
# Set the seed for generating random numbers in PyTorch, NumPy, and Python's random module.
= 1234
seed set_seed(seed)
Setting the Device and Data Type
Next, we determine the device to run our computations on and the data type of our tensors using the get_torch_device
function from the cjm_pytorch_utils
package.
PyTorch can run on either a CPU or a GPU. The get_torch_device
function will automatically check if a supported Nvidia or Mac GPU is available. Otherwise, it will use the CPU. We’ll use the device and type variables to ensure all tensors and model weights are on the correct device and have the same data type. Otherwise, we might get errors.
= get_torch_device()
device = torch.float32
dtype device, dtype
('cuda', torch.float32)
Setting the Directory Paths
We’ll then set up a directory for our project to store our results and other related files. The code currently creates the folder in the current directory (./
). Update the path if that is not suitable for you.
# The name for the project
= f"pytorch-timm-image-classifier"
project_name
# The path for the project folder
= Path(f"./{project_name}/")
project_dir
# Create the project directory if it does not already exist
=True, exist_ok=True)
project_dir.mkdir(parents
print(project_dir)
pytorch-timm-image-classifier
We also need a place to store our datasets. We’re going to create a directory for this purpose. If running locally, select a suitable folder location to store the dataset. For a cloud service like Google Colab, you can set it to the current directory.
# Define path to store datasets
= Path("/mnt/980_1TB_1/Datasets/")
dataset_dir # Create the dataset directory if it does not exist
=True, exist_ok=True)
dataset_dir.mkdir(parentsprint(f"Dataset Directory: {dataset_dir}")
Dataset Directory: /mnt/980_1TB_1/Datasets
Double-check the project and dataset directories exist in the specified paths and that you can add files to them before continuing.
At this point, our environment is set up and ready to go. We’ve set our random seed, determined our computation device, and set up directories for our project and dataset. In the next section, we will download and explore the dataset.
Loading and Exploring the Dataset
Now that we set up our project, we can start working with our dataset. The dataset we’ll use is a downscaled subset of HaGRID (HAnd Gesture Recognition Image Dataset) that I modified for image classification tasks. The dataset contains images for 18
distinct hand gestures and an additional no_gesture
class for idle hands. The dataset is approximately 3.8 GB
, but you will need about 7.6 GB
to store the archive file and extracted dataset.
- HuggingFace Hub Dataset Repository: cj-mills/hagrid-classification-512p-no-gesture-150k-zip
The following steps demonstrate how to load the dataset from the HuggingFace Hub, inspect the dataset, and visualize some sample images.
Setting the Dataset Path
We’ll first set up the path for our dataset. We’ll construct the HuggingFace Hub dataset name by combining the username and the dataset name. We then define where to cache the dataset locally.
# Set the name of the dataset
= 'hagrid-classification-512p-no-gesture-150k-zip'
dataset_name
# Construct the HuggingFace Hub dataset name by combining the username and dataset name
= f'cj-mills/{dataset_name}'
hf_dataset print(f"HuggingFace Dataset: {hf_dataset}")
# Create the path to the directory where the dataset will be cached
= Path(f'{dataset_dir}/{dataset_name}')
cache_dir print(f"Dataset Path: {cache_dir}")
HuggingFace Dataset: cj-mills/hagrid-classification-512p-no-gesture-150k-zip
Dataset Path: /mnt/980_1TB_1/Datasets/hagrid-classification-512p-no-gesture-150k-zip
Downloading the Dataset
We’ll now download the dataset from the HuggingFace Hub using the load_dataset
function. We’ll set the number of worker processes for loading data to the number of CPU cores available on your machine.
If you are following the tutorial on a Windows machine, you might need to enable longer file path lengths for the load_dataset
function to work:
Type “Registry Editor” into the Windows search bar and click
Run as administrator
.In the Registry Editor, input the following location into the text box and press
Enter
.Computer\HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Control\FileSystem
Double-click the entry named
LongPathsEnabled
.Update the
Value data:
value to1
.You may now close the Registry Editor and continue with the tutorial.
# Set the number of worker processes for loading data. This should be the number of CPUs available.
= multiprocessing.cpu_count()
num_workers
# Load the dataset from Hugging Face Hub
= load_dataset(hf_dataset,
dataset ='train',
split=cache_dir,
cache_dir=num_workers)
num_proc
# Print dataset summary
print(dataset)
Dataset({
features: ['image', 'label'],
num_rows: 153735
})
The dataset summary indicates there are 153,735 samples. Each sample has an image and a label. The label values for each dataset sample are index values corresponding to class names. For example, the label value 0
corresponds to the call
gesture.
Deleting the Archive File
After loading the dataset, we can delete the downloaded archive file to free up some space. We define a helper function, delete_files_in_directory
, which deletes all files in a specified directory.
def delete_files_in_directory(directory: Path):
"""
Delete all files in a given directory.
Args:
directory (Path): The directory in which to delete files.
Raises:
Exception: If there's an error when trying to delete a file, an exception is raised and caught.
The exception is then logged and the process continues with the remaining files.
"""
if not directory.exists():
print(f"Directory {directory} does not exist.")
return
for item in directory.glob('*'):
if not item.is_file():
continue
try:
item.unlink()except Exception as e:
print(f"Unable to delete file {item}. Error: {e}")
We’ll then use this function to delete the archive file.
= Path(os.path.join(cache_dir, "downloads"))
download_directory delete_files_in_directory(download_directory)
Inspecting the Class Distribution
Next, we get the names of all the classes in our dataset and inspect the distribution of images among these classes. A balanced dataset (where each class has approximately the same number of instances) is ideal for training a machine-learning model.
Get image classes
= dataset.features['label'].names
class_names
pd.DataFrame(class_names)
0 | |
---|---|
0 | call |
1 | dislike |
2 | fist |
3 | four |
4 | like |
5 | mute |
6 | no_gesture |
7 | ok |
8 | one |
9 | palm |
10 | peace |
11 | peace_inverted |
12 | rock |
13 | stop |
14 | stop_inverted |
15 | three |
16 | three2 |
17 | two_up |
18 | two_up_inverted |
Visualize the class distribution
= pd.DataFrame(dataset['label']).value_counts().sort_index()
class_counts
# Plot the distribution
='bar')
class_counts.plot(kind'Class distribution')
plt.title('Count')
plt.ylabel('Classes')
plt.xlabel(range(len(class_counts.index)), class_names) # Set the x-axis tick labels
plt.xticks(=75)
plt.xticks(rotation
# Rotate x-axis labels
plt.show()
Each class, excluding the no_gesture
class, has roughly the same number of samples. The no_gesture
class contains approximately four times as many images because of the immense variety of non-matching hand positions.
Visualizing Sample Images
Lastly, we will visualize the first sample image of each class in our dataset. Visualizing the samples helps us get a feel for the kind of images we’re working with and whether they’re suitable for the task at hand.
# Get indices for the first sample in the dataset for each class
= [dataset['label'].index(value) for value in range(len(class_names)) if value in dataset['label']]
indices
# Calculate the number of rows and columns
= math.floor(math.sqrt(len(indices)))
grid_size = grid_size+(1 if grid_size**2 < len(indices) else 0)
n_rows = grid_size
n_cols
# Create a list to store the first image found for each class
= dataset[indices]['image']
images = dataset[indices]['label']
labels
# Create a figure for the grid
= plt.subplots(n_rows, n_cols, figsize=(10,10))
fig, axs
for i, ax in enumerate(axs.flatten()):
# If we have an image for this subplot
if i < len(images) and images[i]:
# Add the image to the subplot
ax.imshow(np.array(images[i]))# Set the title to the corresponding class name
ax.set_title(class_names[labels[i]])# Remove the axis
'off')
ax.axis(else:
# If no image, hide the subplot
'off')
ax.axis(
# Display the grid
plt.tight_layout() plt.show()
We have loaded the dataset, inspected its class distribution, and visualized some sample images. In the next section, we will select and load our model.
Selecting a Model
Choosing a suitable model for your task is crucial for the success of your machine learning project. The model you select will depend on several factors, including the size and nature of your dataset, the problem you’re trying to solve, and the computational resources you have at your disposal.
Exploring Available Models
You can explore the pretrained models available in the timm library using the timm.list_models()
method. The library has hundreds of models, so we’ll narrow our search to the ResNet18 family of models. ResNet 18 models are popular for image classification tasks due to their balance of accuracy and speed.
'resnet18*', pretrained=True)) pd.DataFrame(timm.list_models(
0 | |
---|---|
0 | resnet18.a1_in1k |
1 | resnet18.a2_in1k |
2 | resnet18.a3_in1k |
3 | resnet18.fb_ssl_yfcc100m_ft_in1k |
4 | resnet18.fb_swsl_ig1b_ft_in1k |
5 | resnet18.gluon_in1k |
6 | resnet18.tv_in1k |
7 | resnet18d.ra2_in1k |
Choosing the ResNet18-D Model
For this tutorial, I went with the pretrained ResNet 18-D model. This model’s balance of accuracy and speed makes it suitable for real-time applications, such as hand gesture recognition. While this model is a good all-rounder, others may work better for specific applications. For example, some models are designed to run on mobile devices and may sacrifice some accuracy for improved performance. Whatever your requirements are, the timm library likely has a suitable model for your needs. Feel free to try different models and see how they compare.
Inspecting the ResNet18-D Model Configuration
Next, we will inspect the configuration of our chosen model. The model config gives us information about the pretraining process for the model.
# Import the resnet module
from timm.models import resnet
# Define the ResNet model variant to use
= 'resnet18d'
resnet_model
# Get the default configuration of the chosen model
= resnet.default_cfgs[resnet_model].default.to_dict()
model_cfg
# Show the default configuration values
='index') pd.DataFrame.from_dict(model_cfg, orient
0 | |
---|---|
url | https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet18d_ra2-48a79e06.pth |
hf_hub_id | timm/ |
custom_load | False |
input_size | (3, 224, 224) |
test_input_size | (3, 288, 288) |
fixed_input_size | False |
interpolation | bicubic |
crop_pct | 0.875 |
test_crop_pct | 0.95 |
crop_mode | center |
mean | (0.485, 0.456, 0.406) |
std | (0.229, 0.224, 0.225) |
num_classes | 1000 |
pool_size | (7, 7) |
first_conv | conv1.0 |
classifier | fc |
origin_url | https://github.com/huggingface/pytorch-image-models |
Retrieving Normalization Statistics
Before we can use the ResNet18-D model, we need to normalize our dataset. Normalization is a process that changes the range of pixel intensity values to make the neural network converge faster during training. It is performed by subtracting the mean from the pixel values and dividing by the standard deviation of the dataset. The mean and standard deviation values specific to the dataset used in the pretraining process of our model are called normalization statistics. To do this, we will retrieve the normalization statistics (mean and std) specific to our pretrained model.
# Retrieve normalization statistics (mean and std) specific to the pretrained model
= model_cfg['mean'], model_cfg['std']
mean, std = (mean, std)
norm_stats norm_stats
((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
Loading the ResNet18-D Model
We can now load our model. We’ll set the number of output classes equal to the number of image classes in our dataset. We’ll also specify the device and data type for the model.
# Create a pretrained ResNet model with the number of output classes equal to the number of class names
# 'timm.create_model' function automatically downloads and initializes the pretrained weights
= timm.create_model(resnet_model, pretrained=True, num_classes=len(class_names))
resnet18
# Set the device and data type for the model
= resnet18.to(device=device, dtype=dtype)
resnet18
# Add attributes to store the device and model name for later reference
= device
resnet18.device = resnet_model resnet18.name
Selecting the Model
With our model loaded, we can now select it for training.
= resnet18 model
Summarizing the Model
Finally, let’s generate a summary of our model. The summary gives us an overview of its structure and performance characteristics.
# Define the input to the model
= torch.randn(1, 3, 256, 256).to(device)
test_inp
# Get a summary of the model as a Pandas DataFrame
= markdown_to_pandas(f"{get_module_summary(model, [test_inp])}")
summary_df
# Filter the summary to only contain Conv2d layers and the model
= summary_df[(summary_df.index == 0) | (summary_df['Type'] == 'Conv2d')]
summary_df
# Remove the column "Contains Uninitialized Parameters?"
'Contains Uninitialized Parameters?', axis=1) summary_df.drop(
Type | # Parameters | # Trainable Parameters | Size (bytes) | Forward FLOPs | Backward FLOPs | In size | Out size | |
---|---|---|---|---|---|---|---|---|
0 | ResNet | 11.2 M | 11.2 M | 44.9 M | 2.7 G | 5.3 G | [1, 3, 256, 256] | [1, 19] |
2 | Conv2d | 864 | 864 | 3.5 K | 14.2 M | 14.2 M | [1, 3, 256, 256] | [1, 32, 128, 128] |
5 | Conv2d | 9.2 K | 9.2 K | 36.9 K | 150 M | 301 M | [1, 32, 128, 128] | [1, 32, 128, 128] |
8 | Conv2d | 18.4 K | 18.4 K | 73.7 K | 301 M | 603 M | [1, 32, 128, 128] | [1, 64, 128, 128] |
14 | Conv2d | 36.9 K | 36.9 K | 147 K | 150 M | 301 M | [1, 64, 64, 64] | [1, 64, 64, 64] |
19 | Conv2d | 36.9 K | 36.9 K | 147 K | 150 M | 301 M | [1, 64, 64, 64] | [1, 64, 64, 64] |
23 | Conv2d | 36.9 K | 36.9 K | 147 K | 150 M | 301 M | [1, 64, 64, 64] | [1, 64, 64, 64] |
28 | Conv2d | 36.9 K | 36.9 K | 147 K | 150 M | 301 M | [1, 64, 64, 64] | [1, 64, 64, 64] |
33 | Conv2d | 73.7 K | 73.7 K | 294 K | 75.5 M | 150 M | [1, 64, 64, 64] | [1, 128, 32, 32] |
38 | Conv2d | 147 K | 147 K | 589 K | 150 M | 301 M | [1, 128, 32, 32] | [1, 128, 32, 32] |
43 | Conv2d | 8.2 K | 8.2 K | 32.8 K | 8.4 M | 16.8 M | [1, 64, 32, 32] | [1, 128, 32, 32] |
46 | Conv2d | 147 K | 147 K | 589 K | 150 M | 301 M | [1, 128, 32, 32] | [1, 128, 32, 32] |
51 | Conv2d | 147 K | 147 K | 589 K | 150 M | 301 M | [1, 128, 32, 32] | [1, 128, 32, 32] |
56 | Conv2d | 294 K | 294 K | 1.2 M | 75.5 M | 150 M | [1, 128, 32, 32] | [1, 256, 16, 16] |
61 | Conv2d | 589 K | 589 K | 2.4 M | 150 M | 301 M | [1, 256, 16, 16] | [1, 256, 16, 16] |
66 | Conv2d | 32.8 K | 32.8 K | 131 K | 8.4 M | 16.8 M | [1, 128, 16, 16] | [1, 256, 16, 16] |
69 | Conv2d | 589 K | 589 K | 2.4 M | 150 M | 301 M | [1, 256, 16, 16] | [1, 256, 16, 16] |
74 | Conv2d | 589 K | 589 K | 2.4 M | 150 M | 301 M | [1, 256, 16, 16] | [1, 256, 16, 16] |
79 | Conv2d | 1.2 M | 1.2 M | 4.7 M | 75.5 M | 150 M | [1, 256, 16, 16] | [1, 512, 8, 8] |
84 | Conv2d | 2.4 M | 2.4 M | 9.4 M | 150 M | 301 M | [1, 512, 8, 8] | [1, 512, 8, 8] |
89 | Conv2d | 131 K | 131 K | 524 K | 8.4 M | 16.8 M | [1, 256, 8, 8] | [1, 512, 8, 8] |
92 | Conv2d | 2.4 M | 2.4 M | 9.4 M | 150 M | 301 M | [1, 512, 8, 8] | [1, 512, 8, 8] |
97 | Conv2d | 2.4 M | 2.4 M | 9.4 M | 150 M | 301 M | [1, 512, 8, 8] | [1, 512, 8, 8] |
We can see from the summary that the ResNet18-D model is about 45 MB
in size and needs to perform about 2.7 billion
floating-point operations to process a single 256x256
input image. For context, the larger mid-size ResNet50-D model is about 95 MB
and performs about 5.7 billion
floating-point ops for the same image. On the other end, the tiniest variant of the mobile-optimized MobileNetV3 model is 2.4 MB
and only takes about 30 million
floating-point operations.
That is valuable information when considering how we will deploy the fine-tuned model. For example, the in-browser demo I mentioned will first download the model to your local machine. The larger the model, the longer it will take for the demo to start. Likewise, the number of floating-point operations will influence what hardware can run the model smoothly. For real-time applications, even milliseconds can matter for inference speed. Inference involves making predictions with a trained model on new, unseen data.
The model architecture also influences inference speed beyond the raw number of floating-point operations. The MobileNetv3 architecture is tuned to mobile phone CPUs, while the ResNet architectures can better leverage larger GPUs.
That completes the model selection and setup. In the next section, we will prepare our dataset for training.
Preparing the Data
Next, we will prepare our data for the model training process. The data preparation involves several steps, such as applying data augmentation techniques, setting up the train-validation split for the dataset, resizing and padding the images, defining the training dataset class, and initializing DataLoaders to feed data to the model.
Selecting a Sample Image
Let’s begin by selecting a random image from the dataset to visualize the data preparation steps.
# Select a random item from the dataset
= random.choice(dataset)
item
= class_names[item['label']]
label = item['image']
sample_img
print(f"Image Label: {label}")
# Display the image
sample_img
Image Label: stop
Data Augmentation
Next, we’ll define what data augmentations to apply to images during training. Data augmentation is a technique that effectively expands the size and diversity of a dataset by creating variations of existing samples. It helps the model learn general features instead of memorizing specific examples.
We’ll use trivial augmentation, which applies a single, random transform to each image. This simple method can be highly effective for data augmentation.
However, we’ll need to create a custom version of the TrivialAugmentWide class from PyTorch’s transforms module, as some of the default parameters are not ideal for this dataset. This custom class defines a dictionary of operations for augmenting the images, and we can customize each operation’s parameters.
Trivial Augmentation
from torch import Tensor
from typing import Dict, Tuple, List, Optional
# This class extends the TrivialAugmentWide class provided by PyTorch's transforms module.
# TrivialAugmentWide is an augmentation policy randomly applies a single augmentation to each image.
class CustomTrivialAugmentWide(transforms.TrivialAugmentWide):
# The _augmentation_space method defines a custom augmentation space for the augmentation policy.
# This method returns a dictionary where each key is the name of an augmentation operation and
# the corresponding value is a tuple of a tensor and a boolean value.
# The tensor defines the magnitude of the operation, and the boolean defines
# whether to perform the operation in both the positive and negative directions (True)
# or only in the positive direction (False).
def _augmentation_space(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]:
# Define custom augmentation space
= {
custom_augmentation_space # Identity operation doesn't change the image
"Identity": (torch.tensor(0.0), False),
# Distort the image along the x or y axis, respectively.
"ShearX": (torch.linspace(0.0, 0.25, num_bins), True),
"ShearY": (torch.linspace(0.0, 0.25, num_bins), True),
# Move the image along the x or y axis, respectively.
"TranslateX": (torch.linspace(0.0, 32.0, num_bins), True),
"TranslateY": (torch.linspace(0.0, 32.0, num_bins), True),
# Rotate operation: rotates the image.
"Rotate": (torch.linspace(0.0, 45.0, num_bins), True),
# Adjust brightness, color, contrast,and sharpness respectively.
"Brightness": (torch.linspace(0.0, 0.75, num_bins), True),
"Color": (torch.linspace(0.0, 0.99, num_bins), True),
"Contrast": (torch.linspace(0.0, 0.99, num_bins), True),
"Sharpness": (torch.linspace(0.0, 0.99, num_bins), True),
# Reduce the number of bits used to express the color in each channel of the image.
"Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)).round().int(), False),
# Invert all pixel values above a threshold.
"Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
# Maximize the image contrast by setting the darkest color to black and the lightest to white.
"AutoContrast": (torch.tensor(0.0), False),
# Equalize the image histogram to improve its contrast.
"Equalize": (torch.tensor(0.0), False),
}
# The function returns the dictionary of operations.
return custom_augmentation_space
You can test the custom augmentation method by applying it to the sample image and displaying the result. Remember, the augmentation should be different each time we use it.
Test the transform
# Create a TrivialAugmentWide object
= CustomTrivialAugmentWide()
trivial_aug
# Convert the image to a tensor
= transforms.PILToTensor()(sample_img)[None]
img_tensor print(img_tensor.shape)
# Apply the TrivialAugmentWide to the tensor
tensor_to_pil(trivial_aug(img_tensor))
torch.Size([1, 3, 512, 512])
Training-Validation Split
Next, we’ll split the dataset into training and validation sets. The model will use the training set to update its parameters, and we will use the validation set to evaluate the model’s performance on data it has not seen before. Validation sets are needed when training models because we want to verify the model can generalize well to new data before we deploy it.
Get training and validation sets
# Define the percentage of the images that should be used for validation
= 0.1
val_pct
# Split the dataset into training and validation sets
= dataset.train_test_split(test_size=val_pct).values()
train_split, val_split
# Print the number of images in the training and validation sets
len(train_split), len(val_split)
(138361, 15374)
Image Resizing and Padding
Next, we define a class called ResizePad
to resize and pad images, making them a uniform size. Making all the input images the same size allows us to feed samples to the model in batches, allowing us the more efficiently leverage our GPU.
Resize Pad Transform
class ResizePad(nn.Module):
def __init__(self, max_sz=256, padding_mode='edge'):
"""
A PyTorch module that resizes an image tensor and adds padding to make it a square tensor.
Args:
max_sz (int, optional): The size of the square tensor.
padding_mode (str, optional): The padding mode used when adding padding to the tensor.
"""
super().__init__()
self.max_sz = max_sz
self.padding_mode = padding_mode
def forward(self, x):
# Get the width and height of the image tensor
= TF.get_image_size(x)
w, h
# Resize the image tensor so that its minimum dimension is equal to `max_sz`
= int(min(w, h) / (max(w, h) / self.max_sz))
size = TF.resize(x, size=size, antialias=True)
x
# Add padding to make the image tensor a square
= TF.get_image_size(x)
w, h = (self.max_sz - min(w, h)) // 2
offset = [0, offset] if h < w else [offset, 0]
padding = TF.pad(x, padding=padding, padding_mode=self.padding_mode)
x = TF.resize(x, size=[self.max_sz] * 2, antialias=True)
x
return x
For training, we’ll resize and pad the images to 288x288
. That should be large enough for the model to get adequate detail while keeping training time low.
Set training image size
= (288,288) train_sz
You can test the ResizePad transform on the sample image to see how it resizes and pads the sample. The following code crops the sample image to make the padding more apparent.
Test the transform
print(f"Source image: {sample_img.size}")
# Crop the source image
= sample_img.size
w, h = sample_img.crop([0, h//4, w, h-h//4])
cropped_img print(f"Cropped image: {cropped_img.size}")
# Create a `ResizePad` object
= ResizePad(max_sz=max(train_sz))
resize_pad
# Convert the cropped image to a tensor
= transforms.PILToTensor()(cropped_img)[None]
img_tensor print(f"Cropped tensor: {img_tensor.shape}")
# Resize and pad the tensor
= resize_pad(img_tensor)
resized_tensor print(f"Padded tensor: {resized_tensor.shape}")
# Display the updated image
tensor_to_pil(resized_tensor)
Source image: (512, 512)
Cropped image: (512, 256)
Cropped tensor: torch.Size([1, 3, 256, 512])
Padded tensor: torch.Size([1, 3, 288, 288])
Training Dataset Class
Next, we define a custom PyTorch Dataset class that will get used in a DataLoader to create batches. This class fetches a sample from the dataset at a given index and returns the transformed image and its corresponding label index.
class ImageDataset(Dataset):
"""A PyTorch Dataset class to be used in a DataLoader to create batches.
Attributes:
dataset: A list of dictionaries containing 'label' and 'image' keys.
classes: A list of class names.
tfms: A torchvision.transforms.Compose object combining all the desired transformations.
"""
def __init__(self, dataset, classes, tfms):
self.dataset = dataset
self.classes = classes
self.tfms = tfms
def __len__(self):
"""Returns the total number of samples in this dataset."""
return len(self.dataset)
def __getitem__(self, idx):
"""Fetches a sample from the dataset at the given index.
Args:
idx: The index to fetch the sample from.
Returns:
A tuple of the transformed image and its corresponding label index.
"""
= self.dataset[idx]
sample = sample['image'], sample['label']
image, label return self.tfms(image), label
Image Transforms
We’ll then define the transformations for the training and validation datasets. Note that we only apply data augmentation to the training dataset. Both datasets will have their images resized and padded and the pixel values normalized.
# Define the transformations for training and validation datasets
# Note: Data augmentation is performed only on the training dataset
= transforms.Compose([
train_tfms =max(train_sz)),
ResizePad(max_sz
trivial_aug,
transforms.ToTensor(),*norm_stats),
transforms.Normalize(
])
= transforms.Compose([
valid_tfms =max(train_sz)),
ResizePad(max_sz
transforms.ToTensor(),*norm_stats),
transforms.Normalize( ])
Initialize Datasets
We instantiate the PyTorch datasets using the dataset splits, class names, and defined transformations.
# Instantiate the datasets using the defined transformations
= ImageDataset(dataset=train_split, classes=class_names, tfms=train_tfms)
train_dataset = ImageDataset(dataset=val_split, classes=class_names, tfms=valid_tfms)
valid_dataset
# Print the number of samples in the training and validation datasets
print(f'Training dataset size: {len(train_dataset)}')
print(f'Validation dataset size: {len(valid_dataset)}')
Training dataset size: 138361
Validation dataset size: 15374
Inspect Samples
Let’s inspect a sample from the training and validation sets to verify that the data preparation steps get applied correctly.
Inspect training set sample
# Get the label for the first image in the training set
print(f"Label: {train_dataset.classes[train_dataset[0][1]]}")
# Get the first image in the training set
0][0], *norm_stats)) tensor_to_pil(denorm_img_tensor(train_dataset[
Label: dislike
Inspect validation set sample
# Get the label for the first image in the validation set
print(f"Label: {valid_dataset.classes[valid_dataset[0][1]]}")
# Get the first image in the validation set
0][0], *norm_stats)) tensor_to_pil(denorm_img_tensor(valid_dataset[
Label: no_gesture
We then define the batch size for training and initialize the DataLoaders, which are used to efficiently create batches of data for the model to process during training.
Training Batch Size
Next, we set the batch size for training. This number indicates how many sample images get fed to the model at once. The larger the batch size, the more GPU memory we need. The current batch size should be fine for most modern GPUs. If you still get an out-of-memory error, try lowering the batch size to 8
, then restart the Jupyter Notebook.
= 32 bs
Initialize DataLoaders
We initialize the DataLoaders for the training and validation datasets. We’ll set the number of worker processes for loading data to the number of available CPUs.
# Set the number of worker processes for loading data. This should be the number of CPUs available.
= multiprocessing.cpu_count()
num_workers
# Define parameters for DataLoader
= {
data_loader_params 'batch_size': bs, # Batch size for data loading
'num_workers': num_workers, # Number of subprocesses to use for data loading
'persistent_workers': True, # If True, the data loader will not shutdown the worker processes after a dataset has been consumed once. This allows to maintain the worker dataset instances alive.
'pin_memory': True, # If True, the data loader will copy Tensors into CUDA pinned memory before returning them. Useful when using GPU.
'pin_memory_device': device, # Specifies the device where the data should be loaded. Commonly set to use the GPU.
}
# Create DataLoader for training data. Data is shuffled for every epoch.
= DataLoader(train_dataset, **data_loader_params, shuffle=True)
train_dataloader
# Create DataLoader for validation data. Shuffling is not necessary for validation data.
= DataLoader(valid_dataset, **data_loader_params)
valid_dataloader
# Print the number of batches in the training and validation DataLoaders
print(f'Number of batches in train DataLoader: {len(train_dataloader)}')
print(f'Number of batches in validation DataLoader: {len(valid_dataloader)}')
Number of batches in train DataLoader: 4324
Number of batches in validation DataLoader: 481
That completes the data preparation. Now we can finally train our hand gesture recognizer.
Fine-tuning the Model
In this section, we will implement the training code and fine-tune our model. The training process revolves around the concept of an ‘epoch’. Each epoch represents one complete pass through the entire training and validation datasets. To help with this, we will define a function called run_epoch
to handle a single training/validation epoch and one called train_loop
to execute the main training loop.
Define the Training Loop
Let’s start by defining the run_epoch
function. This function runs a single training or validation epoch and calculates the loss and performance metric for the given dataset. The term ‘loss’ refers to a number representing how far our model’s predictions are from the actual values. The goal of training is to minimize this value. We use the autocast
context manager to perform mixed precision training. Mixed-precision training involves performing some operations in 16-bit
precision to speed up training and reduce memory requirements. Modern GPUs tend to have specialized hardware to accelerate these lower-precision operations, and this feature allows us to utilize that.
# Function to run a single training/validation epoch
def run_epoch(model, dataloader, optimizer, metric, lr_scheduler, device, scaler, is_training):
# Set model to training mode if 'is_training' is True, else set to evaluation mode
if is_training else model.eval()
model.train()
# Reset the performance metric
metric.reset()# Initialize the average loss for the current epoch
= 0
epoch_loss # Initialize progress bar with total number of batches in the dataloader
= tqdm(total=len(dataloader), desc="Train" if is_training else "Eval")
progress_bar
# Iterate over data batches
for batch_id, (inputs, targets) in enumerate(dataloader):
# Move inputs and targets to the specified device (e.g., GPU)
= inputs.to(device), targets.to(device)
inputs, targets
# Enables gradient calculation if 'is_training' is True
with torch.set_grad_enabled(is_training):
# Automatic Mixed Precision (AMP) context manager for improved performance
with autocast(device):
= model(inputs) # Forward pass
outputs = torch.nn.functional.cross_entropy(outputs, targets) # Compute loss
loss
# Update the performance metric
metric.update(outputs.detach().cpu(), targets.detach().cpu())
# If in training mode
if is_training:
if scaler is not None: # If using AMP
# Scale the loss and backward propagation
scaler.scale(loss).backward()# Make an optimizer step
scaler.step(optimizer) # Update the scaler
scaler.update() else:
# Backward propagation
loss.backward() # Make an optimizer step
optimizer.step()
# Clear the gradients
optimizer.zero_grad() # Update learning rate
lr_scheduler.step()
= loss.item()
loss_item += loss_item
epoch_loss # Update progress bar
=metric.compute().item(),
progress_bar.set_postfix(accuracy=loss_item,
loss=epoch_loss/(batch_id+1),
avg_loss=lr_scheduler.get_last_lr()[0] if is_training else "")
lr
progress_bar.update()
# If loss is NaN or infinity, stop training
if math.isnan(loss_item) or math.isinf(loss_item):
print(f"Loss is NaN or infinite at epoch {epoch}, batch {batch_id}. Stopping training.")
break
progress_bar.close()return epoch_loss / (batch_id + 1)
This function performs one pass through the given dataset. It first sets the model to training or evaluation mode depending on whether we are training or validating. Then, for each batch of data, it performs a forward pass (calculating the predictions of the model), computes the loss, and then, if in training mode, performs a backward pass to adjust the model’s parameters.
Next, we’ll define the train_loop
function, which executes the main training loop. It iterates over each epoch, runs through the training and validation sets, and saves the best model based on the validation loss.
# Main training loop
def train_loop(model, train_dataloader, valid_dataloader, optimizer, metric, lr_scheduler, device, epochs, use_amp, checkpoint_path):
# Initialize GradScaler for Automatic Mixed Precision (AMP) if 'use_amp' is True
= GradScaler() if use_amp else None
scaler = float('inf')
best_loss
# Iterate over each epoch
for epoch in tqdm(range(epochs), desc="Epochs"):
# Run training epoch and compute training loss
= run_epoch(model, train_dataloader, optimizer, metric, lr_scheduler, device, scaler, is_training=True)
train_loss
with torch.no_grad():
# Run validation epoch and compute validation loss
= run_epoch(model, valid_dataloader, None, metric, None, device, scaler, is_training=False)
valid_loss
# If current validation loss is lower than the best one so far, save model and update best loss
if valid_loss < best_loss:
= valid_loss
best_loss = metric.compute().item()
metric_value
torch.save(model.state_dict(), checkpoint_path)
= {
training_metadata 'epoch': epoch,
'train_loss': train_loss,
'valid_loss': valid_loss,
'metric_value': metric_value,
'learning_rate': lr_scheduler.get_last_lr()[0],
'model_architecture': model.name
}
# Save best_loss and metric_value in a JSON file
with open(Path(checkpoint_path.parent/'training_metadata.json'), 'w') as f:
json.dump(training_metadata, f)
# If loss is NaN or infinity, stop training
if any(math.isnan(loss) or math.isinf(loss) for loss in [train_loss, valid_loss]):
print(f"Loss is NaN or infinite at epoch {epoch}. Stopping training.")
break
# If using AMP, clean up the unused memory in GPU
if use_amp:
torch.cuda.empty_cache()
This function coordinates the training process. It runs the previously defined run_epoch
function for each epoch in the training process, calculating the training and validation losses. It saves the model state as a checkpoint when the model achieves a lower validation loss. It will also save data such as the current epoch, loss values, metric value, learning rate, and model name to a JSON file. If the run_epoch
function returns NaN
(Not a Number) or infinity loss values, it halts the training process since this typically indicates an issue with the training.
Set the Model Checkpoint Path
Before we proceed with training, let’s generate a timestamp for the training session and create a directory to store the checkpoints. These checkpoints will allow us to save the model state periodically. That enables us to load the model checkpoint later to resume training, export the model to a different format or perform inference directly.
# Generate timestamp for the training session (Year-Month-Day_Hour_Minute_Second)
= datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
timestamp
# Create a directory to store the checkpoints if it does not already exist
= Path(project_dir/f"{timestamp}")
checkpoint_dir
# Create the checkpoint directory if it does not already exist
=True, exist_ok=True)
checkpoint_dir.mkdir(parents
# The model checkpoint path
= checkpoint_dir/f"{model.name}.pth"
checkpoint_path
print(checkpoint_path)
pytorch-timm-image-classifier/2023-05-22_16-35-03/resnet18d.pth
Configure the Training Parameters
Now, let’s configure the parameters for training. We’ll define the learning rate, number of training epochs, optimizer, learning rate scheduler, and performance metric and check for CUDA-capable GPU availability. The learning rate determines how much we adjust the model in response to the estimated error each time the weights are updated. Choosing an optimal learning rate is essential for good model performance.
We’re using AdamW as our optimizer, which includes weight decay for regularization, and the OneCycleLR scheduler to adjust the learning rate during training. The one-cycle learning rate policy is a training approach where the learning rate starts low, increases gradually to a maximum, then decreases again, all within a single iteration or epoch, aiming to converge faster and yield better performance.
# Learning rate for the model
= 1e-3
lr
# Number of training epochs
= 3
epochs
# AdamW optimizer; includes weight decay for regularization
= torch.optim.AdamW(model.parameters(), lr=lr, eps=1e-5)
optimizer
# Learning rate scheduler; adjusts the learning rate during training
= torch.optim.lr_scheduler.OneCycleLR(optimizer,
lr_scheduler =lr,
max_lr=epochs*len(train_dataloader))
total_steps
# Performance metric: Multiclass Accuracy
= MulticlassAccuracy()
metric
# Check for CUDA-capable GPU availability
= torch.cuda.is_available() use_amp
We’ll use Multiclass Accuracy for our performance metric as this is a multiclass classification problem where each image falls into one of many classes.
Train the Model
Finally, we can train the model using the train_loop
function. Training time will depend on the available hardware. Feel free to take a break if the progress bar indicates it will take a while.
Training usually takes around 1 hour and 20 minutes on the free GPU tier of Google Colab.
train_loop(model, train_dataloader, valid_dataloader, optimizer, metric, lr_scheduler, device, epochs, use_amp, checkpoint_path)
Epochs: 100%|█████████| 3/3 [11:15<00:00, 224.96s/it]
Train: 100%|██████████| 4324/4324 [03:29<00:00, 21.75it/s, accuracy=0.894, avg_loss=0.374, loss=0.0984, lr=0.000994]
Eval: 100%|██████████| 481/481 [00:17<00:00, 50.42it/s, accuracy=0.975, avg_loss=0.081, loss=0.214, lr=]
Train: 100%|██████████| 4324/4324 [03:28<00:00, 22.39it/s, accuracy=0.968, avg_loss=0.105, loss=0.0717, lr=0.000462]
Eval: 100%|██████████| 481/481 [00:16<00:00, 55.14it/s, accuracy=0.988, avg_loss=0.0354, loss=0.02, lr=]
Train: 100%|██████████| 4324/4324 [03:28<00:00, 21.94it/s, accuracy=0.99, avg_loss=0.0315, loss=0.00148, lr=4.03e-9]
Eval: 100%|██████████| 481/481 [00:16<00:00, 53.87it/s, accuracy=0.995, avg_loss=0.0173, loss=0.000331, lr=]
At last, we have our hand gesture recognizer. The readout for the final validation run indicates the model achieved an approximate 99.5%
accuracy, meaning it missed less than 100
of the 15,374
samples in the validation set. To wrap up the tutorial, we’ll test our fine-tuned model by performing inference on individual images.
Making Predictions with the Model
In this final part of the tutorial, you will learn how to make predictions with the fine-tuned model on individual images, allowing us to see the model in action and assess its performance. Understanding how to apply trained models is crucial for implementing them in real-world applications.
Let’s start by setting the minimum input dimension for inference. We’ll use the same size as the input images we used during training but allow non-square input.
# Set the minimum input dimension for inference
= max(train_sz) infer_sz
Next, we will randomly select an image from our dataset and resize it to the inference size. The resize_img
function will scale the image so the smallest dimension is the specified inference size while maintaining the original aspect ratio.
# Choose a random item from the dataset
= random.choice(dataset)
item
# Open the image and resize it
= item['image']
sample_img = resize_img(sample_img.copy(), infer_sz) inp_img
We then convert the image to a normalized tensor using the pil_to_tensor
function and move it to the device where our model resides (CPU or GPU).
# Convert the image to a normalized tensor and move it to the device
= pil_to_tensor(inp_img, *norm_stats).to(device=device) img_tensor
With our image prepared, we can now use our model to make a prediction. The following code block runs our model in a no-gradient context using torch.no_grad()
. That informs PyTorch that we do not need to keep track of gradients in this operation, saving memory.
# Make a prediction with the model
with torch.no_grad():
= model(img_tensor) pred
After obtaining the raw prediction, we apply the Softmax function to convert these values into probabilities that sum up to 1.
# Scale the model predictions to add up to 1
= torch.softmax(pred, dim=1) pred_scores
Then, we retrieve the highest confidence score and its corresponding class index. The class index is converted into the actual class name using the train_dataset.classes
lookup table.
# Get the highest confidence score
= pred_scores.max()
confidence_score
# Get the class index with the highest confidence score and convert it to the class name
= train_dataset.classes[torch.argmax(pred_scores)] pred_class
We then format the prediction data, including the target class, the predicted class, and the confidence score of the prediction, as a Pandas Series and print it.
# Store the prediction data in a Pandas Series for easy formatting
= pd.Series({
pred_data "Target Class:": class_names[item['label']],
"Predicted Class:": pred_class,
"Confidence Score:": f"{confidence_score*100:.2f}%"
})
# Print the prediction data
print(pred_data.to_string(header=False))
Target Class: like
Predicted Class: like
Confidence Score: 100.00%
Finally, we display the sample image for visual verification.
# Display the image
sample_img
For this sample, the model was 100% confident in its prediction. The model will likely be less sure about images it has not seen before.
Testing the Model on New Data
Let’s try an image from the free stock photo site, Pexels.
import urllib.request
= "https://huggingface.co/datasets/cj-mills/pexel-hand-gesture-test-images/resolve/main/pexels-2769554-man-doing-rock-and-roll-sign.jpg"
test_img_url = Path("./pexels-man-doing-rock-and-roll-sign.jpg")
test_img_path
if test_img_path.is_file():
print("Image already exists.")
else:
urllib.request.urlretrieve(test_img_url, test_img_path)print("Image downloaded.")
= Image.open(test_img_path)
sample_img sample_img
Image already exists.
This image is a bit tricky. The man in the photo makes a rock
gesture with their left hand, but their right hand is visible and strongly resembles samples from the no_gesture
class. Let’s see what the model predicts using the same approach as above.
= "rock"
target_cls
# Set the minimum input dimension for inference
= max(train_sz)
infer_sz
= resize_img(sample_img.copy(), infer_sz)
inp_img
# Convert the image to a normalized tensor and move it to the device
= pil_to_tensor(inp_img, *norm_stats).to(device=device)
img_tensor
# Make a prediction with the model
with torch.no_grad():
= model(img_tensor)
pred
# Scale the model predictions to add up to 1
= torch.softmax(pred, dim=1)
pred_scores
# Get the highest confidence score
= pred_scores.max()
confidence_score
# Get the class index with the highest confidence score and convert it to the class name
= train_dataset.classes[torch.argmax(pred_scores)]
pred_class
# Store the prediction data in a Pandas Series for easy formatting
= pd.Series({
pred_data "Input Size:": inp_img.size,
"Target Class:": target_cls,
"Predicted Class:": pred_class,
"Confidence Score:": f"{confidence_score*100:.2f}%"
})
# Print the prediction data
print(pred_data.to_string(header=False))
# Display the image
sample_img
Input Size: (288, 416)
Target Class: rock
Predicted Class: rock
Confidence Score: 99.70%
Even though this image is a different shape than the training data and has an idle hand, the model confidently predicts rock
as the most likely gesture class.
Saving the Class Labels
Let’s save the dataset class labels in a dedicated JSON file so we don’t need to load the whole dataset to make predictions with the model in the future. I’ll cover how to load the model checkpoint we saved earlier and use it for inference in a future tutorial.
# Save class labels
= {"classes": list(train_dataset.classes)}
class_labels
# Set file path
= checkpoint_dir/f"{dataset_name}-classes.json"
class_labels_path
# Save class labels in JSON format
with open(class_labels_path, "w") as write_file:
json.dump(class_labels, write_file)
print(class_labels_path)
pytorch-timm-image-classifier/2023-05-22_16-35-03/hagrid-classification-512p-no-gesture-150k-zip-classes.json
We now have a functioning hand-gesture recognizer and know how to make predictions with it on individual images. Before we wrap up this tutorial, let’s check out the in-browser demo I mentioned at the beginning of the post.
- Don’t forget to download the model checkpoint and class labels from the Colab Environment’s file browser. (tutorial link)
- Once you finish training and download the files, turn off hardware acceleration for the Colab Notebook to save GPU time. (tutorial link)
Exploring the In-Browser Demo
You’ve gotten your hands dirty with the code. Now let’s see our fine-tuned model in action! I’ve set up an online demo that allows you to interact with a hand gesture recognizer trained with this tutorial’s code in your web browser. No downloads or installations are required.
The demo includes sample images that you can use to test the model. Try these images first to see how the model interprets different hand gestures. Once you’re ready, you can switch on your webcam to provide live input to the model.
Online demos are a great way to see and share the fruits of your labor and explore ways to apply your hand gesture recognizer in real-world scenarios.
I invite you to share any interesting results or experiences with the demo in the comments below. Whether it’s a tricky input image the model handles or a surprising failure case, I’d love to hear about it!
Check out the demo below, and have fun exploring!
Conclusion
Congratulations on completing this tutorial on fine-tuning image classifiers with PyTorch and the timm library! You’ve taken significant strides in your machine learning journey by creating a practical hand gesture recognizer.
Throughout this tutorial, we’ve covered many topics, including setting up your Python environment, importing necessary dependencies, project initialization, dataset loading and exploration, model selection, data preparation, and model fine-tuning. Finally, we made predictions with our fine-tuned model on individual images and tested the model with an interactive, in-browser demo.
This hands-on tutorial underscored the practical applications of fine-tuning image classification models, especially when working with limited data and computational resources. The hand gesture recognizer you’ve built has many real-world applications, and you now have a solid foundation to tackle other image classification tasks.
If you’re intrigued by the underlying concepts leveraged in this tutorial and wish to deepen your understanding, I recommend fast.ai’s Practical Deep Learning for Coders course. By the end, you’ll thoroughly understand the model and training code and have the know-how to implement them from scratch.
While our tutorial concludes here, your journey in deep learning is far from over. In the upcoming tutorials, we’ll explore topics such as incorporating preprocessing and post-processing steps into the model, exporting the model to different formats for deployment, using the fine-tuned model to identify flawed training samples in our dataset, and building interactive in-browser demo projects similar to the one featured in this tutorial.
Once again, congratulations on your achievement, and keep learning!
If you found this guide helpful, consider sharing it with others and exploring some of my other tutorials linked below.
Recommended Tutorials
- Exporting timm Image Classifiers from Pytorch to ONNX: Learn how to export timm image classification models from PyTorch to ONNX and perform inference using ONNX Runtime.
- Training YOLOX Models for Real-Time Object Detection in Pytorch: Learn how to train YOLOX models for real-time object detection in PyTorch by creating a hand gesture detection model.