Exploring the Impact of Different Image Augmentations on Hand Gesture Recognition
Introduction
Computer vision models can learn to recognize complex patterns and objects in images. However, they often struggle to generalize to new input. One solution is data augmentation, a technique that expands the size and diversity of a dataset by creating variations of existing images. The model never sees the exact image twice, helping it learn general features versus memorizing specific examples.
In this post, I will explore how different augmentations can impact model performance using a hand gesture dataset. Hand gesture recognition has potential applications in human-computer interaction, robotics, and sign language interpretation. Hand gesture recognition is challenging due to the variability of hand poses, shapes, and lighting conditions, making it a good candidate for image augmentation. The results of these experiments should provide insights into the best practices for training deep learning models to handle diverse input.
You can find links to view the training code and run it on Google Colab and Kaggle below.
Jupyter Notebook | Colab | Kaggle |
---|---|---|
GitHub Repository | Open In Colab | Kaggle |
Environment Setup
The experiment code depends on the miniai library developed for the Practical Deep Learning 2022 Part 2 course. First, we’ll check if we have the miniai package installed. If not, we clone the course repository from GitHub and install the miniai package.
Install dependencies
# %%capture
# try:
# import miniai
# except ImportError:
# !git clone https://github.com/fastai/course22p2.git
# !pip install -e ./course22p2
Next, we need to install several pip packages.
# %%capture
# !pip install torch torchvision torchaudio torcheval
# !pip install numpy pandas pillow wandb
# !pip install cjm_pil_utils cjm_kaggle_utils cjm_pytorch_utils cjm_pandas_utils
After installing the required dependencies, we can import them into our Jupyter notebook.
Import dependencies
from pathlib import Path
import hashlib
import json
import random
import multiprocessing
import math
import os
from functools import partial
from copy import copy
from glob import glob
from PIL import Image
import numpy as np
import timm
from tqdm.auto import tqdm
# Import pandas module for data manipulation
import pandas as pd
# Set options for Pandas DataFrame display
'max_colwidth', None) # Do not truncate the contents of cells in the DataFrame
pd.set_option('display.max_rows', None) # Display all rows in the DataFrame
pd.set_option('display.max_columns', None) # Display all columns in the DataFrame
pd.set_option(
# Import PyTorch dependencies
import torch
import torch.nn as nn
from torchvision import transforms, models
import torchvision.transforms.functional as TF
from torch.utils.data import Dataset, DataLoader
from torcheval.tools import get_module_summary
from torcheval.metrics import MulticlassAccuracy
# Import miniai dependencies
import fastcore.all as fc
from miniai.datasets import DataLoaders
from miniai.conv import def_device
from miniai.learner import *
from miniai.activations import Hook, ActivationStats
from miniai.accel import MixedPrecision, AccelerateCB
from miniai.init import BatchTransformCB, GeneralRelu, init_weights, set_seed, conv
from miniai.sgd import BatchSchedCB
from miniai.resnet import ResBlock
# Import utility functions
from cjm_kaggle_utils.core import save_kaggle_creds, dl_kaggle
from cjm_pil_utils.core import resize_img, get_img_files, stack_imgs
from cjm_pytorch_utils.core import pil_to_tensor, tensor_to_pil, tensor_stats_df, get_torch_device, denorm_img_tensor
from cjm_pandas_utils.core import markdown_to_pandas
Dataset
The dataset contains images from HaGRID (HAnd Gesture Recognition Image Dataset) modified for image classification. It has 125,912 samples for 18 distinct hand gestures and 27,823 images containing no gestures. The dataset includes a wide range of people, environmental settings, illumination, and camera placement. You can access the dataset on Kaggle at the link below.
Reference Images
Class Image call dislike fist four like mute ok one palm peace peace_inverted rock stop stop_inverted three three2 two_up two_up_inverted
First, we set our Kaggle credentials by providing our Kaggle username and API token. We save the credentials to a file, which allows us to authenticate and download the dataset from Kaggle.
Set Kaggle username and API token
= ""
username = "" key
Save Kaggle credentials to file
=False) save_kaggle_creds(username, key, overwrite
Empty username.
Next, we define the directories to store the archive files and datasets and create the directories if they don’t already exist.
Set directory paths
# Define path to store datasets
= Path("/mnt/980_1TB_2/Datasets/")
dataset_dir # Create the dataset directory if it does not exist
=True, exist_ok=True)
dataset_dir.mkdir(parentsprint(f"Dataset Directory: {dataset_dir}")
# Define path to store archive files
= dataset_dir/'../Archive'
archive_dir # Create the archive directory if it does not exist
=True, exist_ok=True)
archive_dir.mkdir(parentsprint(f"Archive Directory: {archive_dir}")
Dataset Directory: /mnt/980_1TB_2/Datasets
Archive Directory: /mnt/980_1TB_2/Datasets/../Archive
We use the Kaggle dataset name to construct the save paths for the archive file and the extracted dataset.
Set Kaggle dataset
# Set the name of the dataset
= 'hagrid-classification-512p-no-gesture-150k'
dataset_name
# Construct the Kaggle dataset name by combining the username and dataset name
= f'innominate817/{dataset_name}'
kaggle_dataset
# Create the path to the zip file that contains the dataset
= Path(f'{archive_dir}/{dataset_name}.zip')
archive_path print(f"Archive Path: {archive_path}")
# Create the path to the directory where the dataset will be extracted
= Path(f'{dataset_dir}/{dataset_name}')
dataset_path print(f"Dataset Path: {dataset_path}")
Archive Path: /mnt/980_1TB_2/Datasets/../Archive/hagrid-classification-512p-no-gesture-150k.zip
Dataset Path: /mnt/980_1TB_2/Datasets/hagrid-classification-512p-no-gesture-150k
Finally, we download the Kaggle dataset and extract it to the specified directory.
Download Kaggle dataset
dl_kaggle(kaggle_dataset, archive_path, dataset_path)
Dataset already downloaded
The dataset organizes the images for different gesture types into separate subfolders.
Get image classes
# Get all subfolders within the dataset_path
= [folder for folder in dataset_path.glob('*/') if folder.is_dir()]
folders
# Store the folder names
= [f.name for f in folders]
class_names
# Print the list of class names
pd.DataFrame(class_names)
0 | |
---|---|
0 | call |
1 | dislike |
2 | fist |
3 | four |
4 | like |
5 | mute |
6 | no_gesture |
7 | ok |
8 | one |
9 | palm |
10 | peace |
11 | peace_inverted |
12 | rock |
13 | stop |
14 | stop_inverted |
15 | three |
16 | three2 |
17 | two_up |
18 | two_up_inverted |
After obtaining the class names for the dataset, we need to get the paths for all the images in the dataset.
Get image paths
# Get a list of all image file paths from the given folders
= [get_img_files(folder) for folder in folders]
img_paths
# Flatten the list of lists into a single list
= [path for class_paths in img_paths for path in class_paths]
img_paths
# Get the number of image file paths
len(img_paths)
153735
Methodology
I tracked the experiment using Weights & Biases and initialized all training runs with the same random seed to ensure the consistency and reproducibility of the results. I then reran the experiment with a few different seed values. You can view the Weights & Biases projects at the links below.
- miniai_data_augmentation_seed-1
- miniai_data_augmentation_seed-42
- miniai_data_augmentation_seed-100
- miniai_data_augmentation_seed-1234
You can view the projects without an account, but you will need one to create new projects.
Log in to Weights & Biases
import wandb
wandb.login()
True
Set random number seed
# seed = 1
# seed = 42
# seed = 100
= 1234
seed set_seed(seed)
Set device and data type
= get_torch_device()
device = torch.float32
dtype device, dtype
('cuda', torch.float32)
Image Augmentations
For this experiment, I tested five different image augmentations. We can visualize each one using a sample image from the dataset.
Display sample image
# Select a random image path from the list of image paths
= random.choice(img_paths)
img_path
# Print the name of the class of the image, which is the name of the parent folder
print(f"Class: {img_path.parent.name}")
# Open the image using the path
= Image.open(img_path)
sample_img
# Display the image
sample_img
Class: stop
The first augmentation copies a square section from a random location on the image and pastes it in a random spot. I use random patch sizes ranging from 10-20% of the image dimensions to reduce the chance of obscuring the hand gestures.
Random Square Copy Transform
def rand_square_copy(img_tensor:torch.Tensor, pct:float):
"""
Copy data from a randomly selected square region to another randomly selected square region of an image tensor.
Args:
img_tensor (torch.Tensor): The input image tensor.
pct (float): The percentage of the image tensor's size to be used as the side length of the square regions.
Returns:
torch.Tensor: The modified input image tensor.
"""
# Get the shape of the tensor
= img_tensor.shape[-2:]
h, w
# Calculate the size of the square
= int(pct * h)
szx = int(pct * w)
szy
# Calculate the top-left coordinate of the first square
= int(random.random() * (1 - pct) * h)
stx1 = int(random.random() * (1 - pct) * w)
sty1
# Calculate the top-left coordinate of the second square
= int(random.random() * (1 - pct) * h)
stx2 = int(random.random() * (1 - pct) * w)
sty2
# Copy the data from the second square to the first square
+ szx, sty1:sty1 + szy] = img_tensor[..., stx2:stx2 + szx, sty2:sty2 + szy]
img_tensor[..., stx1:stx1
# Return the modified input img_tensor
return img_tensor
class RandSquareCopy(nn.Module):
def __init__(self, pct=0.2, min_num=0, max_num=4):
"""
A PyTorch module that applies the `rand_square_copy` function to an input tensor multiple times.
Args:
pct (float, optional): The percentage of the tensor's size to be used as the side length of
the square regions.
min_num (int, optional): The minimum number of times to apply the `rand_square_copy` function.
max_num (int, optional): The maximum number of times to apply the `rand_square_copy` function.
"""
super().__init__()
self.pct = pct
self.min_num = min_num
self.max_num = max_num
def forward(self, x):
# Apply the `rand_square_copy` function to the input tensor multiple times
for i in range(random.randint(self.min_num, self.max_num)):
= rand_square_copy(x, max(0.1,random.random()*self.pct))
x return x
Test the transform
= RandSquareCopy(pct=0.2, min_num=1, max_num=1)
rand_square_copy_tf
# Convert the PIL image to a tensor
= transforms.PILToTensor()(sample_img)
img_tensor print(img_tensor.shape)
# Apply the `rand_square_copy_tf` transform to the image tensor
tensor_to_pil(rand_square_copy_tf(img_tensor))
torch.Size([3, 512, 512])
The second augmentation copies random pixels and pastes them in different indices. I use random amounts ranging from 0.25-1.5% of the pixels.
Random Pixel Copy Transform
def rand_pixel_copy(img_tensor:torch.Tensor, pct:float=0.2):
"""
Copy data from a randomly selected set of pixels to another randomly selected set of pixels of a image tensor.
Args:
img_tensor (torch.Tensor): The input image tensor.
pct (float, optional): The percentage of the total number of pixels to be selected as the source and target sets
of pixels.
Returns:
torch.Tensor: The modified input image tensor.
"""
= img_tensor.dim()
src_dim
= img_tensor.unsqueeze(0) if src_dim == 3 else img_tensor
img_tensor
# Get the shape of the img_tensor
= img_tensor.shape
b, c, h, w
# Calculate the number of pixels to be selected
= int(img_tensor[-1:].numel() * pct)
num_pixels
# Select the source pixel indices
= torch.LongTensor(num_pixels, 2).random_(0, h * w)
source_indices 0] = source_indices[:, 0] // w
source_indices[:, 1] = source_indices[:, 1] % w
source_indices[:,
# Select the target pixel indices
= torch.LongTensor(num_pixels, 2).random_(0, h * w)
target_indices 0] = target_indices[:, 0] // w
target_indices[:, 1] = target_indices[:, 1] % w
target_indices[:,
# Get the indices of the channels
= torch.arange(c).repeat(num_pixels, 1).t()
c_indices
# Copy the pixels
= img_tensor[:, c_indices, source_indices[:, 0], source_indices[:, 1]]
source_pixels 0], target_indices[:, 1]] = source_pixels
img_tensor[:, c_indices, target_indices[:,
return img_tensor.squeeze(0) if src_dim == 3 else img_tensor
class RandPixelCopy(nn.Module):
def __init__(self, pct=0.1):
"""
A PyTorch module that applies the `rand_pixel_copy` function to an input tensor.
Args:
pct (float, optional): The maximum percentage of the tensor's pixels to be copied.
"""
super().__init__()
self.pct = pct
def forward(self, x):
# Apply the `rand_pixel_copy` function to the input tensor
return rand_pixel_copy(x, max(0.0025, random.random() * self.pct))
Test the transform
= RandPixelCopy(pct=0.015)
rand_pixel_copy_tf
# Convert the PIL image to a tensor
= transforms.PILToTensor()(sample_img)
img_tensor print(img_tensor.shape)
# Apply the `rand_pixel_copy_tf` transform to the image tensor
tensor_to_pil(rand_pixel_copy_tf(img_tensor))
torch.Size([3, 512, 512])
The third combines the first two approaches by copying random squares and then applying the pixel copy transform on those squares. I use the same patch size range as the first augmentation but copy up to 100% of the pixels.
Random Square Pixel Copy Transform
def rand_square_pixel_copy(img_tensor:torch.Tensor, square_pct:float, pixel_pct:float):
"""
Copy data from a randomly selected set of pixels of a randomly selected square region to another randomly
selected square region of a tensor.
Args:
img_tensor (torch.Tensor): The input tensor.
square_pct (float): The percentage of the tensor's size to be used as the side length of the square regions.
pixel_pct (float): The percentage of the pixels of the source square region to be copied.
Returns:
torch.Tensor: The modified input tensor.
"""
# Get the shape of the tensor
= img_tensor.shape[-2:]
h, w
# Calculate the size of the square
= int(square_pct * h)
szx = int(square_pct * w)
szy
# Calculate the top-left coordinate of the first square
= int(random.random() * (1 - square_pct) * h)
stx1 = int(random.random() * (1 - square_pct) * w)
sty1
# Calculate the top-left coordinate of the second square
= int(random.random() * (1 - square_pct) * h)
stx2 = int(random.random() * (1 - square_pct) * w)
sty2
# Copy the data from the second square to the first square
= rand_pixel_copy(img_tensor[..., stx2:stx2 + szx, sty2:sty2 + szy].clone(), pixel_pct)
source_pixels + szx, sty1:sty1 + szy] = source_pixels
img_tensor[..., stx1:stx1
# Return the modified input tensor
return img_tensor
class RandSquarePixelCopy(nn.Module):
def __init__(self, square_pct=0.2, pixel_pct=1.0, min_num=0, max_num=4):
"""
A PyTorch module that applies the `rand_square_pixel_copy` function to an input tensor.
Args:
square_pct (float, optional): The percentage of the tensor's size to be used as the side length of the
square regions.
pixel_pct (float, optional): The maximum percentage of the pixels of the source square region to be copied.
min_num (int, optional): The minimum number of times the `rand_square_pixel_copy` function is applied to
the input tensor.
max_num (int, optional): The maximum number of times the `rand_square_pixel_copy` function is applied to
the input tensor.
"""
super().__init__()
self.square_pct = square_pct
self.pixel_pct = pixel_pct
self.min_num = min_num
self.max_num = max_num
def forward(self, x):
# Apply the `rand_square_pixel_copy` function to the input tensor
for i in range(random.randint(self.min_num, self.max_num)):
= rand_square_pixel_copy(x,
x max(0.1,random.random()*self.square_pct),
max(0.0025, random.random()*self.pixel_pct))
return x
Test the transform
= RandSquarePixelCopy(square_pct=0.2, pixel_pct=1.0, min_num=1, max_num=1)
rand_square_pixel_copy_tf
= transforms.PILToTensor()(sample_img)
img_tensor print(img_tensor.shape)
tensor_to_pil(rand_square_pixel_copy_tf(img_tensor))
torch.Size([3, 512, 512])
I also tested the TrivialAugment augmentation from the paper linked below. TrivialAugment applies a single randomly selected transform to each image. I stuck with the default parameters.
Trivial Augmentation
# Create a TrivialAugmentWide object
= transforms.TrivialAugmentWide()
trivial_aug
# Convert the image to a tensor
= transforms.PILToTensor()(sample_img)[None]
img_tensor print(img_tensor.shape)
# Apply the TrivialAugmentWide to the tensor
tensor_to_pil(trivial_aug(img_tensor))
torch.Size([1, 3, 512, 512])
Finally, inspired by the TrivialAugment transform, I created a custom trivial augmentation that randomly applies the random square copy, random rand pixel copy, random square pixel copy, or no transform.
Custom Trivial Augmentation Transform
class CustomTrivAug(nn.Module):
def __init__(self, aug_list):
"""
A PyTorch module that applies a random image augmentation.
Args:
aug_list (list): List of functions that implement the desired augmentations.
"""
super().__init__()
self.aug_list = aug_list
def forward(self, x):
return random.choice(aug_list)(x)
Test the transform
= [
aug_list # Identity transformation (no augmentation applied)
lambda x:x,
# Random square copy transformation
rand_square_copy_tf,# Random pixel copy transformation
rand_pixel_copy_tf,# Random square pixel copy transformation
rand_square_pixel_copy_tf
]
# Create an instance of the CustomTrivAug module
= CustomTrivAug(aug_list=aug_list)
custom_triv_aug_tf
# Convert a sample image to a tensor
= transforms.PILToTensor()(sample_img)
img_tensor print(img_tensor.shape)
# Apply a random augmentation and convert the tensor back to an image
tensor_to_pil(custom_triv_aug_tf(img_tensor))
torch.Size([3, 512, 512])
Model
I went with the pretrained ResNet 18-D model from the timm library for its balance of accuracy and speed.
List available ResNet18 models
'resnet18*', pretrained=True)) pd.DataFrame(timm.list_models(
0 | |
---|---|
0 | resnet18 |
1 | resnet18d |
Inspect config for ResNet18 model
# Import the resnet module
from timm.models import resnet
# Choose the resnet model
= 'resnet18d'
resnet_model
# Get the default configuration of the chosen model as a Pandas DataFrame
='index') pd.DataFrame.from_dict(resnet.default_cfgs[resnet_model], orient
0 | |
---|---|
url | https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet18d_ra2-48a79e06.pth |
num_classes | 1000 |
input_size | (3, 224, 224) |
pool_size | (7, 7) |
crop_pct | 0.875 |
interpolation | bicubic |
mean | (0.485, 0.456, 0.406) |
std | (0.229, 0.224, 0.225) |
first_conv | conv1.0 |
classifier | fc |
We first get the normalization stats used to pre-train the model. We get the mean and standard deviation values for the red, green, and blue color channels from the default_cfgs
dictionary.
Get normalization stats
= resnet.default_cfgs[resnet_model]['mean'], resnet.default_cfgs[resnet_model]['std']
norm_stats norm_stats
((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
Next, we load the pretrained model and set the number of output classes for the model head. We’ll also update the device and data type of the model to the ones specified earlier.
Load ResNet18 model
# Load the resenet model
= timm.create_model(resnet_model, pretrained=True, num_classes=len(class_names))
resnet18
# Set the device, data type, and name
= resnet18.to(device=device, dtype=dtype)
resnet18 = device
resnet18.device = resnet_model resnet18.name
Select model
= resnet18 model
Get model summary
# Define the input to the model
= torch.randn(1, 3, 256, 256).to(device)
test_inp
# Get a summary of the model as a Pandas DataFrame
= markdown_to_pandas(f"{get_module_summary(model, [test_inp])}")
summary_df
# Filter the summary to only contain Conv2d layers and the model
= summary_df[(summary_df.index == 0) | (summary_df['Type'] == 'Conv2d')]
summary_df
# Remove the column "Contains Uninitialized Parameters?"
'Contains Uninitialized Parameters?', axis=1) summary_df.drop(
Type | # Parameters | # Trainable Parameters | Size (bytes) | Forward FLOPs | Backward FLOPs | In size | Out size | |
---|---|---|---|---|---|---|---|---|
0 | ResNet | 11.2 M | 11.2 M | 44.9 M | 2.7 G | 5.3 G | [1, 3, 256, 256] | [1, 19] |
2 | Conv2d | 864 | 864 | 3.5 K | 14.2 M | 14.2 M | [1, 3, 256, 256] | [1, 32, 128, 128] |
5 | Conv2d | 9.2 K | 9.2 K | 36.9 K | 150 M | 301 M | [1, 32, 128, 128] | [1, 32, 128, 128] |
8 | Conv2d | 18.4 K | 18.4 K | 73.7 K | 301 M | 603 M | [1, 32, 128, 128] | [1, 64, 128, 128] |
14 | Conv2d | 36.9 K | 36.9 K | 147 K | 150 M | 301 M | [1, 64, 64, 64] | [1, 64, 64, 64] |
19 | Conv2d | 36.9 K | 36.9 K | 147 K | 150 M | 301 M | [1, 64, 64, 64] | [1, 64, 64, 64] |
23 | Conv2d | 36.9 K | 36.9 K | 147 K | 150 M | 301 M | [1, 64, 64, 64] | [1, 64, 64, 64] |
28 | Conv2d | 36.9 K | 36.9 K | 147 K | 150 M | 301 M | [1, 64, 64, 64] | [1, 64, 64, 64] |
33 | Conv2d | 73.7 K | 73.7 K | 294 K | 75.5 M | 150 M | [1, 64, 64, 64] | [1, 128, 32, 32] |
38 | Conv2d | 147 K | 147 K | 589 K | 150 M | 301 M | [1, 128, 32, 32] | [1, 128, 32, 32] |
43 | Conv2d | 8.2 K | 8.2 K | 32.8 K | 8.4 M | 16.8 M | [1, 64, 32, 32] | [1, 128, 32, 32] |
46 | Conv2d | 147 K | 147 K | 589 K | 150 M | 301 M | [1, 128, 32, 32] | [1, 128, 32, 32] |
51 | Conv2d | 147 K | 147 K | 589 K | 150 M | 301 M | [1, 128, 32, 32] | [1, 128, 32, 32] |
56 | Conv2d | 294 K | 294 K | 1.2 M | 75.5 M | 150 M | [1, 128, 32, 32] | [1, 256, 16, 16] |
61 | Conv2d | 589 K | 589 K | 2.4 M | 150 M | 301 M | [1, 256, 16, 16] | [1, 256, 16, 16] |
66 | Conv2d | 32.8 K | 32.8 K | 131 K | 8.4 M | 16.8 M | [1, 128, 16, 16] | [1, 256, 16, 16] |
69 | Conv2d | 589 K | 589 K | 2.4 M | 150 M | 301 M | [1, 256, 16, 16] | [1, 256, 16, 16] |
74 | Conv2d | 589 K | 589 K | 2.4 M | 150 M | 301 M | [1, 256, 16, 16] | [1, 256, 16, 16] |
79 | Conv2d | 1.2 M | 1.2 M | 4.7 M | 75.5 M | 150 M | [1, 256, 16, 16] | [1, 512, 8, 8] |
84 | Conv2d | 2.4 M | 2.4 M | 9.4 M | 150 M | 301 M | [1, 512, 8, 8] | [1, 512, 8, 8] |
89 | Conv2d | 131 K | 131 K | 524 K | 8.4 M | 16.8 M | [1, 256, 8, 8] | [1, 512, 8, 8] |
92 | Conv2d | 2.4 M | 2.4 M | 9.4 M | 150 M | 301 M | [1, 512, 8, 8] | [1, 512, 8, 8] |
97 | Conv2d | 2.4 M | 2.4 M | 9.4 M | 150 M | 301 M | [1, 512, 8, 8] | [1, 512, 8, 8] |
Dataset Preparation
I split the dataset randomly into 90% for training and 10% for validation.
Get training and validation sets
# Shuffle the image paths
random.shuffle(img_paths)
# Define the percentage of the images that should be used for training
= 0.9
train_pct = 0.1
val_pct
# Calculate the index at which to split the subset of image paths into training and validation sets
= int(len(img_paths)*train_pct)
train_split = int(len(img_paths)*(train_pct+val_pct))
val_split
# Split the subset of image paths into training and validation sets
= img_paths[:train_split]
train_paths = img_paths[train_split:]
val_paths
# Print the number of images in the training and validation sets
len(train_paths), len(val_paths)
(138361, 15374)
The hand gestures in some images are hard to see at low resolutions, so I resized and padded the samples to 288x288
for training to balance computational efficiency and model performance.
Set training image size
= (288,288) train_sz
Resize Pad Transform
class ResizePad(nn.Module):
def __init__(self, max_sz=256, padding_mode='edge'):
"""
A PyTorch module that resizes an image tensor and adds padding to make it a square tensor.
Args:
max_sz (int, optional): The size of the square tensor.
padding_mode (str, optional): The padding mode used when adding padding to the tensor.
"""
super().__init__()
self.max_sz = max_sz
self.padding_mode = padding_mode
def forward(self, x):
# Get the width and height of the image tensor
= TF.get_image_size(x)
w, h
# Resize the image tensor so that its minimum dimension is equal to `max_sz`
= int(min(w, h) / (max(w, h) / self.max_sz))
size = TF.resize(x, size=size)
x
# Add padding to make the image tensor a square
= TF.get_image_size(x)
w, h = (self.max_sz - min(w, h)) // 2
offset = [0, offset] if h < w else [offset, 0]
padding = TF.pad(x, padding=padding, padding_mode=self.padding_mode)
x = TF.resize(x, size=[self.max_sz] * 2)
x
return x
Test the transform
print(f"Source image: {sample_img.size}")
# Crop the source image
= sample_img.size
w, h = sample_img.crop([0, h//4, w, h-h//4])
cropped_img print(f"Cropped image: {cropped_img.size}")
# Create a `ResizePad` object
= ResizePad(max_sz=max(train_sz))
resize_pad
# Convert the cropped image to a tensor
= transforms.PILToTensor()(cropped_img)[None]
img_tensor print(f"Cropped tensor: {img_tensor.shape}")
# Resize and pad the tensor
= resize_pad(img_tensor)
resized_tensor print(f"Padded tensor: {resized_tensor.shape}")
# Apply random pixel copy to the resized and padded tensor
tensor_to_pil(resized_tensor)
Source image: (512, 512)
Cropped image: (512, 256)
Cropped tensor: torch.Size([1, 3, 256, 512])
Padded tensor: torch.Size([1, 3, 288, 288])
The TrivialAugment transform requires PIL Images, so I applied all transforms per item rather than per batch for consistency.
Set image transforms
# Set transforms for training set
= [
train_tfms =max(train_sz)),
ResizePad(max_sz# transforms.TrivialAugmentWide(),
transforms.ToTensor(),*norm_stats),
transforms.Normalize(# rand_square_copy_tf,
# rand_pixel_copy_tf,
# rand_square_pixel_copy_tf
custom_triv_aug_tf,
]
# Set transforms for validation set
= [
valid_tfms =max(train_sz)),
ResizePad(max_sz
transforms.ToTensor(),*norm_stats),
transforms.Normalize( ]
We pass the list of transforms into a custom Dataset class along with the image paths and class names.
Define training dataset class
class ImageDataset(Dataset):
def __init__(self, img_paths, classes, tfms):
# Store the arguments as instance variables
self.img_paths = img_paths
self.classes = classes
self.tfms = tfms
# Create a mapping from class names to class indices
self.class_to_idx = {c: i for i, c in enumerate(self.classes)}
def __len__(self):
# Return the number of images in the dataset
return len(self.img_paths)
def __getitem__(self, index):
# Get the path of the image at the given index
= self.img_paths[index]
img_path
# Get the label of the image
= self.class_to_idx[img_path.parent.name]
label
# Open the image
= Image.open(img_path).convert('RGB')
image
return self.tfms(image), label
Initialize datasets
# create the training dataset using the composed transformations
= ImageDataset(img_paths=train_paths,
train_dataset =class_names,
classes=transforms.Compose(train_tfms))
tfms
# create the validation dataset
= ImageDataset(img_paths=val_paths,
valid_dataset =class_names,
classes=transforms.Compose(valid_tfms))
tfms
# print the length of the training and validation datasets
len(train_dataset), len(valid_dataset)
(138361, 15374)
Inspect training set sample
0][0], *norm_stats)) tensor_to_pil(denorm_img_tensor(train_dataset[
Inspect validation set sample
0][0], *norm_stats)) tensor_to_pil(denorm_img_tensor(valid_dataset[
Set training batch size
= 32 bs
Initialzie dataloaders
= DataLoader(train_dataset,
train_dataloader =bs,
batch_size=True,
shuffle=multiprocessing.cpu_count())
num_workers
= DataLoader(valid_dataset, batch_size=bs)
valid_dataloader
= DataLoaders(train_dataloader, valid_dataloader)
dls
# Print the number of batches in the training and validation dataloaders
len(dls.train), len(dls.valid)
(4324, 481)
Training
We can define a custom callback to integrate W&B with the miniai library. The one below tracks the loss and accuracy values for the training and validation sets.
Define training callbacks
class WandBCB(MetricsCB):
"""
A Callback class that integrates with the Weights and Biases (W&B) platform to log training
and validation metrics, as well as sample figures during training.
"""
=100
order
def __init__(self, config, *ms, project:str, name:str=None, **metrics):
"""
Initialize the WandBCB class.
Parameters:
config (dict): Configuration parameters for the W&B run.
ms (list): List of metrics to be logged.
project (str): The name of the W&B project.
name (str, optional): The name of the W&B run. Defaults to None.
metrics (dict, optional): Additional metrics to be logged.
"""
fc.store_attr()self.train = False
self.run = None
super().__init__(*ms, **metrics)
def before_fit(self, learn):
"""
Initialize the W&B run.
Parameters:
learn (miniai.learner.Learner): The miniai Learner object.
"""
# If a name for the run is provided, use it, otherwise use the default name
if self.name:
self.run = wandb.init(project=self.project, name=self.name, config=self.config)
else:
self.run = wandb.init(project=self.project, config=self.config)
def after_fit(self, learn):
"""
Finalize the W&B run.
Parameters:
learn (miniai.learner.Learner): The miniai Learner object.
"""
wandb.finish()
def _log(self, d):
"""
Log the metrics to W&B.
Parameters:
d (dict): Dictionary of metrics.
"""
# Log the metrics with a prefix of "train_" if in training mode, otherwise "val_"
if self.train:
'train_'+m:float(d[m]) for m in self.all_metrics})
wandb.log({else:
'val_'+m:float(d[m]) for m in self.all_metrics})
wandb.log({
def after_batch(self, learn):
"""
Log the loss after a batch
"""
super().after_batch(learn)
# Log the batch loss to W&B
'loss':learn.loss})
wandb.log({
def after_epoch(self, learn):
"""
Update the training mode flag and call the parent class's `after_epoch` method.
Parameters:
learn (miniai.learner.Learner): The miniai Learner object.
"""
# Update the training mode flag
self.train = learn.training
super().after_epoch(learn)
class OutOfBoundsCB(Callback):
"""
A callback that exits the training session if the loss is not NaN or infinite.
"""
def __init__(self, msg="Loss out of bounds"):
fc.store_attr()
def after_loss(self, learn):
"""
Check if the loss is NaN or infinite.
"""
if learn.loss.detach().isnan() or learn.loss.detach().isinf():
print(self.msg)
raise CancelFitException()
Initialize callbacks
= MetricsCB(accuracy=MulticlassAccuracy())
metrics
= [MixedPrecision(),
cbs
OutOfBoundsCB(),
DeviceCB(),
metrics, =False),
ProgressCB(plot ]
Set optimizer
= partial(torch.optim.AdamW, eps=1e-5) opt_func
Set learning rate scheduler
= 1e-3
lr = 3
epochs
= epochs * len(dls.train)
tmax = partial(torch.optim.lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax)
sched cbs.append(BatchSchedCB(sched))
Configure Weights & Biases project
= f"miniai_data_augmentation_seed-{seed}"
project_name # run_name = "baseline"
# run_name = "rand-square-cp-1"
# run_name = "rand-pixel-cp-1"
# run_name = "rand-square-pixel-cp-1"
# run_name = "trivial-aug-item-1"
= "custom-trivial-aug-1"
run_name
= {'lr':lr, 'epochs':epochs}
project_config
= WandBCB(config=project_config,
wandbcb =project_name,
project=run_name,
name=MulticlassAccuracy()
accuracy
) cbs.append(wandbcb)
Initialize learner
= TrainLearner(model, dls, torch.nn.functional.cross_entropy, lr=lr, cbs=cbs, opt_func=opt_func) learn
Train model
learn.fit(epochs)
accuracy | loss | epoch | train |
---|---|---|---|
0.905 | 0.343 | 0 | train |
0.966 | 0.113 | 0 | eval |
0.976 | 0.077 | 1 | train |
0.989 | 0.038 | 1 | eval |
0.994 | 0.021 | 2 | train |
0.994 | 0.021 | 2 | eval |
Set checkpoint directory
# Create a directory to store the checkpoints if it does not already exist
= Path(f"./{project_name}/")
checkpoint_dir =True, exist_ok=True)
checkpoint_dir.mkdir(parents
# Print the checkpoint path
checkpoint_dir
Path('miniai_data_augmentation_seed-1234')
Save model checkpoint
/f"{model.name}-{run_name}.pth") torch.save(learn.model.state_dict(), checkpoint_dir
Test model
= max(train_sz)
infer_sz
# Choose a random image from the list of image paths
= random.choice(img_paths)
img_path
# Print the actual class of the chosen image
print(f"Class: {img_path.parent.name}")
# Open the image and resize it
= Image.open(img_path)
sample_img = resize_img(sample_img.copy(), infer_sz)
inp_img
# Convert the image to a tensor and move it to the device
= pil_to_tensor(inp_img, *norm_stats).to(device=device)
img_tensor
# Make a prediction with the model
with torch.no_grad():
= learn.model(img_tensor)
pred
# Get the predicted class index and convert it to the class name
= train_dataset.classes[torch.argmax(torch.softmax(pred, dim=1))]
pred_class
# Print the predicted class
print(f"Predicted Class: {pred_class}")
# Display the image
sample_img
Class: rock
Predicted Class: rock
Results
All training runs, including the baseline no augmentation run, ended with over 99% accuracy across all random seeds.
Validation Set
The TrivialAugmentWide
transform had the highest validation accuracy across three of four random seeds, but the scores were too close to name a definitive winner.
Seed: 1
Seed: 42
Seed: 100
Seed: 1234
Test Set
I also ended each run by testing the models without augmentation on the entire dataset. I used a minimum resolution of 288p, as I did for training but did not enforce square aspect ratios. The results were quite different from the validation sets.
Perform inference on dataset
# Initialize a list to store the file paths of the images that were incorrectly classified
= []
wrong_imgs
# Iterate through the test image paths
for path in tqdm(img_paths):
# Get the actual class of the image
= path.parent.name
target_cls
# Open and resize the image
= Image.open(path)
sample_img = resize_img(sample_img, infer_sz)
sample_img
# Convert the image to a tensor and move it to the device
= pil_to_tensor(sample_img, *norm_stats).to(device=device)
img_tensor
# Make a prediction with the model
with torch.no_grad():
= learn.model(img_tensor)
pred
# Get the predicted class index and convert it to the class name
= train_dataset.classes[torch.argmax(torch.softmax(pred, dim=1))]
pred_cls
# If the prediction is incorrect, add the file path to the list of wrong images
if pred_cls != target_cls:
wrong_imgs.append(path)
len(wrong_imgs)
251
Inspect the number of wrong predictions per class
# Create a DataFrame from the list of incorrectly classified images
= pd.DataFrame(wrong_imgs)
wrong_imgs_df
# Add a column to the DataFrame with the actual class of each image
'class'] = wrong_imgs_df.apply(lambda row: Path(row[0]).parent.stem, axis=1)
wrong_imgs_df[
# Create a DataFrame with the class distribution of the incorrectly classified images
= wrong_imgs_df['class'].value_counts().to_frame().rename(columns={"class":run_name})
class_dist_df
# Set the index name to "class"
"class", inplace=True)
class_dist_df.rename_axis(
# Display the class distribution DataFrame
class_dist_df
custom-trivial-aug-1 | |
---|---|
class | |
peace | 27 |
one | 23 |
three | 22 |
no_gesture | 22 |
stop | 19 |
four | 17 |
two_up | 15 |
call | 15 |
stop_inverted | 13 |
like | 12 |
rock | 12 |
fist | 10 |
palm | 9 |
ok | 8 |
peace_inverted | 7 |
two_up_inverted | 6 |
mute | 5 |
dislike | 5 |
three2 | 4 |
Save DataFrame
= checkpoint_dir/f"class_dist_df-{run_name}.json"
class_dist_df_path class_dist_df.to_json(class_dist_df_path)
With this seed, the TrivialAugmentWide
model missed the second-most number of predictions during inference, even though it had the highest validation accuracy.
Load and concatenate DataFrames
def get_class_dist_agg_df(folder, ext='json'):
# Get a list of paths to `ext` files in the folder
= glob(os.path.join(folder, f"*.{ext}"))
df_paths
# Create a generator that reads the class distribution data from each file in df_paths
= (pd.read_json(f).rename_axis('class') for f in df_paths)
class_dist_dfs
# Concatenate the class distribution data into a single DataFrame
return pd.concat(class_dist_dfs, axis=1, sort=False).fillna(0).astype(int)
# Concatenate the class distribution data into a single DataFrame
= get_class_dist_agg_df(checkpoint_dir)
class_dist_agg_df
# Compute the sum of the class distribution data for each run
= class_dist_agg_df.sum(axis=0)
sums
# Add a row to the DataFrame with the totals
'Total'] = sums
class_dist_agg_df.loc[
# Fill any missing values with 0 and convert the values to integers
='Total', axis=1) class_dist_agg_df.sort_values(by
rand-square-pixel-cp-1 | baseline | rand-square-cp-1 | custom-trivial-aug-1 | trivial-aug-item-1 | rand-pixel-cp-1 | |
---|---|---|---|---|---|---|
class | ||||||
call | 6 | 15 | 9 | 15 | 16 | 13 |
dislike | 8 | 5 | 5 | 5 | 6 | 10 |
fist | 11 | 8 | 9 | 10 | 13 | 21 |
four | 25 | 19 | 15 | 17 | 22 | 31 |
like | 18 | 11 | 15 | 12 | 19 | 14 |
mute | 5 | 3 | 6 | 5 | 4 | 10 |
no_gesture | 18 | 23 | 23 | 22 | 44 | 25 |
ok | 5 | 15 | 11 | 8 | 8 | 7 |
one | 14 | 22 | 28 | 23 | 21 | 19 |
palm | 11 | 10 | 7 | 9 | 23 | 12 |
peace | 20 | 17 | 26 | 27 | 24 | 45 |
peace_inverted | 1 | 9 | 7 | 7 | 6 | 7 |
rock | 8 | 8 | 6 | 12 | 12 | 6 |
stop | 14 | 18 | 11 | 19 | 15 | 23 |
stop_inverted | 6 | 10 | 10 | 13 | 12 | 21 |
three | 20 | 22 | 31 | 22 | 37 | 20 |
three2 | 5 | 11 | 7 | 4 | 8 | 4 |
two_up | 10 | 11 | 12 | 15 | 16 | 18 |
two_up_inverted | 8 | 6 | 8 | 6 | 6 | 11 |
Total | 213 | 243 | 246 | 251 | 312 | 317 |
Even though the TrivialAugmentWide transform seemed like the winner across the validation sets, it performed the worst overall during inference. The custom trivial augmentation transform performed the best overall, with the random square copy transform a close second.
Compare performance across random seeds
# Get project folders
= project_name.split('-')[0]
folder_prefix = [f for f in os.listdir(".") if f.startswith(folder_prefix) and os.path.isdir(f)]
project_folders
# Get the class distribution data for each random seed
= [get_class_dist_agg_df(f).sum(axis=0).rename(int(f.split('-')[-1])) for f in project_folders]
df_list
# Concatenate the class distribution data across seeds into a single DataFrame
= pd.concat(df_list, axis=1, sort=False).transpose().sort_index()
missing_imgs_df
# Compute the mean of the class distribution data for each run
= missing_imgs_df.mean(axis=0).rename('Mean')
mean_vals
# Fill any missing values with 0, convert the values to integers, and sort the columns
= mean_vals.to_frame().transpose().fillna(0).astype(int)
mean_vals_df
# Append mean values to DataFrame
=0).sort_values(by='Mean', axis=1) pd.concat([missing_imgs_df, mean_vals_df], axis
custom-trivial-aug-1 | rand-square-cp-1 | rand-square-pixel-cp-1 | baseline | rand-pixel-cp-1 | trivial-aug-item-1 | |
---|---|---|---|---|---|---|
1 | 231 | 215 | 261 | 257 | 272 | 321 |
42 | 231 | 238 | 290 | 284 | 309 | 312 |
100 | 229 | 255 | 224 | 238 | 352 | 307 |
1234 | 251 | 246 | 213 | 243 | 317 | 312 |
Mean | 235 | 238 | 247 | 255 | 312 | 313 |
Sort runs for each random seed
# Define a lambda function to sort the column names based on their values in each row and return a new DataFrame
= lambda row: pd.DataFrame({row.name: row.sort_values().index.tolist()}).transpose()
sort_cols
# Apply the lambda function to each row in the DataFrame and concatenate the results into a new DataFrame
= pd.concat(missing_imgs_df.apply(sort_cols, axis=1).tolist())
sorted_cols_df
sorted_cols_df
0 | 1 | 2 | 3 | 4 | 5 | |
---|---|---|---|---|---|---|
1 | rand-square-cp-1 | custom-trivial-aug-1 | baseline | rand-square-pixel-cp-1 | rand-pixel-cp-1 | trivial-aug-item-1 |
42 | custom-trivial-aug-1 | rand-square-cp-1 | baseline | rand-square-pixel-cp-1 | rand-pixel-cp-1 | trivial-aug-item-1 |
100 | rand-square-pixel-cp-1 | custom-trivial-aug-1 | baseline | rand-square-cp-1 | trivial-aug-item-1 | rand-pixel-cp-1 |
1234 | rand-square-pixel-cp-1 | baseline | rand-square-cp-1 | custom-trivial-aug-1 | trivial-aug-item-1 | rand-pixel-cp-1 |
Discussion
The random square copy augmentation performed consistently well throughout my initial testing and final set of experiments. I thought this one might not work well with the hand gesture dataset as it could copy or cover parts of a hand. I assumed that would be an issue if the result were too close to another gesture class.
The random pixel copy transform did much worse than I expected on the validation sets and during inference. I thought it would be one of the better-performing augmentations since it was unlikely to obscure a hand. I am curious if some of the training images were already so noisy that adding more was too much.
Given how poorly the pixel copy transform performed relative to the others, I was surprised my custom trivial augmentation transform did the best on average.
Something I found interesting is that the trivial augmentation models initially performed far better during inference than the others. However, it lost that lead as I cleaned up the dataset. I wonder if the more extreme augmentations included with that transform help with a dirty dataset? With the cleaned-up dataset, I believe some of the included transforms altered the images too much.
Looking to the future, recent improvements in fine-grained control for generative models like ControlNet for Stable Diffusion may make synthetic datasets more comparable to real-world images. I plan to explore this approach in future experiments to see if it can improve the performance of deep learning models trained on visual data.
Conclusion
My findings suggest that different datasets may require different augmentation strategies, and it’s important to test multiple augmentations to find the best one for a particular dataset. In addition, factors such as the complexity and quality of the dataset and the type of deep learning model may influence the performance of image augmentations.
However, the square-copy transform delivered the best bang for the buck in these experiments. It is nearly tied with the custom trivial augmentation for overall performance and is simple to implement.
Of course, experiments with a known dataset are no substitute for real-world testing. Therefore, I encourage you to try the in-browser demos I’ve provided to test the models trained with different augmentations using your webcam.
Live Demos
Data Augmentation | Demo |
---|---|
Baseline (no augmentation) | https://cj-mills.github.io/miniai-resnet18d-baseline-demo/ |
Random Square Copy | https://cj-mills.github.io/miniai-resnet18d-rand-square-cp-demo/ |
Random Pixel Copy | https://cj-mills.github.io/miniai-resnet18d-rand-pixel-cp-demo/ |
Random Square Pixel Copy | https://cj-mills.github.io/miniai-resnet18d-rand-square-pixel-cp-demo/ |
TrivialAugmentWide | https://cj-mills.github.io/miniai-resnet18d-trivial-aug-demo/ |
Custom Trivial Augment | https://cj-mills.github.io/miniai-resnet18d-custom-trivial-aug-demo/ |
The models run locally in your browser using the Barracuda inference library for Unity. I selected the best-performing runs for each image augmentation to give them their best shot. Let me know which one performs best for you in the comments below.
- I’m Christian Mills, a deep learning consultant specializing in computer vision and practical AI implementations.
- I help clients leverage cutting-edge AI technologies to solve real-world problems.
- Learn more about me or reach out via email at [email protected] to discuss your project.