Training Mask R-CNN Models with PyTorch
- Introduction
- Getting Started with the Code
- Setting Up Your Python Environment
- Importing the Required Dependencies
- Setting Up the Project
- Loading and Exploring the Dataset
- Loading the Mask R-CNN Model
- Preparing the Data
- Fine-tuning the Model
- Making Predictions with the Model
- Conclusion
Introduction
Welcome to this hands-on guide to training Mask R-CNN models in PyTorch! Mask R-CNN models can identify and locate multiple objects within images and generate segmentation masks for each detected object.
For this tutorial, we will fine-tune a Mask R-CNN model from the torchvision
library on a small sample dataset of annotated student ID card images.
This tutorial is suitable for anyone with rudimentary PyTorch experience. If you are new to PyTorch and want to start with a beginner-focused project, check out my tutorial on fine-tuning image classifiers.
I updated the tutorial code for torchvision 0.16.0
.
Getting Started with the Code
The tutorial code is available as a Jupyter Notebook, which you can run locally or in a cloud-based environment like Google Colab. I have dedicated tutorials for those new to these platforms or who need guidance setting up:
Platform | Jupyter Notebook | Utility File |
---|---|---|
Google Colab | Open In Colab | |
Linux | GitHub Repository | |
Windows | GitHub Repository | windows_utils.py |
The code in this tutorial targets Linux platforms, but most of it should also work on macOS and Windows.
However, Python multiprocessing works differently on those platforms, requiring some changes to leverage multi-processing for the DataLoader
objects.
I’ve made a dedicated version of the tutorial code to run on Windows. The included changes should also work on macOS, but I don’t have a Mac to verify.
Setting Up Your Python Environment
Before diving into the code, we’ll cover the steps to create a local Python environment and install the necessary dependencies. The dedicated Colab Notebook includes the code to install the required dependencies in Google Colab.
Creating a Python Environment
First, we’ll create a Python environment using Conda/Mamba. Open a terminal with Conda/Mamba installed and run the following commands:
# Create a new Python 3.10 environment
conda create --name pytorch-env python=3.10 -y
# Activate the environment
conda activate pytorch-env
# Create a new Python 3.10 environment
mamba create --name pytorch-env python=3.10 -y
# Activate the environment
mamba activate pytorch-env
Installing PyTorch
Next, we’ll install PyTorch. Run the appropriate command for your hardware and operating system.
# Install PyTorch with CUDA
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
# MPS (Metal Performance Shaders) acceleration is available on MacOS 12.3+
pip install torch torchvision torchaudio
# Install PyTorch for CPU only
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
# Install PyTorch for CPU only
pip install torch torchvision torchaudio
Installing Additional Libraries
We also need to install some additional libraries for our project.
Package | Description |
---|---|
jupyter |
An open-source web application that allows you to create and share documents that contain live code, equations, visualizations, and narrative text. (link) |
matplotlib |
This package provides a comprehensive collection of visualization tools to create high-quality plots, charts, and graphs for data exploration and presentation. (link) |
pandas |
This package provides fast, powerful, and flexible data analysis and manipulation tools. (link) |
pillow |
The Python Imaging Library adds image processing capabilities. (link) |
torchtnt |
A library for PyTorch training tools and utilities. (link) |
tqdm |
A Python library that provides fast, extensible progress bars for loops and other iterable objects in Python. (link) |
tabulate |
Pretty-print tabular data in Python. (link) |
distinctipy |
A lightweight python package providing functions to generate colours that are visually distinct from one another. (link) |
Run the following commands to install these additional libraries:
# Install additional dependencies
pip install distinctipy jupyter matplotlib pandas pillow torchtnt==0.2.0 tqdm tabulate
Installing Utility Packages
We’ll also install some utility packages I made to help us handle images, interact with PyTorch, and work with Pandas DataFrames. These utility packages provide shortcuts for routine tasks and keep our code clean and readable.
Package | Description |
---|---|
cjm_pandas_utils |
Some utility functions for working with Pandas. (link) |
cjm_pil_utils |
Some PIL utility functions I frequently use. (link) |
cjm_psl_utils |
Some utility functions using the Python Standard Library. (link) |
cjm_pytorch_utils |
Some utility functions for working with PyTorch. (link) |
cjm_torchvision_tfms |
Some custom Torchvision tranforms. (link) |
Run the following commands to install the utility packages:
# Install additional utility packages
pip install cjm_pandas_utils cjm_pil_utils cjm_psl_utils cjm_pytorch_utils cjm_torchvision_tfms
Importing the Required Dependencies
With our environment set up, let’s dive into the code. First, we will import the necessary Python packages into our Jupyter Notebook.
# Import Python Standard Library dependencies
import datetime
from functools import partial
from glob import glob
import json
import math
import multiprocessing
import os
from pathlib import Path
import random
from typing import Any, Dict, Optional
# Import utility functions
from cjm_psl_utils.core import download_file, file_extract, get_source_code
from cjm_pil_utils.core import resize_img, get_img_files, stack_imgs
from cjm_pytorch_utils.core import pil_to_tensor, tensor_to_pil, get_torch_device, set_seed, denorm_img_tensor, move_data_to_device
from cjm_pandas_utils.core import markdown_to_pandas, convert_to_numeric, convert_to_string
from cjm_torchvision_tfms.core import ResizeMax, PadSquare, CustomRandomIoUCrop
# Import the distinctipy module
from distinctipy import distinctipy
# Import matplotlib for creating plots
import matplotlib.pyplot as plt
# Import numpy
import numpy as np
# Import the pandas package
import pandas as pd
# Set options for Pandas DataFrame display
'max_colwidth', None) # Do not truncate the contents of cells in the DataFrame
pd.set_option('display.max_rows', None) # Display all rows in the DataFrame
pd.set_option('display.max_columns', None) # Display all columns in the DataFrame
pd.set_option(
# Import PIL for image manipulation
from PIL import Image, ImageDraw
# Import PyTorch dependencies
import torch
from torch.amp import autocast
from torch.cuda.amp import GradScaler
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchtnt.utils import get_module_summary
import torchvision
torchvision.disable_beta_transforms_warning()from torchvision.tv_tensors import BoundingBoxes, Mask
from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks
import torchvision.transforms.v2 as transforms
from torchvision.transforms.v2 import functional as TF
# Import Mask R-CNN
from torchvision.models.detection import maskrcnn_resnet50_fpn_v2, MaskRCNN
from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
# Import tqdm for progress bar
from tqdm.auto import tqdm
Setting Up the Project
In this section, we set up some basics for our project, such as initializing random number generators, setting the PyTorch device to run the model, and preparing the folders for our project and datasets.
Setting a Random Number Seed
First, we set the seed for generating random numbers using the set_seed function from the cjm_pytorch_utils
package.
# Set the seed for generating random numbers in PyTorch, NumPy, and Python's random module.
= 1234
seed set_seed(seed)
Setting the Device and Data Type
Next, we determine the device to use for training using the get_torch_device function from the cjm_pytorch_utils
package and set the data type of our tensors.
= get_torch_device()
device = torch.float32
dtype device, dtype
('cuda', torch.float32)
Setting the Directory Paths
We can then set up a directory for our project to store our results and other related files. The following code creates the folder in the current directory (./
). Update the path if that is not suitable for you.
We also need a place to store our dataset. Readers following the tutorial on their local machine should select a location with read-and-write access to store datasets. For a cloud service like Google Colab, you can set it to the current directory.
# The name for the project
= f"pytorch-mask-r-cnn-instance-segmentation"
project_name
# The path for the project folder
= Path(f"./{project_name}/")
project_dir
# Create the project directory if it does not already exist
=True, exist_ok=True)
project_dir.mkdir(parents
# Define path to store datasets
= Path("./Datasets/")
dataset_dir # Create the dataset directory if it does not exist
=True, exist_ok=True)
dataset_dir.mkdir(parents
pd.Series({"Project Directory:": project_dir,
"Dataset Directory:": dataset_dir
='columns') }).to_frame().style.hide(axis
Project Directory: | pytorch-mask-r-cnn-instance-segmentation |
---|---|
Dataset Directory: | Datasets |
Double-check the project and dataset directories exist in the specified paths and that you can add files to them before continuing. At this point, our project is set up and ready to go. In the next section, we will download and explore the dataset.
Loading and Exploring the Dataset
Now that we set up the project, we can start working with our dataset. The dataset is originally from the following GitHub repository:
I made a fork of the original repository with only the files needed for this tutorial, which takes up approximately 77 MB.
The segmentation masks for this dataset uses the LabelMe annotation format. You can learn more about this format and how to work with such annotations in the tutorial linked below:
Setting the Dataset Path
We first need to construct the name for the GitHub repository and define the path to the subfolder with the dataset.
# Set the name of the dataset
= 'pytorch-for-information-extraction'
dataset_name
# Construct the GitHub repository name
= f'cj-mills/{dataset_name}'
gh_repo
# Create the path to the directory where the dataset will be extracted
= Path(f'{dataset_dir}/{dataset_name}/code/datasets/detection/student-id/')
dataset_path
pd.Series({"GitHub Repository:": gh_repo,
"Dataset Path:": dataset_path
='columns') }).to_frame().style.hide(axis
GitHub Repository: | cj-mills/pytorch-for-information-extraction |
---|---|
Dataset Path: | Datasets/pytorch-for-information-extraction/code/datasets/detection/student-id |
Downloading the Dataset
We can now clone the repository to the dataset directory we defined earlier.
# Clone the dataset repository from GitHub
!git clone {f'https://github.com/{gh_repo}.git'} {dataset_dir/dataset_name}
Getting the Image and Annotation Files
The dataset folder contains sample images and annotation files. Each sample image has its own JSON annotation file.
# Get a list of image files in the dataset
= get_img_files(dataset_path)
img_file_paths
# Get a list of JSON files in the dataset
= list(dataset_path.glob('*.json'))
annotation_file_paths
# Display the names of the folders using a Pandas DataFrame
"Image File": [file.name for file in img_file_paths],
pd.DataFrame({"Annotation File":[file.name for file in annotation_file_paths]}).head()
Image File | Annotation File | |
---|---|---|
0 | 10134.jpg | 10134.json |
1 | 10135.jpg | 10135.json |
2 | 10136.jpg | 10136.json |
3 | 10137.jpg | 10137.json |
4 | 10138.jpg | 10138.json |
Get Image File Paths
Each image file has a unique name that we can use to locate the corresponding annotation data. Let’s make a dictionary that maps image names to file paths. The dictionary will allow us to retrieve the file path for a given image more efficiently.
# Create a dictionary that maps file names to file paths
= {file.stem : file for file in img_file_paths}
img_dict
# Print the number of image files
print(f"Number of Images: {len(img_dict)}")
# Display the first five entries from the dictionary using a Pandas DataFrame
='index').head() pd.DataFrame.from_dict(img_dict, orient
Number of Images: 150
0 | |
---|---|
10134 | Datasets/pytorch-for-information-extraction/code/datasets/detection/student-id/10134.jpg |
10135 | Datasets/pytorch-for-information-extraction/code/datasets/detection/student-id/10135.jpg |
10136 | Datasets/pytorch-for-information-extraction/code/datasets/detection/student-id/10136.jpg |
10137 | Datasets/pytorch-for-information-extraction/code/datasets/detection/student-id/10137.jpg |
10138 | Datasets/pytorch-for-information-extraction/code/datasets/detection/student-id/10138.jpg |
Get Image Annotations
Next, we read the contents of the JSON annotation files into a Pandas DataFrame so we can easily query the annotations.
# Create a generator that yields Pandas DataFrames containing the data from each JSON file
= (pd.read_json(f, orient='index').transpose() for f in tqdm(annotation_file_paths))
cls_dataframes
# Concatenate the DataFrames into a single DataFrame
= pd.concat(cls_dataframes, ignore_index=False)
annotation_df
# Assign the image file name as the index for each row
'index'] = annotation_df.apply(lambda row: row['imagePath'].split('.')[0], axis=1)
annotation_df[= annotation_df.set_index('index')
annotation_df
# Keep only the rows that correspond to the filenames in the 'img_dict' dictionary
= annotation_df.loc[list(img_dict.keys())]
annotation_df
# Print the first 5 rows of the DataFrame
annotation_df.head()
version | flags | shapes | lineColor | fillColor | imagePath | imageData | imageHeight | imageWidth | |
---|---|---|---|---|---|---|---|---|---|
index | |||||||||
10134 | 3.21.1 | {} | [{‘label’: ‘student_id’, ‘line_color’: None, ‘fill_color’: None, ‘points’: [[83.7142857142857, 133.57142857142856], [86.57142857142856, 123.57142857142856], [95.14285714285714, 117.14285714285714], [595.1428571428571, 125.71428571428571], [604.4285714285713, 127.85714285714285], [607.2857142857142, 138.57142857142856], [619.4285714285713, 443.57142857142856], [612.2857142857142, 449.2857142857142], [97.99999999999997, 469.2857142857142], [85.14285714285714, 465.71428571428567], [78.0, 457.1428571428571]], ‘shape_type’: ‘polygon’, ‘flags’: {}}] | [0, 255, 0, 128] | [255, 0, 0, 128] | 10134.jpg | 480 | 640 | |
10135 | 3.21.1 | {} | [{‘label’: ‘student_id’, ‘line_color’: None, ‘fill_color’: None, ‘points’: [[125.85714285714283, 288.57142857142856], [391.57142857142856, 24.285714285714285], [459.4285714285714, 7.857142857142857], [612.2857142857142, 166.42857142857142], [612.2857142857142, 174.28571428571428], [334.4285714285714, 477.85714285714283], [321.57142857142856, 478.5714285714285], [127.99999999999997, 297.1428571428571]], ‘shape_type’: ‘polygon’, ‘flags’: {}}] | [0, 255, 0, 128] | [255, 0, 0, 128] | 10135.jpg | 480 | 640 | |
10136 | 3.21.1 | {} | [{‘label’: ‘student_id’, ‘line_color’: None, ‘fill_color’: None, ‘points’: [[62.28571428571428, 44.285714285714285], [70.85714285714285, 39.99999999999999], [571.5714285714286, 81.42857142857142], [582.9999999999999, 90.71428571428571], [634.4285714285713, 374.99999999999994], [634.4285714285713, 389.2857142857142], [622.9999999999999, 394.2857142857142], [46.571428571428555, 427.1428571428571], [35.85714285714285, 424.99999999999994], [30.857142857142847, 414.99999999999994]], ‘shape_type’: ‘polygon’, ‘flags’: {}}] | [0, 255, 0, 128] | [255, 0, 0, 128] | 10136.jpg | 480 | 640 | |
10137 | 3.21.1 | {} | [{‘label’: ‘student_id’, ‘line_color’: None, ‘fill_color’: None, ‘points’: [[81.57142857142856, 137.85714285714283], [84.42857142857142, 129.28571428571428], [273.71428571428567, 29.999999999999996], [284.4285714285714, 29.999999999999996], [549.4285714285713, 277.85714285714283], [550.8571428571428, 288.57142857142856], [362.2857142857142, 472.85714285714283], [354.4285714285714, 472.85714285714283], [345.1428571428571, 467.1428571428571]], ‘shape_type’: ‘polygon’, ‘flags’: {}}, {‘label’: ‘student_id’, ‘line_color’: None, ‘fill_color’: None, ‘points’: [[324.4285714285714, 69.28571428571428], [340.1428571428571, 0.7142857142857141], [525.8571428571428, 27.857142857142854], [529.4285714285713, 177.14285714285714], [395.1428571428571, 135.0]], ‘shape_type’: ‘polygon’, ‘flags’: {}}] | [0, 255, 0, 128] | [255, 0, 0, 128] | 10137.jpg | 480 | 640 | |
10138 | 3.21.1 | {} | [{‘label’: ‘student_id’, ‘line_color’: None, ‘fill_color’: None, ‘points’: [[202.28571428571425, 12.142857142857142], [208.71428571428567, 2.857142857142856], [434.4285714285714, 70.0], [411.57142857142856, 392.1428571428571], [407.2857142857142, 445.71428571428567], [402.99999999999994, 453.57142857142856], [392.2857142857142, 454.99999999999994], [165.85714285714283, 470.71428571428567], [155.85714285714283, 467.85714285714283], [152.28571428571428, 459.2857142857142]], ‘shape_type’: ‘polygon’, ‘flags’: {}}] | [0, 255, 0, 128] | [255, 0, 0, 128] | 10138.jpg | 480 | 640 |
The shapes
column contains the point coordinates to draw the segmentation masks. We will also use this information to generate the associated bounding box annotations.
Inspecting the Class Distribution
Now that we have the annotation data, we can extract the unique class names and inspect the class distribution. This small sample dataset only has one object class, but reviewing the class distribution is still good practice for other datasets.
Get image classes
# Explode the 'shapes' column in the annotation_df dataframe
# Convert the resulting series to a dataframe and rename the 'shapes' column to 'shapes'
# Apply the pandas Series function to the 'shapes' column of the dataframe
= annotation_df['shapes'].explode().to_frame().shapes.apply(pd.Series) shapes_df
# Get a list of unique labels in the 'annotation_df' DataFrame
= shapes_df['label'].unique().tolist()
class_names
# Display labels using a Pandas DataFrame
pd.DataFrame(class_names)
0 | |
---|---|
0 | student_id |
Visualize the class distribution
# Get the number of samples for each object class
= shapes_df['label'].value_counts()
class_counts
# Plot the distribution
='bar')
class_counts.plot(kind'Class distribution')
plt.title('Count')
plt.ylabel('Classes')
plt.xlabel(range(len(class_counts.index)), class_names, rotation=75) # Set the x-axis tick labels
plt.xticks( plt.show()
Add a background class
The Mask R-CNN model provided with the torchvision library expects datasets to have a background
class. We can prepend one to the list of class names.
# Prepend a `background` class to the list of class names
= ['background']+class_names
class_names
# Display labels using a Pandas DataFrame
pd.DataFrame(class_names)
0 | |
---|---|
0 | background |
1 | student_id |
Visualizing Image Annotations
Lastly, we will visualize the segmentation masks and bounding boxes for one of the sample images to demonstrate how to interpret the annotations.
Generate a color map
While not required, assigning a unique color to segmentation masks and bounding boxes for each object class enhances visual distinction, allowing for easier identification of different objects in the scene. We can use the distinctipy
package to generate a visually distinct colormap.
# Generate a list of colors with a length equal to the number of labels
= distinctipy.get_colors(len(class_names))
colors
# Make a copy of the color map in integer format
= [tuple(int(c*255) for c in color) for color in colors]
int_colors
# Generate a color swatch to visualize the color map
distinctipy.color_swatch(colors)
Download a font file
The draw_bounding_boxes
function included with torchvision uses a pretty small font size. We can increase the font size if we use a custom font. Font files are available on sites like Google Fonts, or we can use one included with the operating system.
# Set the name of the font file
= 'KFOlCnqEu92Fr1MmEU9vAw.ttf'
font_file
# Download the font file
f"https://fonts.gstatic.com/s/roboto/v30/{font_file}", "./") download_file(
Define the bounding box annotation function
Let’s make a partial function using draw_bounding_boxes
since we’ll use the same box thickness and font each time we visualize bounding boxes.
= partial(draw_bounding_boxes, fill=False, width=2, font=font_file, font_size=25) draw_bboxes
Selecting a Sample Image
We can use the unique ID for an image in the image dictionary to get the image’s file path and the associated annotations from the annotation DataFrame.
Load the sample image
# Get the file ID of the first image file
= list(img_dict.keys())[56]
file_id
# Open the associated image file as a RGB image
= Image.open(img_dict[file_id]).convert('RGB')
sample_img
# Print the dimensions of the image
print(f"Image Dims: {sample_img.size}")
# Show the image
sample_img
Image Dims: (640, 480)
Inspect the corresponding annotation data
# Get the row from the 'annotation_df' DataFrame corresponding to the 'file_id'
annotation_df.loc[file_id].to_frame()
10067 | |
---|---|
version | 3.21.1 |
flags | {} |
shapes | [{‘label’: ‘student_id’, ‘line_color’: None, ‘fill_color’: None, ‘points’: [[90.85714285714283, 22.142857142857142], [414.4285714285714, 17.857142857142854], [427.2857142857142, 19.285714285714285], [430.1428571428571, 24.999999999999996], [437.99999999999994, 222.85714285714283], [432.99999999999994, 227.1428571428571], [270.1428571428571, 231.42857142857142], [101.57142857142856, 234.28571428571428], [92.28571428571428, 232.85714285714283], [88.0, 227.85714285714283], [89.42857142857142, 44.99999999999999], [88.0, 31.428571428571427]], ‘shape_type’: ‘polygon’, ‘flags’: {}}, {‘label’: ‘student_id’, ‘line_color’: None, ‘fill_color’: None, ‘points’: [[0.14285714285713802, 226.42857142857142], [85.85714285714283, 107.14285714285714], [87.28571428571428, 234.99999999999997], [101.57142857142856, 235.7142857142857], [266.57142857142856, 231.42857142857142], [91, 479], [38, 479], [0, 453]], ‘shape_type’: ‘polygon’, ‘flags’: {}}, {‘label’: ‘student_id’, ‘line_color’: None, ‘fill_color’: None, ‘points’: [[287.99999999999994, 257.1428571428571], [293.71428571428567, 246.42857142857142], [314.4285714285714, 246.42857142857142], [615.1428571428571, 247.1428571428571], [621.5714285714286, 253.57142857142856], [626.5714285714286, 456.4285714285714], [617.9999999999999, 461.4285714285714], [311.57142857142856, 471.4285714285714], [297.2857142857143, 472.1428571428571], [290.1428571428571, 469.99999999999994], [285.1428571428571, 464.99999999999994]], ‘shape_type’: ‘polygon’, ‘flags’: {}}] |
lineColor | [0, 255, 0, 128] |
fillColor | [255, 0, 0, 128] |
imagePath | 10067.jpg |
imageData | |
imageHeight | 480 |
imageWidth | 640 |
The lists of point coordinates in the shapes column are the vertices of a polygon for the individual segmentation masks. We can use these to generate images for each segmentation mask.
Define a function to convert segmentation polygons to images
def create_polygon_mask(image_size, vertices):
"""
Create a grayscale image with a white polygonal area on a black background.
Parameters:
- image_size (tuple): A tuple representing the dimensions (width, height) of the image.
- vertices (list): A list of tuples, each containing the x, y coordinates of a vertex
of the polygon. Vertices should be in clockwise or counter-clockwise order.
Returns:
- PIL.Image.Image: A PIL Image object containing the polygonal mask.
"""
# Create a new black image with the given dimensions
= Image.new('L', image_size, 0)
mask_img
# Draw the polygon on the image. The area inside the polygon will be white (255).
'L').polygon(vertices, fill=(255))
ImageDraw.Draw(mask_img,
# Return the image with the drawn polygon
return mask_img
Annotate sample image
The torchvision library provides a draw_segmentation_masks
function to annotate images with segmentation masks. We can use the masks_to_boxes
function included with torchvision to generate bounding box annotations in the [top-left X, top-left Y, bottom-right X, bottom-right Y]
format from the segmentation masks. That is the same format the draw_bounding_boxes
function expects so we can use the output directly.
# Extract the labels for the sample
= [shape['label'] for shape in annotation_df.loc[file_id]['shapes']]
labels # Extract the polygon points for segmentation mask
= [shape['points'] for shape in annotation_df.loc[file_id]['shapes']]
shape_points # Format polygon points for PIL
= [[tuple(p) for p in points] for points in shape_points]
xy_coords # Generate mask images from polygons
= [create_polygon_mask(sample_img.size, xy) for xy in xy_coords]
mask_imgs # Convert mask images to tensors
= torch.concat([Mask(transforms.PILToTensor()(mask_img), dtype=torch.bool) for mask_img in mask_imgs])
masks # Generate bounding box annotations from segmentation masks
= torchvision.ops.masks_to_boxes(masks)
bboxes
# Annotate the sample image with segmentation masks
= draw_segmentation_masks(
annotated_tensor =transforms.PILToTensor()(sample_img),
image=masks,
masks=0.3,
alpha=[int_colors[i] for i in [class_names.index(label) for label in labels]]
colors
)
# Annotate the sample image with labels and bounding boxes
= draw_bboxes(
annotated_tensor =annotated_tensor,
image=bboxes,
boxes=labels,
labels=[int_colors[i] for i in [class_names.index(label) for label in labels]]
colors
)
tensor_to_pil(annotated_tensor)
We have explored the dataset and visualized the annotations for a sample image. In the next section, we will load and prepare our model.
Loading the Mask R-CNN Model
TorchVision provides checkpoints for the Mask R-CNN model trained on the COCO (Common Objects in Context) dataset. We can initialize a model with these pretrained weights using the maskrcnn_resnet50_fpn_v2
function. We must then replace the bounding box and segmentation mask predictors for the pretrained model with new ones for our dataset.
# Initialize a Mask R-CNN model with pretrained weights
= maskrcnn_resnet50_fpn_v2(weights='DEFAULT')
model
# Get the number of input features for the classifier
= model.roi_heads.box_predictor.cls_score.in_features
in_features_box = model.roi_heads.mask_predictor.conv5_mask.in_channels
in_features_mask
# Get the numbner of output channels for the Mask Predictor
= model.roi_heads.mask_predictor.conv5_mask.out_channels
dim_reduced
# Replace the box predictor
= FastRCNNPredictor(in_channels=in_features_box, num_classes=len(class_names))
model.roi_heads.box_predictor
# Replace the mask predictor
= MaskRCNNPredictor(in_channels=in_features_mask, dim_reduced=dim_reduced, num_classes=len(class_names))
model.roi_heads.mask_predictor
# Set the model's device and data type
=device, dtype=dtype);
model.to(device
# Add attributes to store the device and model name for later reference
= device
model.device = 'maskrcnn_resnet50_fpn_v2' model.name
The model internally normalizes input using the mean and standard deviation values used during the pretraining process, so we do not need to keep track of them separately.
Summarizing the Model
Before moving on, let’s generate a summary of our model to get an overview of its performance characteristics. We can use this to gauge the computational requirements for deploying the model.
= torch.randn(1, 3, 256, 256).to(device)
test_inp
= markdown_to_pandas(f"{get_module_summary(model.eval(), [test_inp])}")
summary_df
# # Filter the summary to only contain Conv2d layers and the model
= summary_df[summary_df.index == 0]
summary_df
'In size', 'Out size', 'Contains Uninitialized Parameters?'], axis=1) summary_df.drop([
Type | # Parameters | # Trainable Parameters | Size (bytes) | Forward FLOPs | |
---|---|---|---|---|---|
0 | MaskRCNN | 45.9 M | 45.7 M | 183 M | 331 G |
The above table shows the model has approximately 45.7
million trainable parameters. It takes up 183
Megabytes and performs around 331
billion floating point operations for a single 256x256
RGB image. This model internally resizes input images and executes the same number of floating point operations for different input resolutions.
That completes the model setup. In the next section, we will prepare our dataset for training.
Preparing the Data
The data preparation involves several steps, such as applying data augmentation techniques, setting up the train-validation split for the dataset, resizing and padding the images, defining the training dataset class, and initializing DataLoaders to feed data to the model.
Training-Validation Split
Let’s begin by defining the training-validation split. We’ll randomly select 80% of the available samples for the training set and use the remaining 20% for the validation set.
# Get the list of image IDs
= list(img_dict.keys())
img_keys
# Shuffle the image IDs
random.shuffle(img_keys)
# Define the percentage of the images that should be used for training
= 0.8
train_pct = 0.2
val_pct
# Calculate the index at which to split the subset of image paths into training and validation sets
= int(len(img_keys)*train_pct)
train_split = int(len(img_keys)*(train_pct+val_pct))
val_split
# Split the subset of image paths into training and validation sets
= img_keys[:train_split]
train_keys = img_keys[train_split:]
val_keys
# Print the number of images in the training and validation sets
pd.Series({"Training Samples:": len(train_keys),
"Validation Samples:": len(val_keys)
='columns') }).to_frame().style.hide(axis
Training Samples: | 120 |
---|---|
Validation Samples: | 30 |
Data Augmentation
Next, we can define what data augmentations to apply to images during training. I created a few custom image transforms to help streamline the code.
The first extends torchvision’s RandomIoUCrop
transform to give the user more control over how much it crops into bounding box areas. The second resizes images based on their largest dimension rather than their smallest. The third applies square padding and allows the padding to be applied equally on both sides or randomly split between the two sides.
All three are available through the cjm-torchvision-tfms
package.
Set training image size
First, we’ll set the size to use for training. The ResizeMax
transform will resize images so that the longest dimension equals this value while preserving the aspect ratio. The PadSquare
transform will then pad the other side to make all the input squares.
# Set training image size
= 512 train_sz
Initialize the transforms
Now, we can initialize the transform objects. The jitter_factor
parameter for the CustomRandomIoUCrop
transform controls how much the center coordinates for the crop area can deviate from the center of a bounding box. Setting this to a value greater than zero allows the transform to crop into the bounding box area. We’ll keep this value small as cutting into the hand gestures too much will change their meaning.
# Create a RandomIoUCrop object
= CustomRandomIoUCrop(min_scale=0.3,
iou_crop =1.0,
max_scale=0.5,
min_aspect_ratio=2.0,
max_aspect_ratio=[0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0],
sampler_options=400,
trials=0.25) jitter_factor
# Create a `ResizeMax` object
= ResizeMax(max_sz=train_sz)
resize_max
# Create a `PadSquare` object
= PadSquare(shift=True, fill=0) pad_square
We must use a scalar value for the fill
parameter when applying the PadSquare
transform to images with segmentation masks.
Test the transforms
We’ll pass input through the CustomRandomIoUCrop
transform first and then through ResizeMax
and PadSquare
. We can pass the result through a final resize operation to ensure both sides match the train_sz
value.
# Extract the labels for the sample
= [shape['label'] for shape in annotation_df.loc[file_id]['shapes']]
labels # Extract the polygon points for segmentation mask
= [shape['points'] for shape in annotation_df.loc[file_id]['shapes']]
shape_points # Format polygon points for PIL
= [[tuple(p) for p in points] for points in shape_points]
xy_coords # Generate mask images from polygons
= [create_polygon_mask(sample_img.size, xy) for xy in xy_coords]
mask_imgs # Convert mask images to tensors
= torch.concat([Mask(transforms.PILToTensor()(mask_img), dtype=torch.bool) for mask_img in mask_imgs])
masks # Generate bounding box annotations from segmentation masks
= BoundingBoxes(data=torchvision.ops.masks_to_boxes(masks), format='xyxy', canvas_size=sample_img.size[::-1])
bboxes
# Get colors for dataset sample
= [int_colors[i] for i in [class_names.index(label) for label in labels]]
sample_colors
# Prepare mask and bounding box targets
= {
targets 'masks': Mask(masks),
'boxes': bboxes,
'labels': torch.Tensor([class_names.index(label) for label in labels])
}
# Crop the image
= iou_crop(sample_img, targets)
cropped_img, targets
# Resize the image
= resize_max(cropped_img, targets)
resized_img, targets
# Pad the image
= pad_square(resized_img, targets)
padded_img, targets
# Ensure the padded image is the target size
= transforms.Resize([train_sz] * 2, antialias=True)
resize = resize(padded_img, targets)
resized_padded_img, targets = transforms.SanitizeBoundingBoxes()(resized_padded_img, targets)
sanitized_img, targets
# Annotate the sample image with segmentation masks
= draw_segmentation_masks(
annotated_tensor =transforms.PILToTensor()(sanitized_img),
image=targets['masks'],
masks=0.3,
alpha=sample_colors
colors
)
# Annotate the sample image with labels and bounding boxes
= draw_bboxes(
annotated_tensor =annotated_tensor,
image=targets['boxes'],
boxes=[class_names[int(label.item())] for label in targets['labels']],
labels=sample_colors
colors
)
# # Display the annotated image
display(tensor_to_pil(annotated_tensor))
pd.Series({"Source Image:": sample_img.size,
"Cropped Image:": cropped_img.size,
"Resized Image:": resized_img.size,
"Padded Image:": padded_img.size,
"Resized Padded Image:": resized_padded_img.size,
='columns') }).to_frame().style.hide(axis
Source Image: | (640, 480) |
---|---|
Cropped Image: | (286, 387) |
Resized Image: | (378, 511) |
Padded Image: | (511, 511) |
Resized Padded Image: | (512, 512) |
Training Dataset Class
Now, we can define a custom dataset class to load images, extract the segmentation masks, generate the bounding box annotations, and apply the image transforms during training.
class StudentIDDataset(Dataset):
"""
This class represents a PyTorch Dataset for a collection of images and their annotations.
The class is designed to load images along with their corresponding segmentation masks, bounding box annotations, and labels.
"""
def __init__(self, img_keys, annotation_df, img_dict, class_to_idx, transforms=None):
"""
Constructor for the HagridDataset class.
Parameters:
img_keys (list): List of unique identifiers for images.
annotation_df (DataFrame): DataFrame containing the image annotations.
img_dict (dict): Dictionary mapping image identifiers to image file paths.
class_to_idx (dict): Dictionary mapping class labels to indices.
transforms (callable, optional): Optional transform to be applied on a sample.
"""
super(Dataset, self).__init__()
self._img_keys = img_keys # List of image keys
self._annotation_df = annotation_df # DataFrame containing annotations
self._img_dict = img_dict # Dictionary mapping image keys to image paths
self._class_to_idx = class_to_idx # Dictionary mapping class names to class indices
self._transforms = transforms # Image transforms to be applied
def __len__(self):
"""
Returns the length of the dataset.
Returns:
int: The number of items in the dataset.
"""
return len(self._img_keys)
def __getitem__(self, index):
"""
Fetch an item from the dataset at the specified index.
Parameters:
index (int): Index of the item to fetch from the dataset.
Returns:
tuple: A tuple containing the image and its associated target (annotations).
"""
# Retrieve the key for the image at the specified index
= self._img_keys[index]
img_key # Get the annotations for this image
= self._annotation_df.loc[img_key]
annotation # Load the image and its target (segmentation masks, bounding boxes and labels)
= self._load_image_and_target(annotation)
image, target
# Apply the transformations, if any
if self._transforms:
= self._transforms(image, target)
image, target
return image, target
def _load_image_and_target(self, annotation):
"""
Load an image and its target (bounding boxes and labels).
Parameters:
annotation (pandas.Series): The annotations for an image.
Returns:
tuple: A tuple containing the image and a dictionary with 'boxes' and 'labels' keys.
"""
# Retrieve the file path of the image
= self._img_dict[annotation.name]
filepath # Open the image file and convert it to RGB
= Image.open(filepath).convert('RGB')
image
# Convert the class labels to indices
= [shape['label'] for shape in annotation['shapes']]
labels = torch.Tensor([self._class_to_idx[label] for label in labels])
labels = labels.to(dtype=torch.int64)
labels
# Convert polygons to mask images
= [shape['points'] for shape in annotation['shapes']]
shape_points = [[tuple(p) for p in points] for points in shape_points]
xy_coords = [create_polygon_mask(image.size, xy) for xy in xy_coords]
mask_imgs = Mask(torch.concat([Mask(transforms.PILToTensor()(mask_img), dtype=torch.bool) for mask_img in mask_imgs]))
masks
# Generate bounding box annotations from segmentation masks
= BoundingBoxes(data=torchvision.ops.masks_to_boxes(masks), format='xyxy', canvas_size=image.size[::-1])
bboxes
return image, {'masks': masks,'boxes': bboxes, 'labels': labels}
Image Transforms
We’ll add additional data augmentations with the IoU crop transform to help the model generalize.
Transform | Description |
---|---|
ColorJitter |
Randomly change the brightness, contrast, saturation and hue of an image or video. (link) |
RandomGrayscale |
Randomly convert image or videos to grayscale with a probability of p (default 0.1). (link) |
RandomEqualize |
Equalize the histogram of the given image or video with a given probability. (link) |
RandomPosterize |
Randomly posterize an image by reducing the number of bits for each color channel. (link) |
# Compose transforms for data augmentation
= transforms.Compose(
data_aug_tfms =[
transforms
iou_crop,
transforms.ColorJitter(= (0.875, 1.125),
brightness = (0.5, 1.5),
contrast = (0.5, 1.5),
saturation = (-0.05, 0.05),
hue
),
transforms.RandomGrayscale(),
transforms.RandomEqualize(),=3, p=0.5),
transforms.RandomPosterize(bits=0.5),
transforms.RandomHorizontalFlip(p
],
)
# Compose transforms to resize and pad input images
= transforms.Compose([
resize_pad_tfm
resize_max,
pad_square,* 2, antialias=True)
transforms.Resize([train_sz]
])
# Compose transforms to sanitize bounding boxes and normalize input data
= transforms.Compose([
final_tfms
transforms.ToImage(), =True),
transforms.ToDtype(torch.float32, scale
transforms.SanitizeBoundingBoxes(),
])
# Define the transformations for training and validation datasets
= transforms.Compose([
train_tfms
data_aug_tfms,
resize_pad_tfm,
final_tfms
])= transforms.Compose([resize_pad_tfm, final_tfms]) valid_tfms
We do not need to include a Normalize
transform as the model internally normalizes input.
Always use the SanitizeBoundingBoxes
transform to clean up annotations after using data augmentations that alter bounding boxes (e.g., cropping, warping, etc.).
Initialize Datasets
Now, we can create our training and validation dataset objects using the dataset splits and transforms.
# Create a mapping from class names to class indices
= {c: i for i, c in enumerate(class_names)}
class_to_idx
# Instantiate the datasets using the defined transformations
= StudentIDDataset(train_keys, annotation_df, img_dict, class_to_idx, train_tfms)
train_dataset = StudentIDDataset(val_keys, annotation_df, img_dict, class_to_idx, valid_tfms)
valid_dataset
# Print the number of samples in the training and validation datasets
pd.Series({'Training dataset size:': len(train_dataset),
'Validation dataset size:': len(valid_dataset)}
='columns') ).to_frame().style.hide(axis
Training dataset size: | 120 |
---|---|
Validation dataset size: | 30 |
Inspect Samples
Let’s verify the dataset objects work correctly by inspecting the first samples from the training and validation sets.
Inspect training set sample
= train_dataset[0]
dataset_sample
# Get colors for dataset sample
= [int_colors[int(i.item())] for i in dataset_sample[1]['labels']]
sample_colors
# Annotate the sample image with segmentation masks
= draw_segmentation_masks(
annotated_tensor =(dataset_sample[0]*255).to(dtype=torch.uint8),
image=dataset_sample[1]['masks'],
masks=0.3,
alpha=sample_colors
colors
)
# Annotate the sample image with bounding boxes
= draw_bboxes(
annotated_tensor =annotated_tensor,
image=dataset_sample[1]['boxes'],
boxes=[class_names[int(i.item())] for i in dataset_sample[1]['labels']],
labels=sample_colors
colors
)
tensor_to_pil(annotated_tensor)
Inspect validation set sample
= valid_dataset[0]
dataset_sample
# Get colors for dataset sample
= [int_colors[int(i.item())] for i in dataset_sample[1]['labels']]
sample_colors
# Annotate the sample image with segmentation masks
= draw_segmentation_masks(
annotated_tensor =(dataset_sample[0]*255).to(dtype=torch.uint8),
image=dataset_sample[1]['masks'],
masks=0.3,
alpha=sample_colors
colors
)
# Annotate the sample image with bounding boxes
= draw_bboxes(
annotated_tensor =annotated_tensor,
image=dataset_sample[1]['boxes'],
boxes=[class_names[int(i.item())] for i in dataset_sample[1]['labels']],
labels=sample_colors
colors
)
tensor_to_pil(annotated_tensor)
Initialize DataLoaders
The last step before training is to instantiate the DataLoaders for the training and validation sets. Try decreasing the batch size if you encounter memory limitations.
# Set the training batch size
= 4
bs
# Set the number of worker processes for loading data.
= multiprocessing.cpu_count()//2
num_workers
# Define parameters for DataLoader
= {
data_loader_params 'batch_size': bs, # Batch size for data loading
'num_workers': num_workers, # Number of subprocesses to use for data loading
'persistent_workers': True, # If True, the data loader will not shutdown the worker processes after a dataset has been consumed once. This allows to maintain the worker dataset instances alive.
'pin_memory': 'cuda' in device, # If True, the data loader will copy Tensors into CUDA pinned memory before returning them. Useful when using GPU.
'pin_memory_device': device if 'cuda' in device else '', # Specifies the device where the data should be loaded. Commonly set to use the GPU.
'collate_fn': lambda batch: tuple(zip(*batch)),
}
# Create DataLoader for training data. Data is shuffled for every epoch.
= DataLoader(train_dataset, **data_loader_params, shuffle=True)
train_dataloader
# Create DataLoader for validation data. Shuffling is not necessary for validation data.
= DataLoader(valid_dataset, **data_loader_params)
valid_dataloader
# Print the number of batches in the training and validation DataLoaders
pd.Series({'Number of batches in train DataLoader:': len(train_dataloader),
'Number of batches in validation DataLoader:': len(valid_dataloader)}
='columns') ).to_frame().style.hide(axis
Number of batches in train DataLoader: | 30 |
---|---|
Number of batches in validation DataLoader: | 8 |
That completes the data preparation. Now, we can finally train our Mask R-CNN model.
Fine-tuning the Model
In this section, we will implement the training code and fine-tune our model.
Define the Training Loop
The following function performs a single pass through the training or validation set.
The model has different behavior when in training
mode versus evaluation
mode. In training mode, it calculates the loss internally for the object detection and segmentation tasks and returns a dictionary with the individual loss values. We can sum up these separate values to get the total loss.
def run_epoch(model, dataloader, optimizer, lr_scheduler, device, scaler, epoch_id, is_training):
"""
Function to run a single training or evaluation epoch.
Args:
model: A PyTorch model to train or evaluate.
dataloader: A PyTorch DataLoader providing the data.
optimizer: The optimizer to use for training the model.
loss_func: The loss function used for training.
device: The device (CPU or GPU) to run the model on.
scaler: Gradient scaler for mixed-precision training.
is_training: Boolean flag indicating whether the model is in training or evaluation mode.
Returns:
The average loss for the epoch.
"""
# Set the model to training mode
model.train()
= 0 # Initialize the total loss for this epoch
epoch_loss = tqdm(total=len(dataloader), desc="Train" if is_training else "Eval") # Initialize a progress bar
progress_bar
# Loop over the data
for batch_id, (inputs, targets) in enumerate(dataloader):
# Move inputs and targets to the specified device
= torch.stack(inputs).to(device)
inputs
# Forward pass with Automatic Mixed Precision (AMP) context manager
with autocast(torch.device(device).type):
if is_training:
= model(inputs.to(device), move_data_to_device(targets, device))
losses else:
with torch.no_grad():
= model(inputs.to(device), move_data_to_device(targets, device))
losses
# Compute the loss
= sum([loss for loss in losses.values()]) # Sum up the losses
loss
# If in training mode, backpropagate the error and update the weights
if is_training:
if scaler:
scaler.scale(loss).backward()
scaler.step(optimizer)= scaler.get_scale()
old_scaler
scaler.update()= scaler.get_scale()
new_scaler if new_scaler >= old_scaler:
lr_scheduler.step()else:
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Update the total loss
= loss.item()
loss_item += loss_item
epoch_loss
# Update the progress bar
= dict(loss=loss_item, avg_loss=epoch_loss/(batch_id+1))
progress_bar_dict if is_training:
=lr_scheduler.get_last_lr()[0])
progress_bar_dict.update(lr
progress_bar.set_postfix(progress_bar_dict)
progress_bar.update()
# If loss is NaN or infinity, stop training
if is_training:
= f"Loss is NaN or infinite at epoch {epoch_id}, batch {batch_id}. Stopping training."
stop_training_message assert not math.isnan(loss_item) and math.isfinite(loss_item), stop_training_message
# Cleanup and close the progress bar
progress_bar.close()
# Return the average loss for this epoch
return epoch_loss / (batch_id + 1)
Next, we define the train_loop
function, which executes the main training loop. It iterates over each epoch, runs through the training and validation sets, and saves the best model based on the validation loss.
def train_loop(model,
train_dataloader,
valid_dataloader,
optimizer,
lr_scheduler,
device,
epochs,
checkpoint_path, =False):
use_scaler"""
Main training loop.
Args:
model: A PyTorch model to train.
train_dataloader: A PyTorch DataLoader providing the training data.
valid_dataloader: A PyTorch DataLoader providing the validation data.
optimizer: The optimizer to use for training the model.
lr_scheduler: The learning rate scheduler.
device: The device (CPU or GPU) to run the model on.
epochs: The number of epochs to train for.
checkpoint_path: The path where to save the best model checkpoint.
use_scaler: Whether to scale graidents when using a CUDA device
Returns:
None
"""
# Initialize a gradient scaler for mixed-precision training if the device is a CUDA GPU
= torch.cuda.amp.GradScaler() if device.type == 'cuda' and use_scaler else None
scaler = float('inf') # Initialize the best validation loss
best_loss
# Loop over the epochs
for epoch in tqdm(range(epochs), desc="Epochs"):
# Run a training epoch and get the training loss
= run_epoch(model, train_dataloader, optimizer, lr_scheduler, device, scaler, epoch, is_training=True)
train_loss # Run an evaluation epoch and get the validation loss
with torch.no_grad():
= run_epoch(model, valid_dataloader, None, None, device, scaler, epoch, is_training=False)
valid_loss
# If the validation loss is lower than the best validation loss seen so far, save the model checkpoint
if valid_loss < best_loss:
= valid_loss
best_loss
torch.save(model.state_dict(), checkpoint_path)
# Save metadata about the training process
= {
training_metadata 'epoch': epoch,
'train_loss': train_loss,
'valid_loss': valid_loss,
'learning_rate': lr_scheduler.get_last_lr()[0],
'model_architecture': model.name
}with open(Path(checkpoint_path.parent/'training_metadata.json'), 'w') as f:
json.dump(training_metadata, f)
# If the device is a GPU, empty the cache
if device.type != 'cpu':
getattr(torch, device.type).empty_cache()
Set the Model Checkpoint Path
Before we proceed with training, let’s generate a timestamp for the training session and create a directory to save the checkpoints during training.
# Generate timestamp for the training session (Year-Month-Day_Hour_Minute_Second)
= datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
timestamp
# Create a directory to store the checkpoints if it does not already exist
= Path(project_dir/f"{timestamp}")
checkpoint_dir
# Create the checkpoint directory if it does not already exist
=True, exist_ok=True)
checkpoint_dir.mkdir(parents
# The model checkpoint path
= checkpoint_dir/f"{model.name}.pth"
checkpoint_path
print(checkpoint_path)
pytorch-mask-r-cnn-instance-segmentation/2023-09-19_15-17-57/maskrcnn_resnet50_fpn_v2.pth
Let’s also save a copy of the colormap for the current dataset in the training folder for future use.
Save the Color Map
# Create a color map and write it to a JSON file
= {'items': [{'label': label, 'color': color} for label, color in zip(class_names, colors)]}
color_map with open(f"{checkpoint_dir}/{dataset_path.name}-colormap.json", "w") as file:
file)
json.dump(color_map,
# Print the name of the file that the color map was written to
print(f"{checkpoint_dir}/{dataset_path.name}-colormap.json")
pytorch-mask-r-cnn-instance-segmentation/2023-09-19_15-17-57/student-id-colormap.json
Configure the Training Parameters
Now, we can configure the parameters for training. We must specify the learning rate and number of training epochs. We will also instantiate the optimizer and learning rate scheduler.
# Learning rate for the model
= 5e-4
lr
# Number of training epochs
= 40
epochs
# AdamW optimizer; includes weight decay for regularization
= torch.optim.AdamW(model.parameters(), lr=lr)
optimizer
# Learning rate scheduler; adjusts the learning rate during training
= torch.optim.lr_scheduler.OneCycleLR(optimizer,
lr_scheduler =lr,
max_lr=epochs*len(train_dataloader)) total_steps
Train the Model
Finally, we can train the model using the train_loop
function. Training time will depend on the available hardware.
Training usually takes around 13 minutes on the free GPU tier of Google Colab.
=model,
train_loop(model=train_dataloader,
train_dataloader=valid_dataloader,
valid_dataloader=optimizer,
optimizer=lr_scheduler,
lr_scheduler=torch.device(device),
device=epochs,
epochs=checkpoint_path,
checkpoint_path=True) use_scaler
Epochs: 100%|██████████| 40/40 [03:22<00:00, 5.04s/it]
Train: 100%|██████████| 30/30 [00:04<00:00, 8.51it/s, loss=0.472, avg_loss=0.917, lr=2.82e-5]
Eval: 100%|██████████| 8/8 [00:00<00:00, 12.23it/s, loss=0.346, avg_loss=0.421]
Train: 100%|██████████| 30/30 [00:04<00:00, 9.10it/s, loss=0.134, avg_loss=0.35, lr=5.23e-5]
Eval: 100%|██████████| 8/8 [00:00<00:00, 12.53it/s, loss=0.138, avg_loss=0.209]
Train: 100%|██████████| 30/30 [00:04<00:00, 8.67it/s, loss=0.371, avg_loss=0.207, lr=9.07e-5]
Eval: 100%|██████████| 8/8 [00:00<00:00, 10.86it/s, loss=0.0978, avg_loss=0.147]
Train: 100%|██████████| 30/30 [00:04<00:00, 9.08it/s, loss=0.125, avg_loss=0.191, lr=0.000141]
Eval: 100%|██████████| 8/8 [00:00<00:00, 10.68it/s, loss=0.0925, avg_loss=0.182]
Train: 100%|██████████| 30/30 [00:04<00:00, 9.17it/s, loss=0.17, avg_loss=0.227, lr=0.000199]
Eval: 100%|██████████| 8/8 [00:00<00:00, 11.88it/s, loss=0.0917, avg_loss=0.205]
Train: 100%|██████████| 30/30 [00:04<00:00, 8.87it/s, loss=0.116, avg_loss=0.193, lr=0.000261]
Eval: 100%|██████████| 8/8 [00:00<00:00, 12.42it/s, loss=0.075, avg_loss=0.13]
Train: 100%|██████████| 30/30 [00:04<00:00, 8.81it/s, loss=0.154, avg_loss=0.2, lr=0.000323]
Eval: 100%|██████████| 8/8 [00:00<00:00, 10.71it/s, loss=0.111, avg_loss=0.144]
Train: 100%|██████████| 30/30 [00:04<00:00, 8.56it/s, loss=0.17, avg_loss=0.206, lr=0.000381]
Eval: 100%|██████████| 8/8 [00:00<00:00, 14.12it/s, loss=0.136, avg_loss=0.177]
Train: 100%|██████████| 30/30 [00:04<00:00, 9.12it/s, loss=0.668, avg_loss=0.252, lr=0.000431]
Eval: 100%|██████████| 8/8 [00:00<00:00, 11.77it/s, loss=0.15, avg_loss=0.357]
Train: 100%|██████████| 30/30 [00:04<00:00, 9.13it/s, loss=0.297, avg_loss=0.3, lr=0.000469]
Eval: 100%|██████████| 8/8 [00:00<00:00, 11.99it/s, loss=0.139, avg_loss=0.22]
Train: 100%|██████████| 30/30 [00:04<00:00, 9.16it/s, loss=0.357, avg_loss=0.254, lr=0.000492]
Eval: 100%|██████████| 8/8 [00:00<00:00, 11.99it/s, loss=0.135, avg_loss=0.193]
Train: 100%|██████████| 30/30 [00:04<00:00, 8.85it/s, loss=0.471, avg_loss=0.253, lr=0.0005]
Eval: 100%|██████████| 8/8 [00:00<00:00, 12.07it/s, loss=0.0909, avg_loss=0.165]
Train: 100%|██████████| 30/30 [00:04<00:00, 9.19it/s, loss=0.454, avg_loss=0.216, lr=0.000498]
Eval: 100%|██████████| 8/8 [00:00<00:00, 11.56it/s, loss=0.104, avg_loss=0.172]
Train: 100%|██████████| 30/30 [00:04<00:00, 9.03it/s, loss=0.165, avg_loss=0.225, lr=0.000494]
Eval: 100%|██████████| 8/8 [00:00<00:00, 11.85it/s, loss=0.0873, avg_loss=0.14]
Train: 100%|██████████| 30/30 [00:04<00:00, 8.82it/s, loss=0.0918, avg_loss=0.215, lr=0.000486]
Eval: 100%|██████████| 8/8 [00:00<00:00, 11.76it/s, loss=0.0951, avg_loss=0.137]
Train: 100%|██████████| 30/30 [00:04<00:00, 9.16it/s, loss=0.1, avg_loss=0.179, lr=0.000475]
Eval: 100%|██████████| 8/8 [00:00<00:00, 11.66it/s, loss=0.0898, avg_loss=0.131]
Train: 100%|██████████| 30/30 [00:04<00:00, 9.03it/s, loss=0.205, avg_loss=0.164, lr=0.000461]
Eval: 100%|██████████| 8/8 [00:00<00:00, 11.70it/s, loss=0.103, avg_loss=0.131]
Train: 100%|██████████| 30/30 [00:04<00:00, 9.20it/s, loss=0.118, avg_loss=0.201, lr=0.000445]
Eval: 100%|██████████| 8/8 [00:00<00:00, 10.57it/s, loss=0.0923, avg_loss=0.168]
Train: 100%|██████████| 30/30 [00:04<00:00, 9.15it/s, loss=0.0848, avg_loss=0.197, lr=0.000426]
Eval: 100%|██████████| 8/8 [00:00<00:00, 12.11it/s, loss=0.0793, avg_loss=0.135]
Train: 100%|██████████| 30/30 [00:04<00:00, 8.94it/s, loss=0.138, avg_loss=0.169, lr=0.000405]
Eval: 100%|██████████| 8/8 [00:00<00:00, 12.12it/s, loss=0.0741, avg_loss=0.134]
Train: 100%|██████████| 30/30 [00:04<00:00, 9.00it/s, loss=0.3, avg_loss=0.195, lr=0.000382]
Eval: 100%|██████████| 8/8 [00:00<00:00, 11.81it/s, loss=0.133, avg_loss=0.163]
Train: 100%|██████████| 30/30 [00:04<00:00, 9.09it/s, loss=0.29, avg_loss=0.126, lr=0.000358]
Eval: 100%|██████████| 8/8 [00:00<00:00, 11.96it/s, loss=0.0742, avg_loss=0.12]
Train: 100%|██████████| 30/30 [00:04<00:00, 8.97it/s, loss=0.113, avg_loss=0.136, lr=0.000332]
Eval: 100%|██████████| 8/8 [00:00<00:00, 14.36it/s, loss=0.0749, avg_loss=0.116]
Train: 100%|██████████| 30/30 [00:04<00:00, 9.12it/s, loss=0.0852, avg_loss=0.137, lr=0.000305]
Eval: 100%|██████████| 8/8 [00:00<00:00, 12.19it/s, loss=0.0689, avg_loss=0.114]
Train: 100%|██████████| 30/30 [00:04<00:00, 9.16it/s, loss=0.118, avg_loss=0.142, lr=0.000277]
Eval: 100%|██████████| 8/8 [00:00<00:00, 11.83it/s, loss=0.0643, avg_loss=0.117]
Train: 100%|██████████| 30/30 [00:04<00:00, 9.01it/s, loss=0.0898, avg_loss=0.134, lr=0.000249]
Eval: 100%|██████████| 8/8 [00:00<00:00, 12.22it/s, loss=0.0726, avg_loss=0.105]
Train: 100%|██████████| 30/30 [00:04<00:00, 9.02it/s, loss=0.0792, avg_loss=0.122, lr=0.000221]
Eval: 100%|██████████| 8/8 [00:00<00:00, 11.67it/s, loss=0.0679, avg_loss=0.1]
Train: 100%|██████████| 30/30 [00:04<00:00, 9.13it/s, loss=0.0842, avg_loss=0.127, lr=0.000193]
Eval: 100%|██████████| 8/8 [00:00<00:00, 11.75it/s, loss=0.0724, avg_loss=0.101]
Train: 100%|██████████| 30/30 [00:04<00:00, 9.11it/s, loss=0.0794, avg_loss=0.126, lr=0.000167]
Eval: 100%|██████████| 8/8 [00:00<00:00, 11.94it/s, loss=0.0656, avg_loss=0.0925]
Train: 100%|██████████| 30/30 [00:04<00:00, 9.02it/s, loss=0.0992, avg_loss=0.113, lr=0.000141]
Eval: 100%|██████████| 8/8 [00:00<00:00, 11.72it/s, loss=0.0586, avg_loss=0.0914]
Train: 100%|██████████| 30/30 [00:04<00:00, 9.05it/s, loss=0.15, avg_loss=0.117, lr=0.000116]
Eval: 100%|██████████| 8/8 [00:00<00:00, 12.15it/s, loss=0.0593, avg_loss=0.089]
Train: 100%|██████████| 30/30 [00:04<00:00, 9.06it/s, loss=0.23, avg_loss=0.107, lr=9.34e-5]
Eval: 100%|██████████| 8/8 [00:00<00:00, 11.97it/s, loss=0.0559, avg_loss=0.0899]
Train: 100%|██████████| 30/30 [00:04<00:00, 8.93it/s, loss=0.0678, avg_loss=0.0973, lr=7.26e-5]
Eval: 100%|██████████| 8/8 [00:00<00:00, 11.98it/s, loss=0.0631, avg_loss=0.0831]
Train: 100%|██████████| 30/30 [00:04<00:00, 9.11it/s, loss=0.0892, avg_loss=0.0847, lr=5.4e-5]
Eval: 100%|██████████| 8/8 [00:00<00:00, 12.03it/s, loss=0.0587, avg_loss=0.0813]
Train: 100%|██████████| 30/30 [00:04<00:00, 9.08it/s, loss=0.0662, avg_loss=0.0854, lr=3.78e-5]
Eval: 100%|██████████| 8/8 [00:00<00:00, 11.68it/s, loss=0.06, avg_loss=0.0842]
Train: 100%|██████████| 30/30 [00:04<00:00, 8.96it/s, loss=0.065, avg_loss=0.0936, lr=2.44e-5]
Eval: 100%|██████████| 8/8 [00:00<00:00, 12.23it/s, loss=0.0532, avg_loss=0.0795]
Train: 100%|██████████| 30/30 [00:04<00:00, 8.70it/s, loss=0.226, avg_loss=0.0877, lr=1.37e-5]
Eval: 100%|██████████| 8/8 [00:00<00:00, 12.32it/s, loss=0.0544, avg_loss=0.0792]
Train: 100%|██████████| 30/30 [00:04<00:00, 9.10it/s, loss=0.0718, avg_loss=0.0849, lr=6.06e-6]
Eval: 100%|██████████| 8/8 [00:00<00:00, 11.96it/s, loss=0.0556, avg_loss=0.0769]
Train: 100%|██████████| 30/30 [00:04<00:00, 9.03it/s, loss=0.0538, avg_loss=0.0903, lr=1.47e-6]
Eval: 100%|██████████| 8/8 [00:00<00:00, 10.89it/s, loss=0.0526, avg_loss=0.0778]
Train: 100%|██████████| 30/30 [00:04<00:00, 9.13it/s, loss=0.0511, avg_loss=0.0853, lr=3.75e-9]
Eval: 100%|██████████| 8/8 [00:00<00:00, 11.66it/s, loss=0.0526, avg_loss=0.0772]
At last, we have our fine-tuned Mask R-CNN model. To wrap up the tutorial, we can test our model by performing inference on individual images.
Making Predictions with the Model
In this final part of the tutorial, we will cover how to perform inference on individual images with our Mask R-CNN model and filter the predictions.
Preparing Input Data
Let’s use a random image from the validation set. That way, we have some ground truth annotation data to compare against. Unlike during training, we won’t stick to square input dimensions for inference.
# Choose a random item from the validation set
= random.choice(val_keys)
file_id
# Retrieve the image file path associated with the file ID
= img_dict[file_id]
test_file
# Open the test file
= Image.open(test_file).convert('RGB')
test_img
# Resize the test image
= resize_img(test_img, target_sz=train_sz, divisor=1)
input_img
# Calculate the scale between the source image and the resized image
= min(test_img.size) / min(input_img.size)
min_img_scale
display(test_img)
# Print the prediction data as a Pandas DataFrame for easy formatting
pd.Series({"Source Image Size:": test_img.size,
"Input Dims:": input_img.size,
"Min Image Scale:": min_img_scale,
"Input Image Size:": input_img.size
='columns') }).to_frame().style.hide(axis
Source Image Size: | (480, 640) |
---|---|
Input Dims: | (512, 682) |
Min Image Scale: | 0.937500 |
Input Image Size: | (512, 682) |
Get the target annotation data
# Extract the polygon points for segmentation mask
= [shape['points'] for shape in annotation_df.loc[file_id]['shapes']]
target_shape_points # Format polygon points for PIL
= [[tuple(p) for p in points] for points in target_shape_points]
target_xy_coords # Generate mask images from polygons
= [create_polygon_mask(test_img.size, xy) for xy in target_xy_coords]
target_mask_imgs # Convert mask images to tensors
= Mask(torch.concat([Mask(transforms.PILToTensor()(mask_img), dtype=torch.bool) for mask_img in target_mask_imgs]))
target_masks
# Get the target labels and bounding boxes
= [shape['label'] for shape in annotation_df.loc[file_id]['shapes']]
target_labels = BoundingBoxes(data=torchvision.ops.masks_to_boxes(target_masks), format='xyxy', canvas_size=test_img.size[::-1]) target_bboxes
Pass the input data to the model
Now, we can convert the test image to a tensor and pass it to the model. Ensure the model is set to evaluation mode to get predictions instead of loss values.
# Set the model to evaluation mode
eval();
model.
# Ensure the model and input data are on the same device
model.to(device)= transforms.Compose([transforms.ToImage(), transforms.ToDtype(torch.float32, scale=True)])(input_img)[None].to(device)
input_tensor
# Make a prediction with the model
with torch.no_grad():
= model(input_tensor) model_output
Filter the model output
The model performs most post-processing steps internally, so we only need to filter the output based on the desired confidence threshold. The model returns predictions as a list of dictionaries. Each dictionary stores bounding boxes, label indices, confidence scores, and segmentation masks for a single sample in the input batch.
Since we resized the test image, we must scale the bounding boxes and segmentation masks to the source resolution.
# Set the confidence threshold
= 0.5
threshold
# Move model output to the CPU
= move_data_to_device(model_output, 'cpu')
model_output
# Filter the output based on the confidence threshold
= model_output[0]['scores'] > threshold
scores_mask
# Scale the predicted bounding boxes
= BoundingBoxes(model_output[0]['boxes'][scores_mask]*min_img_scale, format='xyxy', canvas_size=input_img.size[::-1])
pred_bboxes
# Get the class names for the predicted label indices
= [class_names[int(label)] for label in model_output[0]['labels'][scores_mask]]
pred_labels
# Extract the confidence scores
= model_output[0]['scores']
pred_scores
# Scale and stack the predicted segmentation masks
= F.interpolate(model_output[0]['masks'][scores_mask], size=test_img.size[::-1])
pred_masks = torch.concat([Mask(torch.where(mask >= threshold, 1, 0), dtype=torch.bool) for mask in pred_masks]) pred_masks
Annotate the image using the model predictions
# Get the annotation colors for the targets and predictions
=[int_colors[i] for i in [class_names.index(label) for label in target_labels]]
target_colors=[int_colors[i] for i in [class_names.index(label) for label in pred_labels]]
pred_colors
# Convert the test images to a tensor
= transforms.PILToTensor()(test_img)
img_tensor
# Annotate the test image with the target segmentation masks
= draw_segmentation_masks(image=img_tensor, masks=target_masks, alpha=0.3, colors=target_colors)
annotated_tensor # Annotate the test image with the target bounding boxes
= draw_bboxes(image=annotated_tensor, boxes=target_bboxes, labels=target_labels, colors=target_colors)
annotated_tensor # Display the annotated test image
= tensor_to_pil(annotated_tensor)
annotated_test_img
# Annotate the test image with the predicted segmentation masks
= draw_segmentation_masks(image=img_tensor, masks=pred_masks, alpha=0.3, colors=pred_colors)
annotated_tensor # Annotate the test image with the predicted labels and bounding boxes
= draw_bboxes(
annotated_tensor =annotated_tensor,
image=pred_bboxes,
boxes=[f"{label}\n{prob*100:.2f}%" for label, prob in zip(pred_labels, pred_scores)],
labels=pred_colors
colors
)
# Display the annotated test image with the predicted bounding boxes
display(stack_imgs([annotated_test_img, tensor_to_pil(annotated_tensor)]))
# Print the prediction data as a Pandas DataFrame for easy formatting
pd.Series({"Target BBoxes:": [f"{label}:{bbox}" for label, bbox in zip(target_labels, np.round(target_bboxes.numpy(), decimals=3))],
"Predicted BBoxes:": [f"{label}:{bbox}" for label, bbox in zip(pred_labels, pred_bboxes.round(decimals=3).numpy())],
"Confidence Scores:": [f"{label}: {prob*100:.2f}%" for label, prob in zip(pred_labels, pred_scores)]
='columns') }).to_frame().style.hide(axis
Target BBoxes: | [‘student_id:[ 11. 164. 451. 455.]’] |
---|---|
Predicted BBoxes: | [‘student_id:[ 10.621 162.709 452.722 453.537]’] |
Confidence Scores: | [‘student_id: 99.94%’] |
The segmentation mask has a few rough spots, but the model appears to have learned to detect and segment ID cards as desired.
- Don’t forget to download the model checkpoint and class labels from the Colab Environment’s file browser. (tutorial link)
- Once you finish training and download the files, turn off hardware acceleration for the Colab Notebook to save GPU time. (tutorial link)
Conclusion
Congratulations on completing this tutorial for training Mask R-CNN models in PyTorch! The skills and knowledge you’ve acquired here serve as a solid foundation for future projects.
Recommended Tutorials
- Exporting Mask R-CNN Models from PyTorch to ONNX: Learn how to export Mask R-CNN models from PyTorch to ONNX and perform inference using ONNX Runtime.
- Working with LabelMe Segmentation Annotations in Torchvision: Learn how to work with LabelMe segmentation annotations in torchvision for instance segmentation tasks.
- Feel free to post questions or problems related to this tutorial in the comments below. I try to make time to address them on Thursdays and Fridays.
- I’m Christian Mills, a deep learning consultant specializing in computer vision and practical AI implementations.
- I help clients leverage cutting-edge AI technologies to solve real-world problems.
- Learn more about me or reach out via email at [email protected] to discuss your project.