Training Keypoint R-CNN Models with PyTorch
- Introduction
- Getting Started with the Code
- Setting Up Your Python Environment
- Importing the Required Dependencies
- Setting Up the Project
- Loading and Exploring the Dataset
- Loading the Keypoint R-CNN Model
- Preparing the Data
- Fine-tuning the Model
- Making Predictions with the Model
- Conclusion
Introduction
Welcome to this hands-on guide to training Keypoint R-CNN models in PyTorch. Keypoint estimation models predict the locations of points on a given object or person, allowing us to recognize and interpret poses, gestures, or significant parts of objects.
The tutorial walks through setting up a Python environment, loading the raw keypoint annotations, annotating and augmenting images, creating a custom Dataset class to feed samples to a model, finetuning a Keypoint R-CNN model, and performing inference.
This guide is suitable for beginners and experienced practitioners, providing the code, explanations, and resources needed to understand and implement each step. Upon completion, you will have a solid foundation for training custom key point estimation models for other projects.
Getting Started with the Code
The tutorial code is available as a Jupyter Notebook, which you can run locally or in a cloud-based environment like Google Colab. I have dedicated tutorials for those new to these platforms or who need guidance setting up:
Platform | Jupyter Notebook |
---|---|
Google Colab | Open In Colab |
Linux | GitHub Repository |
Windows | GitHub Repository |
The code in this tutorial targets Linux platforms, but most of it should also work on macOS and Windows.
However, Python multiprocessing works differently on those platforms, requiring some changes to leverage multi-processing for the DataLoader
objects.
I’ve made a dedicated version of the tutorial code to run on Windows. The included changes should also work on macOS, but I don’t have a Mac to verify.
Setting Up Your Python Environment
Before diving into the code, we’ll cover the steps to create a local Python environment and install the necessary dependencies.
Creating a Python Environment
First, we’ll create a Python environment using Conda/Mamba. Open a terminal with Conda/Mamba installed and run the following commands:
# Create a new Python 3.11 environment
conda create --name pytorch-env python=3.11 -y
# Activate the environment
conda activate pytorch-env
# Create a new Python 3.11 environment
mamba create --name pytorch-env python=3.11 -y
# Activate the environment
mamba activate pytorch-env
Installing PyTorch
Next, we’ll install PyTorch. Run the appropriate command for your hardware and operating system.
# Install PyTorch with CUDA
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
# MPS (Metal Performance Shaders) acceleration is available on MacOS 12.3+
pip install torch torchvision torchaudio
# Install PyTorch for CPU only
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
# Install PyTorch for CPU only
pip install torch torchvision torchaudio
Installing Additional Libraries
We also need to install some additional libraries for our project.
Package | Description |
---|---|
jupyter |
An open-source web application that allows you to create and share documents that contain live code, equations, visualizations, and narrative text. (link) |
matplotlib |
This package provides a comprehensive collection of visualization tools to create high-quality plots, charts, and graphs for data exploration and presentation. (link) |
pandas |
This package provides fast, powerful, and flexible data analysis and manipulation tools. (link) |
pillow |
The Python Imaging Library adds image processing capabilities. (link) |
torchtnt |
A Python library that provides fast, extensible progress bars for loops and other iterable objects in Python. (link) |
tabulate |
Pretty-print tabular data in Python. (link) |
tqdm |
A Python library that provides fast, extensible progress bars for loops and other iterable objects in Python. (link) |
distinctipy |
A lightweight python package providing functions to generate colours that are visually distinct from one another. (link) |
Run the following command to install these additional libraries:
# Install additional dependencies
pip install distinctipy jupyter matplotlib pandas pillow torchtnt==0.2.0 tabulate tqdm
Installing Utility Packages
We will also install some utility packages I made, which provide shortcuts for routine tasks.
Package | Description |
---|---|
cjm_pandas_utils |
Some utility functions for working with Pandas. (link) |
cjm_pil_utils |
Some PIL utility functions I frequently use. (link) |
cjm_psl_utils |
Some utility functions using the Python Standard Library. (link) |
cjm_pytorch_utils |
Some utility functions for working with PyTorch. (link) |
cjm_torchvision_tfms |
Some custom Torchvision tranforms. (link) |
Run the following command to install the utility packages:
# Install additional utility packages
pip install cjm_pandas_utils cjm_pil_utils cjm_psl_utils cjm_pytorch_utils cjm_torchvision_tfms
With our environment set up, we can open our Jupyter Notebook and dive into the code.
Importing the Required Dependencies
First, we will import the necessary Python modules into our Jupyter Notebook.
# Import Python Standard Library dependencies
from contextlib import contextmanager
import datetime
from functools import partial
from glob import glob
import json
import math
import multiprocessing
import os
from pathlib import Path
import random
# Import utility functions
from cjm_pandas_utils.core import markdown_to_pandas
from cjm_pil_utils.core import resize_img, get_img_files, stack_imgs
from cjm_psl_utils.core import download_file, file_extract
from cjm_pytorch_utils.core import set_seed, pil_to_tensor, tensor_to_pil, get_torch_device, denorm_img_tensor, move_data_to_device
from cjm_torchvision_tfms.core import ResizeMax, PadSquare, CustomRandomIoUCrop, RandomPixelCopy
# Import the distinctipy module
from distinctipy import distinctipy
# Import matplotlib for creating plots
import matplotlib.pyplot as plt
# Import numpy
import numpy as np
# Import the pandas package
import pandas as pd
# Do not truncate the contents of cells and display all rows and columns
'max_colwidth', None, 'display.max_rows', None, 'display.max_columns', None)
pd.set_option(
# Import PIL for image manipulation
from PIL import Image
# Import PyTorch dependencies
import torch
import torch.nn as nn
from torch.amp import autocast
from torch.cuda.amp import GradScaler
from torch.utils.data import Dataset, DataLoader
from torchtnt.utils import get_module_summary
# Import torchvision dependencies
import torchvision
torchvision.disable_beta_transforms_warning()from torchvision.tv_tensors import BoundingBoxes
from torchvision.utils import draw_bounding_boxes
import torchvision.transforms.v2 as transforms
# Import Keypoint R-CNN
from torchvision.models.detection.keypoint_rcnn import KeypointRCNNPredictor
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.models.detection import keypointrcnn_resnet50_fpn
# Import tqdm for progress bar
from tqdm.auto import tqdm
Torchvision provides dedicated torch.Tensor
subclasses for different annotation types called TVTensors
. Torchvision’s V2 transforms use these subclasses to update the annotations based on the applied image augmentations. While there is currently no dedicated TVTensor class for keypoint annotations, we can use the one for bounding boxes instead. Torchvision does include a draw_keypoints
function, but we might as well stick with the draw_bounding_boxes
function to annotate images.
Setting Up the Project
In this section, we set up some basics for our project, such as initializing random number generators, setting the PyTorch device to run the model, and preparing the folders for our project and datasets.
Setting a Random Number Seed
First, we set the seed for generating random numbers using the set_seed function from the cjm_pytorch_utils
package.
# Set the seed for generating random numbers in PyTorch, NumPy, and Python's random module.
= 123
seed set_seed(seed)
Setting the Device and Data Type
Next, we determine the device to use for training using the get_torch_device function from the cjm_pytorch_utils
package and set the data type of our tensors.
= get_torch_device()
device = torch.float32
dtype device, dtype
('cuda', torch.float32)
Setting the Directory Paths
We can then set up a directory for our project to store our results and other related files. We also need a place to store our dataset. The following code creates the folders in the current directory (./
). Update the path if that is not suitable for you.
# The name for the project
= f"pytorch-keypoint-r-cnn"
project_name
# The path for the project folder
= Path(f"./{project_name}/")
project_dir
# Create the project directory if it does not already exist
=True, exist_ok=True)
project_dir.mkdir(parents
# Define path to store datasets
= Path("./Datasets/")
dataset_dir # Create the dataset directory if it does not exist
=True, exist_ok=True)
dataset_dir.mkdir(parents
# Define path to store archive files
= dataset_dir/'../Archive'
archive_dir # Create the archive directory if it does not exist
=True, exist_ok=True)
archive_dir.mkdir(parents
# Creating a Series with the paths and converting it to a DataFrame for display
pd.Series({"Project Directory:": project_dir,
"Dataset Directory:": dataset_dir,
"Archive Directory:": archive_dir
='columns') }).to_frame().style.hide(axis
Project Directory: | pytorch-keypoint-r-cnn |
---|---|
Dataset Directory: | Datasets |
Archive Directory: | Datasets/../Archive |
Double-check the project and dataset directories exist in the specified paths and that you can add files to them before continuing. At this point, our project is set up and ready to go. In the next section, we will download and explore the dataset.
Loading and Exploring the Dataset
I annotated a small dataset with key points for this tutorial using images from the free stock photo site Pexels. The dataset is available on HuggingFace Hub at the link below:
- Dataset Repository: labelme-keypoint-eyes-noses-dataset
The dataset contains 2D coordinates for eyes and noses on human faces.
The keypoints for this dataset use the LabelMe annotation format. You can learn more about this format and how to work with such annotations in the tutorial linked below:
Setting the Dataset Path
First, we construct the name for the Hugging Face Hub dataset and set where to download and extract the dataset.
# Set the name of the dataset
= 'labelme-keypoint-eyes-noses-dataset'
dataset_name
# Construct the HuggingFace Hub dataset name by combining the username and dataset name
= f'cj-mills/{dataset_name}'
hf_dataset
# Create the path to the zip file that contains the dataset
= Path(f'{archive_dir}/{dataset_name}.zip')
archive_path
# Create the path to the directory where the dataset will be extracted
= Path(f'{dataset_dir}/{dataset_name}')
dataset_path
# Creating a Series with the dataset name and paths and converting it to a DataFrame for display
pd.Series({"HuggingFace Dataset:": hf_dataset,
"Archive Path:": archive_path,
"Dataset Path:": dataset_path
='columns') }).to_frame().style.hide(axis
HuggingFace Dataset: | cj-mills/labelme-keypoint-eyes-noses-dataset |
---|---|
Archive Path: | Datasets/../Archive/labelme-keypoint-eyes-noses-dataset.zip |
Dataset Path: | Datasets/labelme-keypoint-eyes-noses-dataset |
Downloading the Dataset
We can now download the archive file and extract the dataset using the download_file
and file_extract
functions from the cjm_psl_utils
package. We can delete the archive afterward to save space.
# Construct the HuggingFace Hub dataset URL
= f"https://huggingface.co/datasets/{hf_dataset}/resolve/main/{dataset_name}.zip"
dataset_url print(f"HuggingFace Dataset URL: {dataset_url}")
# Set whether to delete the archive file after extracting the dataset
= True
delete_archive
# Download the dataset if not present
if dataset_path.is_dir():
print("Dataset folder already exists")
else:
print("Downloading dataset...")
download_file(dataset_url, archive_dir)
print("Extracting dataset...")
=archive_path, dest=dataset_dir)
file_extract(fname
# Delete the archive if specified
if delete_archive: archive_path.unlink()
Get Image File Paths
Next, we will make a dictionary that maps each image’s unique name to its file path, allowing us to retrieve the file path for a given image more efficiently.
# Get a list of image files in the dataset
= get_img_files(dataset_path)
img_file_paths
# Create a dictionary that maps file names to file paths
= {file.stem : file for file in (img_file_paths)}
img_dict
# Print the number of image files
print(f"Number of Images: {len(img_dict)}")
# Display the first five entries from the dictionary using a Pandas DataFrame
='index').head() pd.DataFrame.from_dict(img_dict, orient
Number of Images: 200
0 | |
---|---|
denim-jacket-fashion-fashion-model-1848570 | Datasets/labelme-keypoint-eyes-noses-dataset/denim-jacket-fashion-fashion-model-1848570.jpg |
dried-dry-face-2965690 | Datasets/labelme-keypoint-eyes-noses-dataset/dried-dry-face-2965690.jpg |
elderly-face-old-person-2856346 | Datasets/labelme-keypoint-eyes-noses-dataset/elderly-face-old-person-2856346.jpg |
elderly-hair-man-1319289 | Datasets/labelme-keypoint-eyes-noses-dataset/elderly-hair-man-1319289.jpg |
face-facial-expression-fashion-2592000 | Datasets/labelme-keypoint-eyes-noses-dataset/face-facial-expression-fashion-2592000.jpg |
Get Image Annotations
We will then read the content of the JSON annotation file associated with each image into a single Pandas DataFrame so we can easily query the annotations.
# Get a list of JSON files in the dataset
= list(dataset_path.glob('*.json'))
annotation_file_paths
# Create a generator that yields Pandas DataFrames containing the data from each JSON file
= (pd.read_json(f, orient='index').transpose() for f in tqdm(annotation_file_paths))
cls_dataframes
# Concatenate the DataFrames into a single DataFrame
= pd.concat(cls_dataframes, ignore_index=False)
annotation_df
# Assign the image file name as the index for each row
'index'] = annotation_df.apply(lambda row: row['imagePath'].split('.')[0], axis=1)
annotation_df[= annotation_df.set_index('index')
annotation_df
# Keep only the rows that correspond to the filenames in the 'img_dict' dictionary
= annotation_df.loc[list(img_dict.keys())]
annotation_df
# Print the first 5 rows of the DataFrame
annotation_df.head()
version | flags | shapes | imagePath | imageData | imageHeight | imageWidth | |
---|---|---|---|---|---|---|---|
index | |||||||
denim-jacket-fashion-fashion-model-1848570 | 5.3.1 | {} | [{‘label’: ‘left-eye’, ‘points’: [[329.17073170731703, 252.59756097560972]], ‘group_id’: None, ‘description’: ’‘, ’shape_type’: ‘point’, ‘flags’: {}}, {‘label’: ‘nose’, ‘points’: [[323.68292682926835, 291.0121951219512]], ‘group_id’: None, ‘description’: ’‘, ’shape_type’: ‘point’, ‘flags’: {}}, {‘label’: ‘right-eye’, ‘points’: [[260.2682926829268, 234.91463414634143]], ‘group_id’: None, ‘description’: ’‘, ’shape_type’: ‘point’, ‘flags’: {}}] | denim-jacket-fashion-fashion-model-1848570.jpg | None | 768 | 512 |
dried-dry-face-2965690 | 5.3.1 | {} | [{‘label’: ‘right-eye’, ‘points’: [[201.7317073170732, 351.9878048780488]], ‘group_id’: None, ‘description’: ’‘, ’shape_type’: ‘point’, ‘flags’: {}}, {‘label’: ‘left-eye’, ‘points’: [[333.43902439024396, 342.23170731707313]], ‘group_id’: None, ‘description’: ’‘, ’shape_type’: ‘point’, ‘flags’: {}}, {‘label’: ‘nose’, ‘points’: [[271.2439024390244, 436.1341463414634]], ‘group_id’: None, ‘description’: ’‘, ’shape_type’: ‘point’, ‘flags’: {}}] | dried-dry-face-2965690.jpg | None | 768 | 512 |
elderly-face-old-person-2856346 | 5.3.1 | {} | [{‘label’: ‘left-eye’, ‘points’: [[302.3414634146342, 286.1341463414634]], ‘group_id’: None, ‘description’: ’‘, ’shape_type’: ‘point’, ‘flags’: {}}, {‘label’: ‘nose’, ‘points’: [[243.80487804878055, 339.79268292682923]], ‘group_id’: None, ‘description’: ’‘, ’shape_type’: ‘point’, ‘flags’: {}}, {‘label’: ‘right-eye’, ‘points’: [[196.2439024390244, 286.7439024390244]], ‘group_id’: None, ‘description’: ’‘, ’shape_type’: ‘point’, ‘flags’: {}}] | elderly-face-old-person-2856346.jpg | None | 768 | 512 |
elderly-hair-man-1319289 | 5.3.1 | {} | [{‘label’: ‘right-eye’, ‘points’: [[490.910569105691, 175.71544715447155]], ‘group_id’: None, ‘description’: ’‘, ’shape_type’: ‘point’, ‘flags’: {}}, {‘label’: ‘left-eye’, ‘points’: [[548.6341463414634, 167.58536585365852]], ‘group_id’: None, ‘description’: ’‘, ’shape_type’: ‘point’, ‘flags’: {}}, {‘label’: ‘nose’, ‘points’: [[526.6829268292682, 201.73170731707316]], ‘group_id’: None, ‘description’: ’‘, ’shape_type’: ‘point’, ‘flags’: {}}] | elderly-hair-man-1319289.jpg | None | 512 | 768 |
face-facial-expression-fashion-2592000 | 5.3.1 | {} | [{‘label’: ‘left-eye’, ‘points’: [[301.45454545454544, 106.85561497326205]], ‘group_id’: None, ‘description’: ’‘, ’shape_type’: ‘point’, ‘flags’: {}}, {‘label’: ‘right-eye’, ‘points’: [[250.65240641711233, 115.94652406417114]], ‘group_id’: None, ‘description’: ’‘, ’shape_type’: ‘point’, ‘flags’: {}}, {‘label’: ‘nose’, ‘points’: [[272.0427807486631, 121.29411764705884]], ‘group_id’: None, ‘description’: ’‘, ’shape_type’: ‘point’, ‘flags’: {}}] | face-facial-expression-fashion-2592000.jpg | None | 672 | 512 |
Inspecting the Class Distribution
Now that we have the annotation data, we can extract the unique class names and inspect the class distribution. A balanced dataset (where each class has approximately the same number of instances) is ideal for training a machine-learning model.
Get image classes
# Explode the 'shapes' column in the annotation_df dataframe
# Apply the pandas Series function to the 'shapes' column of the dataframe
= annotation_df['shapes'].explode().to_frame().shapes.apply(pd.Series)
shapes_df
# Get a list of unique labels in the 'annotation_df' DataFrame
= shapes_df['label'].unique().tolist()
class_names
# Display labels using a Pandas DataFrame
pd.DataFrame(class_names)
0 | |
---|---|
0 | left-eye |
1 | nose |
2 | right-eye |
Visualize the class distribution
# Get the number of samples for each object class
= shapes_df['label'].value_counts()
class_counts
# Plot the distribution
='bar')
class_counts.plot(kind'Class distribution')
plt.title('Count')
plt.ylabel('Classes')
plt.xlabel(range(len(class_counts.index)), class_counts.index, rotation=75) # Set the x-axis tick labels
plt.xticks( plt.show()
Visualizing Image Annotations
In this section, we will annotate a single image with its bounding boxes using torchvision’s BoundingBoxes
class and draw_bounding_boxes
function.
Generate a color map
First, we will generate a color map for the object classes.
# Generate a list of colors with a length equal to the number of labels
= distinctipy.get_colors(len(class_names))
colors
# Make a copy of the color map in integer format
= [tuple(int(c*255) for c in color) for color in colors]
int_colors
# Generate a color swatch to visualize the color map
distinctipy.color_swatch(colors)
Download a font file
The draw_bounding_boxes
function included with torchvision uses a pretty small font size. We can increase the font size if we use a custom font. Font files are available on sites like Google Fonts, or we can use one included with the operating system.
# Set the name of the font file
= 'KFOlCnqEu92Fr1MmEU9vAw.ttf'
font_file
# Download the font file
f"https://fonts.gstatic.com/s/roboto/v30/{font_file}", "./") download_file(
Define the bounding box annotation function
We can make a partial function using draw_bounding_boxes
since we’ll use the same box thickness and font each time we visualize bounding boxes.
= partial(draw_bounding_boxes, fill=True, width=4, font=font_file, font_size=25) draw_bboxes
Annotate sample image
Finally, we will open a sample image and annotate it with it’s associated bounding boxes.
# Get the file ID of the first image file
= list(img_dict.keys())[0]
file_id
# Open the associated image file as a RGB image
= Image.open(img_dict[file_id]).convert('RGB')
sample_img
# Extract the labels and bounding box annotations for the sample image
= [shape['label'] for shape in annotation_df.loc[file_id]['shapes']]
labels = torch.tensor(np.array([shape['points'] for shape in annotation_df.loc[file_id]['shapes']])).reshape(-1,2)
keypoints = 4
BBOX_DIM = torch.cat((keypoints, torch.ones(len(keypoints), 2)*BBOX_DIM), dim=1)
keypoints_bboxes
# Annotate the sample image with labels and bounding boxes
= draw_bboxes(
annotated_tensor =transforms.PILToTensor()(sample_img),
image=torchvision.ops.box_convert(torch.Tensor(keypoints_bboxes), 'cxcywh', 'xyxy'),
boxes=labels,
labels=[int_colors[i] for i in [class_names.index(label) for label in labels]]
colors
)
tensor_to_pil(annotated_tensor)
Loading the Keypoint R-CNN Model
TorchVision provides checkpoints for the Keypoint R-CNN model trained on the COCO (Common Objects in Context) dataset. We can initialize a model with these pretrained weights using the keypointrcnn_resnet50_fpn
function. We must then replace the keypoint predictor for the pretrained model with a new one for our dataset.
# Load a pre-trained model
= keypointrcnn_resnet50_fpn(weights='DEFAULT')
model
# Replace the classifier head with the number of keypoints
= model.roi_heads.keypoint_predictor.kps_score_lowres.in_channels
in_features = KeypointRCNNPredictor(in_channels=in_features, num_keypoints=len(class_names))
model.roi_heads.keypoint_predictor
# Set the model's device and data type
=device, dtype=dtype);
model.to(device
# Add attributes to store the device and model name for later reference
= device
model.device = 'keypointrcnn_resnet50_fpn' model.name
The model internally normalizes input using the mean and standard deviation values used during the pretraining process, so we do not need to keep track of them separately.
Summarizing the Model
Before moving on, let’s generate a summary of our model to get an overview of its performance characteristics. We can use this to gauge the computational requirements for deploying the model.
# Define the input to the model
= torch.randn(1, 3, 256, 256).to(device)
test_inp
# Get a summary of the model as a Pandas DataFrame
= markdown_to_pandas(f"{get_module_summary(model.eval(), [test_inp])}")
summary_df
# Filter the summary to only the model
= summary_df[summary_df.index == 0]
summary_df
# Remove the column "Contains Uninitialized Parameters?"
'In size', 'Out size', 'Contains Uninitialized Parameters?'], axis=1) summary_df.drop([
Type | # Parameters | # Trainable Parameters | Size (bytes) | Forward FLOPs | |
---|---|---|---|---|---|
0 | KeypointRCNN | 59.0 M | 58.8 M | 236 M | 144 G |
The above table shows the model has approximately 58.8
million trainable parameters. It takes up 263
Megabytes and performs around 144
billion floating point operations for a single 256x256
RGB image. This model internally resizes input images and executes the same number of floating point operations for different input resolutions.
That completes the model setup. In the next section, we will prepare our dataset for training.
Preparing the Data
The data preparation involves several steps, such as applying data augmentation techniques, setting up the train-validation split for the dataset, resizing and padding the images, defining the training dataset class, and initializing DataLoaders to feed data to the model.
Training-Validation Split
Let’s begin by defining the training-validation split. We’ll randomly select 90% of the available samples for the training set and use the remaining 10% for the validation set.
# Get the list of image IDs
= list(img_dict.keys())
img_keys
# Shuffle the image IDs
random.shuffle(img_keys)
# Define the percentage of the images that should be used for training
= 0.9
train_pct = 0.1
val_pct
# Calculate the index at which to split the subset of image paths into training and validation sets
= int(len(img_keys)*train_pct)
train_split = int(len(img_keys)*(train_pct+val_pct))
val_split
# Split the subset of image paths into training and validation sets
= img_keys[:train_split]
train_keys = img_keys[train_split:]
val_keys
# Print the number of images in the training and validation sets
pd.Series({"Training Samples:": len(train_keys),
"Validation Samples:": len(val_keys)
='columns') }).to_frame().style.hide(axis
Training Samples: | 180 |
---|---|
Validation Samples: | 20 |
Data Augmentation
Here, we will define some data augmentations to apply to images during training. I created a few custom image transforms to help streamline the code.
The first extends the RandomIoUCrop
transform included with torchvision to give the user more control over how much it crops into bounding box areas. The second resizes images based on their largest dimension rather than their smallest. The third applies square padding and allows the padding to be applied equally on both sides or randomly split between the two sides.
All three are available through the cjm-torchvision-tfms
package.
Set training image size
First, we will specify the image size to use during training.
# Set training image size
= 512 train_sz
Initialize custom transforms
Next, we can initialize the transform objects.
# Create a RandomIoUCrop object
= CustomRandomIoUCrop(min_scale=0.3,
iou_crop =1.0,
max_scale=0.5,
min_aspect_ratio=2.0,
max_aspect_ratio=[0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0],
sampler_options=400,
trials=0.25)
jitter_factor
# Create a `ResizeMax` object
= ResizeMax(max_sz=train_sz)
resize_max
# Create a `PadSquare` object
= PadSquare(shift=True) pad_square
Test the transforms
Torchvision’s V2 image transforms take an image and a targets
dictionary. The targets
dictionary contains the annotations and labels for the image.
We will pass input through the CustomRandomIoUCrop
transform first and then through ResizeMax
and PadSquare
. We can pass the result through a final resize operation to ensure both sides match the train_sz
value.
Always use the SanitizeBoundingBoxes
transform to clean up annotations after using data augmentations that alter bounding boxes (e.g., cropping, warping, etc.).
# Extract the labels for the sample
= [shape['label'] for shape in annotation_df.loc[file_id]['shapes']]
labels
# Prepare bounding box targets
= {'boxes': BoundingBoxes(torchvision.ops.box_convert(keypoints_bboxes, 'cxcywh', 'xyxy'),
targets format='xyxy',
=sample_img.size[::-1]),
canvas_size'labels': torch.Tensor([class_names.index(label) for label in labels])}
# Crop the image
= iou_crop(sample_img, targets)
cropped_img, targets
# Resize the image
= resize_max(cropped_img, targets)
resized_img, targets
# Pad the image
= pad_square(resized_img, targets)
padded_img, targets
# Ensure the padded image is the target size
= transforms.Resize([train_sz] * 2, antialias=True)
resize = resize(padded_img, targets)
resized_padded_img, targets = transforms.SanitizeBoundingBoxes()(resized_padded_img, targets)
sanitized_img, targets
# Get colors for dataset sample
= [int_colors[i] for i in [class_names.index(label) for label in labels]]
sample_colors
# Annotate the augmented image with updated labels and bounding boxes
= draw_bboxes(
annotated_tensor =transforms.PILToTensor()(sanitized_img),
image=targets['boxes'],
boxes=[class_names[int(label.item())] for label in targets['labels']],
labels=sample_colors,
colors
)
# Display the annotated image
display(tensor_to_pil(annotated_tensor))
pd.Series({"Source Image:": sample_img.size,
"Cropped Image:": cropped_img.size,
"Resized Image:": resized_img.size,
"Padded Image:": padded_img.size,
"Resized Padded Image:": resized_padded_img.size,
='columns') }).to_frame().style.hide(axis
Source Image: | (512, 768) |
---|---|
Cropped Image: | (512, 768) |
Resized Image: | (341, 511) |
Padded Image: | (511, 511) |
Resized Padded Image: | (512, 512) |
Now that we know how to apply data augmentations, we can put all the steps we’ve covered into a custom Dataset class.
Training Dataset Class
The following custom Dataset class is responsible for loading a single image, preparing the associated annotations, applying any image transforms, and returning the final image
tensor and its target
dictionary during training.
We will be applying the SanitizeBoundingBoxes
transform here as well. This transform can remove key points if a previous transform moves them outside the image dimensions. The Keypoint R-CNN model still expects values for key points even when not visible, so we will fill the target annotations with dummy values as needed.
class LabelMeKeypointDataset(Dataset):
"""
A PyTorch Dataset class for handling LabelMe image keypoints.
This class extends PyTorch's Dataset and is designed to work with image data and
associated keypoints annotations. It supports loading images and corresponding
keypoints annotations, and applying transformations.
Attributes:
img_keys (list): List of image keys.
annotation_df (DataFrame): DataFrame containing annotations for each image.
img_dict (dict): Dictionary mapping image keys to their file paths.
class_to_idx (dict): Dictionary mapping class names to class indices.
transforms (callable, optional): Transformations to be applied to the images and targets.
"""
def __init__(self, img_keys, annotation_df, img_dict, class_to_idx, transforms=None):
"""
Initializes the LabelMeKeypointDataset with image keys, annotations, and other relevant information.
Args:
img_keys (list): List of image keys.
annotation_df (DataFrame): DataFrame containing annotations for each image.
img_dict (dict): Dictionary mapping image keys to their file paths.
class_to_idx (dict): Dictionary mapping class names to class indices.
transforms (callable, optional): Transformations to be applied to the images and targets.
"""
super(Dataset, self).__init__()
self._img_keys = img_keys
self._annotation_df = annotation_df
self._img_dict = img_dict
self._class_to_idx = class_to_idx
self._transforms = transforms
self.sanitize_bboxes = torchvision.transforms.v2.SanitizeBoundingBoxes()
self.BBOX_DIM = 4
self.DUMMY_VALUE = -1
def __len__(self):
"""
Returns the number of items in the dataset.
Returns:
int: Number of items in the dataset.
"""
return len(self._img_keys)
def __getitem__(self, index):
"""
Retrieves an item from the dataset at the specified index.
Args:
index (int): Index of the item to retrieve.
Returns:
tuple: A tuple containing the image and its corresponding target (annotations).
"""
= self._img_keys[index]
img_key = self._annotation_df.loc[img_key]
annotation = self._load_image_and_target(annotation)
image, target
# Applying transformations if specified
if self._transforms:
= self._transforms(image, target)
image, target
# Fill any missing keypoints with dummy values
= self._fill_and_order_target(target)
target return image, target
def order_points_by_labels(self, data, label_order):
"""
Extracts and orders points from a list of dictionaries based on a given order of labels.
:param data: List of dictionaries containing labels and points.
:param label_order: List of labels in the desired order.
:return: List of points in the specified label order.
"""
= []
ordered_points = {item['label']: item['points'] for item in data}
label_to_points
for label in label_order:
= label_to_points.get(label)
points if points is not None:
ordered_points.extend(points)
return ordered_points
def _load_image_and_target(self, annotation):
"""
Loads an image and its corresponding target (annotations) based on the provided annotation.
Args:
annotation (DataFrame row): Annotation data for a specific image.
Returns:
tuple: A tuple containing the loaded image and its corresponding target data.
"""
# Load the image from the file path specified in the annotations
= self._img_dict[annotation.name]
filepath = Image.open(filepath).convert('RGB')
image
# Extracting keypoints from the annotation and converting them to a tensor
= self.order_points_by_labels(annotation['shapes'], self._class_to_idx.keys())
keypoints = torch.tensor(np.array(keypoints, dtype=np.float32)).reshape(-1, 2)
keypoints
# Adding an offset to create bounding boxes around keypoints
= torch.cat((keypoints, torch.ones(len(keypoints), 2) * self.BBOX_DIM), dim=1)
keypoints_bboxes
# Convert bounding box format and create a BoundingBoxes object
= torchvision.ops.box_convert(keypoints_bboxes, 'cxcywh', 'xyxy')
bbox_tensor = BoundingBoxes(bbox_tensor, format='xyxy', canvas_size=image.size[::-1])
boxes
# Create tensor for labels based on the class indices
= torch.Tensor([self._class_to_idx[label] for label in self._class_to_idx.keys()])
labels
return image, {'boxes': boxes, 'labels': labels}
def _fill_and_order_target(self, target):
"""
Fills and orders the target bounding boxes and labels based on the class index.
This method ensures that each target has a bounding box and label for each class,
even if some classes are not present in the original target. Missing classes
are filled with dummy values.
Args:
target (dict): A dictionary containing 'boxes' and 'labels' keys, where
'boxes' is a tensor of bounding boxes and 'labels' is a tensor
of labels corresponding to these boxes.
Returns:
dict: The updated target dictionary with boxes and labels ordered and filled
according to the class index.
"""
# Initialize new boxes with dummy values for each class
= torch.full((len(self._class_to_idx), 4), self.DUMMY_VALUE)
new_boxes # Prepare labels tensor based on the class indices
= torch.tensor(list(self._class_to_idx.values()), dtype=torch.float32)
new_labels
# Iterate over each class label
for i, label in enumerate(new_labels):
# Check if the current label exists in the target's labels
if label in target['labels']:
# Find the index of the current label in the target's labels
= (target['labels'] == label).nonzero(as_tuple=True)[0]
idx # Assign the corresponding box to the new boxes tensor
= target['boxes'][idx]
new_boxes[i]
# Update the target dictionary with the new boxes and labels
'boxes'] = new_boxes
target['labels'] = new_labels
target[
return target
Image Transforms
Here, we will specify and organize all the image transforms to apply during training.
# Compose transforms for data augmentation
= transforms.Compose(
data_aug_tfms =[
transforms
transforms.ColorJitter(= (0.8, 1.125),
brightness = (0.5, 1.5),
contrast = (0.5, 1.5),
saturation = (-0.05, 0.05),
hue
),
transforms.RandomGrayscale(),
transforms.RandomEqualize(),=0.025),
RandomPixelCopy(max_pct=0.15, p=0.5, fill=(123, 117, 104)),
transforms.RandomPerspective(distortion_scale=90, fill=(123, 117, 104)),
transforms.RandomRotation(degrees
iou_crop,
],
)
# Compose transforms to resize and pad input images
= transforms.Compose([
resize_pad_tfm
resize_max,
pad_square,* 2, antialias=True)
transforms.Resize([train_sz]
])
# Compose transforms to sanitize bounding boxes and normalize input data
= transforms.Compose([
final_tfms
transforms.ToImage(), =True),
transforms.ToDtype(torch.float32, scale
transforms.SanitizeBoundingBoxes(),
])
# Define the transformations for training and validation datasets
= transforms.Compose([
train_tfms
data_aug_tfms,
resize_pad_tfm,
final_tfms
])= transforms.Compose([resize_pad_tfm, final_tfms]) valid_tfms
Initialize Datasets
Now, we can create the dataset objects for the training and validation sets using the image dictionary, the annotation DataFrame, and the image transforms.
# Create a mapping from class names to class indices
= {c: i for i, c in enumerate(class_names)}
class_to_idx
# Instantiate the dataset using the defined transformations
= LabelMeKeypointDataset(train_keys, annotation_df, img_dict, class_to_idx, train_tfms)
train_dataset = LabelMeKeypointDataset(val_keys, annotation_df, img_dict, class_to_idx, valid_tfms)
valid_dataset
# Print the number of samples in the training and validation datasets
pd.Series({'Training dataset size:': len(train_dataset),
'Validation dataset size:': len(valid_dataset)}
='columns') ).to_frame().style.hide(axis
Training dataset size: | 180 |
---|---|
Validation dataset size: | 20 |
Inspect Samples
Let’s verify the dataset objects work correctly by inspecting the first samples from the training and validation sets.
Inspect training set sample
Since our custom dataset fills missing annotations with dummy values, we will pass the target dictionary through the SanitizeBoundingBoxes
function again.
# Get a sample image and its target annotations
= train_dataset[0]
dataset_sample
# Sanitize bounding boxes to remove dummy values
= dataset_sample[1]
targets 'boxes'] = BoundingBoxes(targets['boxes'], format='xyxy', canvas_size=dataset_sample[0].shape[1:])
targets[= transforms.SanitizeBoundingBoxes()(dataset_sample[0], targets)
sanitized_image, sanitized_targets
# Annotate the sample image with the sanitized annotations
= draw_bboxes(
annotated_tensor =(sanitized_image*255).to(dtype=torch.uint8),
image=sanitized_targets['boxes'],
boxes=[class_names[int(i.item())] for i in sanitized_targets['labels']],
labels=[int_colors[int(i.item())] for i in sanitized_targets['labels']]
colors
)
tensor_to_pil(annotated_tensor)
Inspect validation set sample
= valid_dataset[0]
dataset_sample
= draw_bboxes(
annotated_tensor =(dataset_sample[0]*255).to(dtype=torch.uint8),
image=dataset_sample[1]['boxes'],
boxes=[class_names[int(i.item())] for i in dataset_sample[1]['labels']],
labels=[int_colors[int(i.item())] for i in dataset_sample[1]['labels']]
colors
)
tensor_to_pil(annotated_tensor)
Initialize DataLoaders
The last step before training is to instantiate the DataLoaders for the training and validation sets. Try decreasing the batch size if you encounter memory limitations.
# Set the training batch size
= 4
bs
# Set the number of worker processes for loading data. This should be the number of CPUs available.
= multiprocessing.cpu_count()
num_workers
# Define parameters for DataLoader
= {
data_loader_params 'batch_size': bs, # Batch size for data loading
'num_workers': num_workers, # Number of subprocesses to use for data loading
'persistent_workers': True, # If True, the data loader will not shutdown the worker processes after a dataset has been consumed once. This allows to maintain the worker dataset instances alive.
'pin_memory': 'cuda' in device, # If True, the data loader will copy Tensors into CUDA pinned memory before returning them. Useful when using GPU.
'pin_memory_device': device if 'cuda' in device else '', # Specifies the device where the data should be loaded. Commonly set to use the GPU.
'collate_fn': lambda batch: tuple(zip(*batch)),
}
# Create DataLoader for training data. Data is shuffled for every epoch.
= DataLoader(train_dataset, **data_loader_params, shuffle=True)
train_dataloader
# Create DataLoader for validation data. Shuffling is not necessary for validation data.
= DataLoader(valid_dataset, **data_loader_params)
valid_dataloader
# Print the number of batches in the training and validation DataLoaders
print(f'Number of batches in train DataLoader: {len(train_dataloader)}')
print(f'Number of batches in validation DataLoader: {len(valid_dataloader)}')
Number of batches in train DataLoader: 45
Number of batches in validation DataLoader: 5
Fine-tuning the Model
In this section, we will implement the training code and fine-tune our model.
Define Utility Functions
First, we need to define a couple of utility functions.
Define a function to create a bounding box that encapsulates the key points
The Keypoint R-CNN model expects a bounding box encapsulating the points associated with a given person/object. We could include these bounding box annotations in our dataset (e.g., have bounding boxes around each face). However, dynamically making one large enough to contain the key points will suffice.
def keypoints_to_bbox(keypoints, offset=10):
"""
Convert a tensor of keypoint coordinates to a bounding box.
Args:
keypoints (Tensor): A tensor of shape (N, 2), where N is the number of keypoints.
Returns:
Tensor: A tensor representing the bounding box [xmin, ymin, xmax, ymax].
"""
= keypoints[:, 0], keypoints[:, 1]
x_coordinates, y_coordinates
= torch.min(x_coordinates)
xmin = torch.min(y_coordinates)
ymin = torch.max(x_coordinates)
xmax = torch.max(y_coordinates)
ymax
= torch.tensor([xmin-offset, ymin-offset, xmax+offset, ymax+offset])
bbox
return bbox
Define a conditional autocast
context manager
The autocast context manager that handles mixed-precision training on CPUs does not fully support the Keypoint R-CNN model. Therefore, we will only use mixed-precision training when not using the CPU.
@contextmanager
def conditional_autocast(device):
"""
A context manager for conditional automatic mixed precision (AMP).
This context manager applies automatic mixed precision for operations if the
specified device is not a CPU. It's a no-op (does nothing) if the device is a CPU.
Mixed precision can speed up computations and reduce memory usage on compatible
hardware, primarily GPUs.
Parameters:
device (str): The device type, e.g., 'cuda' or 'cpu', which determines whether
autocasting is applied.
Yields:
None - This function does not return any value but enables the wrapped code
block to execute under the specified precision context.
"""
# Check if the specified device is not a CPU
if 'cpu' not in device:
# If the device is not a CPU, enable autocast for the specified device type.
# Autocast will automatically choose the precision (e.g., float16) for certain
# operations to improve performance.
with autocast(device_type=device):
yield
else:
# If the device is a CPU, autocast is not applied.
# This yields control back to the with-block with no changes.
yield
Define the Training Loop
The following function performs a single pass through the training or validation set.
As mentioned earlier, the Keypoint R-CNN model expects values for key points even when not visible. We indicate which key points are visible, with a 1
for visible and a 0
for not.
The model has different behavior when in training
mode versus evaluation
mode. In training mode, it calculates the loss internally for the key point estimation task and returns a dictionary with the individual loss values. We can sum up these separate values to get the total loss.
# Function to run a single training/validation epoch
def run_epoch(model, dataloader, optimizer, lr_scheduler, device, scaler, epoch_id, is_training):
"""
Function to run a single training or evaluation epoch.
Args:
model: A PyTorch model to train or evaluate.
dataloader: A PyTorch DataLoader providing the data.
optimizer: The optimizer to use for training the model.
loss_func: The loss function used for training.
device: The device (CPU or GPU) to run the model on.
scaler: Gradient scaler for mixed-precision training.
is_training: Boolean flag indicating whether the model is in training or evaluation mode.
Returns:
The average loss for the epoch.
"""
# Set model to training mode
model.train()
# Initialize the average loss for the current epoch
= 0
epoch_loss # Initialize progress bar with total number of batches in the dataloader
= tqdm(total=len(dataloader), desc="Train" if is_training else "Eval")
progress_bar
# Iterate over data batches
for batch_id, (inputs, targets) in enumerate(dataloader):
# Move inputs and targets to the specified device
= torch.stack(inputs).to(device)
inputs # Extract the ground truth bounding boxes and labels
= zip(*[(d['boxes'].to(device), d['labels'].to(device)) for d in targets])
gt_bboxes, gt_labels
# Convert ground truth bounding boxes from 'xyxy' to 'cxcywh' format and only keep center coordinates
= torchvision.ops.box_convert(torch.stack(gt_bboxes), 'xyxy', 'cxcywh')[:,:,:2]
gt_keypoints
# Initialize a visibility tensor with ones, indicating all keypoints are visible
= torch.ones(len(inputs),gt_keypoints.shape[1],1).to(device)
visibility # Create a visibility mask based on whether the bounding boxes are valid (greater than or equal to 0)
= (torch.stack(gt_bboxes) >= 0.)[..., 0].view(visibility.shape).to(device)
visibility_mask
# Concatenate the keypoints with the visibility mask, adding a visibility channel to keypoints
= torch.concat((
gt_keypoints_with_visibility
gt_keypoints, *visibility_mask
visibility=2)
), dim
# Convert keypoints to bounding boxes for each input and move them to the specified device
= torch.vstack([keypoints_to_bbox(keypoints) for keypoints in gt_keypoints]).to(device)
gt_object_bboxes # Initialize ground truth labels as tensor of ones and move them to the specified device
= torch.ones(len(inputs), dtype=torch.int64).to(device)
gt_labels
# Prepare the targets for the Keypoint R-CNN model
# This includes bounding boxes, labels, and keypoints with visibility for each input image
= [
keypoint_rcnn_targets 'boxes' : boxes[None], 'labels': labels[None], 'keypoints': keypoints[None]}
{for boxes, labels, keypoints in zip(gt_object_bboxes, gt_labels, gt_keypoints_with_visibility)
]
# Forward pass with Automatic Mixed Precision (AMP) context manager
with conditional_autocast(torch.device(device).type):
if is_training:
= model(inputs.to(device), move_data_to_device(keypoint_rcnn_targets, device))
losses else:
with torch.no_grad():
= model(inputs.to(device), move_data_to_device(keypoint_rcnn_targets, device))
losses
# Compute the loss
= sum([loss for loss in losses.values()]) # Sum up the losses
loss
# If in training mode
if is_training:
if scaler:
scaler.scale(loss).backward()
scaler.step(optimizer)= scaler.get_scale()
old_scaler
scaler.update()= scaler.get_scale()
new_scaler if new_scaler >= old_scaler:
lr_scheduler.step()else:
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
= loss.item()
loss_item += loss_item
epoch_loss # Update progress bar
=loss_item,
progress_bar.set_postfix(loss=epoch_loss/(batch_id+1),
avg_loss=lr_scheduler.get_last_lr()[0] if is_training else "")
lr
progress_bar.update()
# If loss is NaN or infinity, stop training
if is_training:
= f"Loss is NaN or infinite at epoch {epoch_id}, batch {batch_id}. Stopping training."
stop_training_message assert not math.isnan(loss_item) and math.isfinite(loss_item), stop_training_message
progress_bar.close()return epoch_loss / (batch_id + 1)
Next, we define the train_loop
function, which executes the main training loop. It iterates over each epoch, runs through the training and validation sets, and saves the best model based on the validation loss.
def train_loop(model,
train_dataloader,
valid_dataloader,
optimizer,
lr_scheduler,
device,
epochs,
checkpoint_path, =False):
use_scaler"""
Main training loop.
Args:
model: A PyTorch model to train.
train_dataloader: A PyTorch DataLoader providing the training data.
valid_dataloader: A PyTorch DataLoader providing the validation data.
optimizer: The optimizer to use for training the model.
lr_scheduler: The learning rate scheduler.
device: The device (CPU or GPU) to run the model on.
epochs: The number of epochs to train for.
checkpoint_path: The path where to save the best model checkpoint.
use_scaler: Whether to scale graidents when using a CUDA device
Returns:
None
"""
# Initialize a gradient scaler for mixed-precision training if the device is a CUDA GPU
= torch.cuda.amp.GradScaler() if device.type == 'cuda' and use_scaler else None
scaler = float('inf') # Initialize the best validation loss
best_loss
# Loop over the epochs
for epoch in tqdm(range(epochs), desc="Epochs"):
# Run a training epoch and get the training loss
= run_epoch(model, train_dataloader, optimizer, lr_scheduler, device, scaler, epoch, is_training=True)
train_loss # Run an evaluation epoch and get the validation loss
with torch.no_grad():
= run_epoch(model, valid_dataloader, None, None, device, scaler, epoch, is_training=False)
valid_loss
# If the validation loss is lower than the best validation loss seen so far, save the model checkpoint
if valid_loss < best_loss:
= valid_loss
best_loss
torch.save(model.state_dict(), checkpoint_path)
# Save metadata about the training process
= {
training_metadata 'epoch': epoch,
'train_loss': train_loss,
'valid_loss': valid_loss,
'learning_rate': lr_scheduler.get_last_lr()[0],
'model_architecture': model.name
}with open(Path(checkpoint_path.parent/'training_metadata.json'), 'w') as f:
json.dump(training_metadata, f)
# If the device is a GPU, empty the cache
if device.type != 'cpu':
getattr(torch, device.type).empty_cache()
Set the Model Checkpoint Path
Before we proceed with training, let’s generate a timestamp for the training session and create a directory to save the checkpoints during training.
# Generate timestamp for the training session (Year-Month-Day_Hour_Minute_Second)
= datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
timestamp
# Create a directory to store the checkpoints if it does not already exist
= Path(project_dir/f"{timestamp}")
checkpoint_dir
# Create the checkpoint directory if it does not already exist
=True, exist_ok=True)
checkpoint_dir.mkdir(parents
# The model checkpoint path
= checkpoint_dir/f"{model.name}.pth"
checkpoint_path
print(checkpoint_path)
pytorch-keypoint-r-cnn/2024-01-28_17-07-09/keypointrcnn_resnet50_fpn.pth
Let’s also save a copy of the colormap for the current dataset in the training folder for future use.
Save the Color Map
# Create a color map and write it to a JSON file
= {'items': [{'label': label, 'color': color} for label, color in zip(class_names, colors)]}
color_map with open(f"{checkpoint_dir}/{dataset_path.name}-colormap.json", "w") as file:
file)
json.dump(color_map,
# Print the name of the file that the color map was written to
print(f"{checkpoint_dir}/{dataset_path.name}-colormap.json")
pytorch-keypoint-r-cnn/2024-01-28_17-07-09/labelme-keypoint-eyes-noses-dataset-colormap.json
Configure the Training Parameters
Now, we can configure the parameters for training. We must specify the learning rate and number of training epochs. We will also instantiate the optimizer and learning rate scheduler.
# Learning rate for the model
= 5e-4
lr
# Number of training epochs
= 70
epochs
# AdamW optimizer; includes weight decay for regularization
= torch.optim.AdamW(model.parameters(), lr=lr)
optimizer
# Learning rate scheduler; adjusts the learning rate during training
= torch.optim.lr_scheduler.OneCycleLR(optimizer,
lr_scheduler =lr,
max_lr=epochs*len(train_dataloader)) total_steps
Train the Model
Finally, we can train the model using the train_loop
function. Training time will depend on the available hardware.
Training usually takes around 30 minutes on the free GPU tier of Google Colab.
=model,
train_loop(model=train_dataloader,
train_dataloader=valid_dataloader,
valid_dataloader=optimizer,
optimizer=lr_scheduler,
lr_scheduler=torch.device(device),
device=epochs,
epochs=checkpoint_path,
checkpoint_path=True) use_scaler
Epochs: 100% |██████████| 70/70 [07:29<00:00, 6.55s/it]
Train: 100% |██████████| 45/45 [00:07<00:00, 8.58it/s, avg_loss=6.95, loss=6.07, lr=2.27e-5]
Eval: 100% |██████████| 5/5 [00:00<00:00, 8.31it/s, avg_loss=5.17, loss=5.31, lr=]
Train: 100% |██████████| 45/45 [00:05<00:00, 8.50it/s, avg_loss=5.42, loss=4.87, lr=3.07e-5]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.40it/s, avg_loss=4.3, loss=4.14, lr=]
Train: 100% |██████████| 45/45 [00:06<00:00, 7.09it/s, avg_loss=4.85, loss=4.88, lr=4.38e-5]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.72it/s, avg_loss=4.54, loss=4.73, lr=]
Train: 100% |██████████| 45/45 [00:05<00:00, 8.12it/s, avg_loss=4.55, loss=4.27, lr=6.18e-5]
Eval: 100% |██████████| 5/5 [00:00<00:00, 13.73it/s, avg_loss=4.16, loss=3.78, lr=]
Train: 100% |██████████| 45/45 [00:06<00:00, 8.77it/s, avg_loss=4.37, loss=4.64, lr=8.42e-5]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.63it/s, avg_loss=3.79, loss=3.36, lr=]
Train: 100% |██████████| 45/45 [00:05<00:00, 8.46it/s, avg_loss=4.53, loss=6.24, lr=0.000111]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.26it/s, avg_loss=3.81, loss=3.25, lr=]
Train: 100% |██████████| 45/45 [00:05<00:00, 8.30it/s, avg_loss=4.39, loss=4.33, lr=0.00014]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.50it/s, avg_loss=3.93, loss=3.63, lr=]
Train: 100% |██████████| 45/45 [00:05<00:00, 8.64it/s, avg_loss=4.2, loss=4.98, lr=0.000173]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.93it/s, avg_loss=3.85, loss=3.1, lr=]
Train: 100% |██████████| 45/45 [00:05<00:00, 8.43it/s, avg_loss=4.37, loss=4.64, lr=0.000207]
Eval: 100% |██████████| 5/5 [00:00<00:00, 15.57it/s, avg_loss=4.49, loss=4.54, lr=]
Train: 100% |██████████| 45/45 [00:05<00:00, 8.54it/s, avg_loss=4.26, loss=3.53, lr=0.000242]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.36it/s, avg_loss=4.11, loss=4.03, lr=]
Train: 100% |██████████| 45/45 [00:05<00:00, 8.48it/s, avg_loss=4.38, loss=4.53, lr=0.000278]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.60it/s, avg_loss=4.34, loss=3.82, lr=]
Train: 100% |██████████| 45/45 [00:05<00:00, 8.51it/s, avg_loss=4.58, loss=4.45, lr=0.000314]
Eval: 100% |██████████| 5/5 [00:00<00:00, 15.03it/s, avg_loss=4.42, loss=4.41, lr=]
Train: 100% |██████████| 45/45 [00:05<00:00, 8.58it/s, avg_loss=4.47, loss=3.38, lr=0.000348]
Eval: 100% |██████████| 5/5 [00:00<00:00, 15.05it/s, avg_loss=4.24, loss=3.27, lr=]
Train: 100% |██████████| 45/45 [00:05<00:00, 8.53it/s, avg_loss=4.44, loss=5.01, lr=0.00038]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.75it/s, avg_loss=4.22, loss=4.14, lr=]
Train: 100% |██████████| 45/45 [00:05<00:00, 8.58it/s, avg_loss=4.54, loss=4.36, lr=0.00041]
Eval: 100% |██████████| 5/5 [00:00<00:00, 15.24it/s, avg_loss=4.02, loss=3.7, lr=]
Train: 100% |██████████| 45/45 [00:05<00:00, 8.83it/s, avg_loss=4.55, loss=3.89, lr=0.000436]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.35it/s, avg_loss=4.04, loss=3.33, lr=]
Train: 100% |██████████| 45/45 [00:05<00:00, 8.58it/s, avg_loss=4.57, loss=4.49, lr=0.000459]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.79it/s, avg_loss=4.68, loss=4.85, lr=]
Train: 100% |██████████| 45/45 [00:05<00:00, 8.84it/s, avg_loss=4.57, loss=4.47, lr=0.000477]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.73it/s, avg_loss=3.98, loss=3.36, lr=]
Train: 100% |██████████| 45/45 [00:05<00:00, 8.83it/s, avg_loss=4.4, loss=4.59, lr=0.00049]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.27it/s, avg_loss=4.11, loss=3.59, lr=]
Train: 100% |██████████| 45/45 [00:05<00:00, 8.76it/s, avg_loss=4.59, loss=4.98, lr=0.000497]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.49it/s, avg_loss=3.98, loss=3.41, lr=]
Train: 100% |██████████| 45/45 [00:05<00:00, 8.85it/s, avg_loss=4.35, loss=4.5, lr=0.0005]
Eval: 100% |██████████| 5/5 [00:00<00:00, 15.15it/s, avg_loss=4, loss=3.34, lr=]
Train: 100% |██████████| 45/45 [00:05<00:00, 8.05it/s, avg_loss=4.6, loss=5.02, lr=0.000499]
Eval: 100% |██████████| 5/5 [00:00<00:00, 13.54it/s, avg_loss=4.14, loss=3.99, lr=]
Train: 100% |██████████| 45/45 [00:06<00:00, 8.92it/s, avg_loss=4.5, loss=3.75, lr=0.000498]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.00it/s, avg_loss=4.38, loss=4.55, lr=]
Train: 100% |██████████| 45/45 [00:05<00:00, 8.39it/s, avg_loss=4.25, loss=3.95, lr=0.000495]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.46it/s, avg_loss=3.72, loss=3.16, lr=]
Train: 100% |██████████| 45/45 [00:05<00:00, 8.55it/s, avg_loss=4.26, loss=5.19, lr=0.000492]
Eval: 100% |██████████| 5/5 [00:00<00:00, 15.02it/s, avg_loss=4.54, loss=4.14, lr=]
Train: 100% |██████████| 45/45 [00:05<00:00, 8.46it/s, avg_loss=4.15, loss=3.68, lr=0.000487]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.73it/s, avg_loss=3.94, loss=3.61, lr=]
Train: 100% |██████████| 45/45 [00:05<00:00, 8.77it/s, avg_loss=4.3, loss=3.22, lr=0.000482]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.81it/s, avg_loss=3.71, loss=3.57, lr=]
Train: 100% |██████████| 45/45 [00:05<00:00, 8.42it/s, avg_loss=4.08, loss=3.55, lr=0.000475]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.04it/s, avg_loss=3.88, loss=3.6, lr=]
Train: 100% |██████████| 45/45 [00:05<00:00, 8.42it/s, avg_loss=4.18, loss=3.19, lr=0.000468]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.78it/s, avg_loss=3.84, loss=3.7, lr=]
Train: 100% |██████████| 45/45 [00:05<00:00, 8.52it/s, avg_loss=4.09, loss=3.7, lr=0.000459]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.19it/s, avg_loss=3.91, loss=3.65, lr=]
Train: 100% |██████████| 45/45 [00:06<00:00, 8.16it/s, avg_loss=3.93, loss=4.28, lr=0.00045]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.02it/s, avg_loss=3.8, loss=3.52, lr=]
Train: 100% |██████████| 45/45 [00:06<00:00, 8.15it/s, avg_loss=4.04, loss=3.38, lr=0.00044]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.09it/s, avg_loss=3.88, loss=4.04, lr=]
Train: 100% |██████████| 45/45 [00:06<00:00, 8.36it/s, avg_loss=4.1, loss=3.53, lr=0.000429]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.44it/s, avg_loss=3.7, loss=2.95, lr=]
Train: 100% |██████████| 45/45 [00:05<00:00, 8.62it/s, avg_loss=4.05, loss=4.06, lr=0.000418]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.18it/s, avg_loss=3.78, loss=3.28, lr=]
Train: 100% |██████████| 45/45 [00:05<00:00, 8.20it/s, avg_loss=3.95, loss=3.53, lr=0.000406]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.48it/s, avg_loss=3.44, loss=3.38, lr=]
Train: 100% |██████████| 45/45 [00:05<00:00, 8.27it/s, avg_loss=3.86, loss=2.82, lr=0.000393]
Eval: 100% |██████████| 5/5 [00:00<00:00, 15.31it/s, avg_loss=3.63, loss=3, lr=]
Train: 100% |██████████| 45/45 [00:05<00:00, 8.35it/s, avg_loss=3.97, loss=3.48, lr=0.000379]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.21it/s, avg_loss=3.62, loss=3.22, lr=]
Train: 100% |██████████| 45/45 [00:05<00:00, 8.46it/s, avg_loss=3.72, loss=3.94, lr=0.000365]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.22it/s, avg_loss=3.45, loss=2.83, lr=]
Train: 100% |██████████| 45/45 [00:06<00:00, 8.30it/s, avg_loss=3.75, loss=3.34, lr=0.000351]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.41it/s, avg_loss=3.52, loss=3.38, lr=]
Train: 100% |██████████| 45/45 [00:05<00:00, 8.63it/s, avg_loss=3.7, loss=4.19, lr=0.000336]
Eval: 100% |██████████| 5/5 [00:00<00:00, 15.00it/s, avg_loss=3.56, loss=2.9, lr=]
Train: 100% |██████████| 45/45 [00:05<00:00, 8.47it/s, avg_loss=3.65, loss=4.22, lr=0.000321]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.94it/s, avg_loss=3.67, loss=3.11, lr=]
Train: 100% |██████████| 45/45 [00:05<00:00, 8.50it/s, avg_loss=3.58, loss=4.13, lr=0.000305]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.86it/s, avg_loss=3.55, loss=2.98, lr=]
Train: 100% |██████████| 45/45 [00:05<00:00, 8.55it/s, avg_loss=3.54, loss=3.29, lr=0.00029]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.17it/s, avg_loss=3.42, loss=2.62, lr=]
Train: 100% |██████████| 45/45 [00:05<00:00, 8.25it/s, avg_loss=3.51, loss=3.97, lr=0.000274]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.54it/s, avg_loss=3.33, loss=2.68, lr=]
Train: 100% |██████████| 45/45 [00:05<00:00, 8.39it/s, avg_loss=3.5, loss=2.83, lr=0.000258]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.36it/s, avg_loss=3.27, loss=2.94, lr=]
Train: 100% |██████████| 45/45 [00:05<00:00, 8.22it/s, avg_loss=3.45, loss=4.09, lr=0.000242]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.64it/s, avg_loss=3.63, loss=3.29, lr=]
Train: 100% |██████████| 45/45 [00:05<00:00, 8.37it/s, avg_loss=3.44, loss=2.97, lr=0.000226]
Eval: 100% |██████████| 5/5 [00:00<00:00, 13.97it/s, avg_loss=3.44, loss=2.87, lr=]
Train: 100% |██████████| 45/45 [00:06<00:00, 8.26it/s, avg_loss=3.35, loss=2.87, lr=0.00021]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.52it/s, avg_loss=3.35, loss=2.94, lr=]
Train: 100% |██████████| 45/45 [00:05<00:00, 8.29it/s, avg_loss=3.32, loss=3.1, lr=0.000194]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.92it/s, avg_loss=3.58, loss=3.28, lr=]
Train: 100% |██████████| 45/45 [00:05<00:00, 8.67it/s, avg_loss=3.21, loss=3.25, lr=0.000179]
Eval: 100% |██████████| 5/5 [00:00<00:00, 15.04it/s, avg_loss=3.36, loss=2.86, lr=]
Train: 100% |██████████| 45/45 [00:06<00:00, 8.12it/s, avg_loss=3.29, loss=2.95, lr=0.000163]
Eval: 100% |██████████| 5/5 [00:00<00:00, 15.03it/s, avg_loss=3.36, loss=2.87, lr=]
Train: 100% |██████████| 45/45 [00:06<00:00, 8.07it/s, avg_loss=3.21, loss=3.99, lr=0.000148]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.45it/s, avg_loss=3.32, loss=2.96, lr=]
Train: 100% |██████████| 45/45 [00:06<00:00, 8.23it/s, avg_loss=3.21, loss=2.92, lr=0.000134]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.38it/s, avg_loss=3.15, loss=2.81, lr=]
Train: 100% |██████████| 45/45 [00:06<00:00, 8.05it/s, avg_loss=3.13, loss=2.58, lr=0.00012]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.48it/s, avg_loss=3.39, loss=2.86, lr=]
Train: 100% |██████████| 45/45 [00:06<00:00, 8.27it/s, avg_loss=3.07, loss=2.13, lr=0.000107]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.16it/s, avg_loss=3.15, loss=2.68, lr=]
Train: 100% |██████████| 45/45 [00:06<00:00, 7.66it/s, avg_loss=3.12, loss=3.1, lr=9.39e-5]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.06it/s, avg_loss=3.27, loss=2.85, lr=]
Train: 100% |██████████| 45/45 [00:06<00:00, 7.56it/s, avg_loss=3.02, loss=3.05, lr=8.17e-5]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.53it/s, avg_loss=3.24, loss=2.74, lr=]
Train: 100% |██████████| 45/45 [00:06<00:00, 8.42it/s, avg_loss=2.99, loss=2.36, lr=7.02e-5]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.08it/s, avg_loss=3.1, loss=2.56, lr=]
Train: 100% |██████████| 45/45 [00:06<00:00, 8.24it/s, avg_loss=2.93, loss=2.53, lr=5.94e-5]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.33it/s, avg_loss=3.21, loss=2.85, lr=]
Train: 100% |██████████| 45/45 [00:06<00:00, 8.11it/s, avg_loss=2.98, loss=2.77, lr=4.94e-5]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.40it/s, avg_loss=3.31, loss=2.95, lr=]
Train: 100% |██████████| 45/45 [00:06<00:00, 8.16it/s, avg_loss=3.04, loss=3.37, lr=4.03e-5]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.74it/s, avg_loss=3.15, loss=2.93, lr=]
Train: 100% |██████████| 45/45 [00:05<00:00, 8.33it/s, avg_loss=3, loss=3.06, lr=3.2e-5]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.16it/s, avg_loss=3.1, loss=2.8, lr=]
Train: 100% |██████████| 45/45 [00:06<00:00, 8.36it/s, avg_loss=2.92, loss=2.94, lr=2.46e-5]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.99it/s, avg_loss=3.23, loss=2.85, lr=]
Train: 100% |██████████| 45/45 [00:06<00:00, 8.45it/s, avg_loss=2.86, loss=2.2, lr=1.81e-5]
Eval: 100% |██████████| 5/5 [00:00<00:00, 15.14it/s, avg_loss=3.06, loss=2.78, lr=]
Train: 100% |██████████| 45/45 [00:06<00:00, 8.53it/s, avg_loss=2.94, loss=2.69, lr=1.26e-5]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.66it/s, avg_loss=3.07, loss=2.53, lr=]
Train: 100% |██████████| 45/45 [00:06<00:00, 8.29it/s, avg_loss=2.86, loss=2.94, lr=8.09e-6]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.19it/s, avg_loss=3.04, loss=2.48, lr=]
Train: 100% |██████████| 45/45 [00:06<00:00, 8.56it/s, avg_loss=2.79, loss=2.45, lr=4.54e-6]
Eval: 100% |██████████| 5/5 [00:00<00:00, 13.92it/s, avg_loss=3.15, loss=2.65, lr=]
Train: 100% |██████████| 45/45 [00:06<00:00, 7.95it/s, avg_loss=2.87, loss=2.57, lr=2.01e-6]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.27it/s, avg_loss=3.02, loss=2.29, lr=]
Train: 100% |██████████| 45/45 [00:06<00:00, 8.31it/s, avg_loss=2.93, loss=2.63, lr=4.93e-7]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.89it/s, avg_loss=2.96, loss=2.65, lr=]
Train: 100% |██████████| 45/45 [00:06<00:00, 8.32it/s, avg_loss=2.87, loss=2.75, lr=2.25e-9]
Eval: 100% |██████████| 5/5 [00:00<00:00, 14.32it/s, avg_loss=3.07, loss=2.65, lr=]
At last, we have our fine-tuned Keypoint R-CNN model. To wrap up the tutorial, we can test our model by performing inference on individual images.
Making Predictions with the Model
In this final part of the tutorial, we will cover how to perform inference on individual images with our Mask R-CNN model and filter the predictions.
Prepare Input Data
Let’s use an image from the validation set. That way, we have some ground truth annotation data to compare against. Unlike during training, we won’t stick to square input dimensions for inference.
# Choose a random item from the validation set
= val_keys[0]
file_id
# Retrieve the image file path associated with the file ID
= img_dict[file_id]
test_file
# Open the test file
= Image.open(test_file).convert('RGB')
test_img
= resize_img(test_img, target_sz=train_sz, divisor=1)
input_img
# Calculate the scale between the source image and the resized image
= min(test_img.size) / min(input_img.size)
min_img_scale
display(test_img)
# Print the prediction data as a Pandas DataFrame for easy formatting
pd.Series({"Source Image Size:": test_img.size,
"Input Dims:": input_img.size,
"Min Image Scale:": min_img_scale,
"Input Image Size:": input_img.size
='columns') }).to_frame().style.hide(axis
Source Image Size: | (512, 768) |
---|---|
Input Dims: | (512, 768) |
Min Image Scale: | 1.000000 |
Input Image Size: | (512, 768) |
Get Target Annotation Data
# Extract the source annotations for the test image
= [shape['label'] for shape in annotation_df.loc[file_id]['shapes']]
gt_labels = torch.tensor(np.array([shape['points'] for shape in annotation_df.loc[file_id]['shapes']])).reshape(-1,2)
gt_keypoints = torch.cat((gt_keypoints, torch.ones(len(gt_keypoints), 2)*BBOX_DIM), dim=1) gt_keypoints_bboxes
Pass Input Data to the Model
Now, we can convert the test image to a tensor and pass it to the model. Ensure the model is set to evaluation mode to get predictions instead of loss values.
# Set the model to evaluation mode
eval();
model.
# Ensure the model and input data are on the same device
;
model.to(device)= transforms.Compose([transforms.ToImage(),
input_tensor =True)])(input_img)[None].to(device)
transforms.ToDtype(torch.float32, scale
# Make a prediction with the model
with torch.no_grad():
= model(input_tensor)[0] model_output
Filter the Model Output
The model performs most post-processing steps internally, so we only need to filter the output based on the desired confidence threshold. The model returns predictions as a list of dictionaries. Each dictionary stores bounding boxes, label indices, confidence scores, and key points for a single sample in the input batch.
Since we resized the test image, we must scale the key points to the source resolution.
# Set the confidence threshold
= 0.8
conf_threshold
# Filter the output based on the confidence threshold
= model_output['scores'] > conf_threshold
scores_mask
# Extract and scale the predicted keypoints
= (model_output['keypoints'][scores_mask])[:,:,:-1].reshape(-1,2)*min_img_scale predicted_keypoints
Compare Model Predictions with the Source Annotations
Finally, we can compare the model predictions with the ground-truth annotations.
# Annotate the test image with the ground-truth annotations
= draw_bboxes(
gt_annotated_tensor =transforms.PILToTensor()(test_img),
image=torchvision.ops.box_convert(torch.Tensor(gt_keypoints_bboxes), 'cxcywh', 'xyxy'),
boxes# labels=gt_labels,
=[int_colors[i] for i in [class_names.index(label) for label in gt_labels]]
colors
)
# Prepare the labels and bounding box annotations for the test image
= class_names*sum(scores_mask).item()
labels = torch.cat((predicted_keypoints.cpu(), torch.ones(len(predicted_keypoints), 2)), dim=1)
keypoints_bboxes
# Annotate the test image with the model predictions
= draw_bboxes(
annotated_tensor =transforms.PILToTensor()(test_img),
image=torchvision.ops.box_convert(torch.Tensor(keypoints_bboxes), 'cxcywh', 'xyxy'),
boxes# labels=labels,
=[int_colors[i] for i in [class_names.index(label) for label in labels]]
colors
)
stack_imgs([tensor_to_pil(gt_annotated_tensor), tensor_to_pil(annotated_tensor)])
The model appears to have learned to detect eyes and noses as desired.
Conclusion
Congratulations on completing this tutorial for training Keypoint R-CNN models in PyTorch! The skills and knowledge you acquired here provide a solid foundation for future projects.
As a next step, perhaps try annotating a keypoint dataset with LabelMe for your own Keypoint R-CNN model or experiment with the data augmentations to see how they impact model accuracy.
Recommended Tutorials
- Exporting Keypoint R-CNN Models from PyTorch to ONNX: Learn how to export Keypoint R-CNN models from PyTorch to ONNX and perform inference using ONNX Runtime.
- Training Mask R-CNN Models with PyTorch: Learn how to train Mask R-CNN models on custom datasets with PyTorch.
- Feel free to post questions or problems related to this tutorial in the comments below. I try to make time to address them on Thursdays and Fridays.
I’m Christian Mills, a deep learning consultant specializing in practical AI implementations. I help clients leverage cutting-edge AI technologies to solve real-world problems.
Interested in working together? Fill out my Quick AI Project Assessment form or learn more about me.