Exporting Keypoint R-CNN Models from PyTorch to ONNX

Learn how to export Keypoint R-CNN models from PyTorch to ONNX and perform inference using ONNX Runtime.

Christian Mills


January 30, 2024


February 22, 2024

This post is part of the following series:


Welcome back to this series on training Keypoint R-CNN models with PyTorch. Previously, we demonstrated how to fine-tune a Keypoint R-CNN model by training it to identify the locations of human noses and faces. This tutorial builds on that by showing how to export the model to ONNX and perform inference using ONNX Runtime.

ONNX (Open Neural Network Exchange) is an open format to represent machine learning models and make them portable across various platforms. ONNX Runtime is a cross-platform inference accelerator that provides interfaces to hardware-specific libraries. By exporting our model to ONNX, we can deploy it to multiple devices and leverage hardware acceleration for faster inference. The Keypoint R-CNN model is computationally intensive, so any improvements to inference speed are welcome.

Additionally, we’ll implement the functionality to annotate images with key points without relying on PyTorch as a dependency. By the end of this tutorial, you will have an ONNX version of our Keypoint R-CNN model that you can deploy to servers and edge devices using ONNX Runtime.

This post assumes the reader has completed the previous tutorial linked below:

Getting Started with the Code

As with the previous tutorial, the code is available as a Jupyter Notebook.

Jupyter Notebook Google Colab
GitHub Repository Open In Colab

Setting Up Your Python Environment

We’ll need to add a few new libraries to our Python environment for working with ONNX models.

Package Description
onnx This package provides a Python API for working with ONNX models. (link)
onnxruntime ONNX Runtime is a runtime accelerator for machine learning models. (link)
onnx-simplifier This package helps simplify ONNX models. (link)

Run the following command to install these additional libraries:

# Install ONNX packages
pip install onnx onnxruntime onnx-simplifier

Importing the Required Dependencies

With our environment updated, we can dive into the code. First, we will import the necessary Python dependencies into our Jupyter Notebook.

# Import Python Standard Library dependencies
import json
from pathlib import Path
import random

# Import utility functions
from cjm_psl_utils.core import download_file
from cjm_pil_utils.core import resize_img

# Import numpy
import numpy as np

# Import the pandas package
import pandas as pd

# Import PIL for image manipulation
from PIL import Image, ImageDraw, ImageFont

# Import PyTorch dependencies
import torch

# Import Keypoint R-CNN
from torchvision.models.detection import keypointrcnn_resnet50_fpn
from torchvision.models.detection.keypoint_rcnn import KeypointRCNNPredictor

# Import ONNX dependencies
import onnx # Import the onnx module
from onnxsim import simplify # Import the method to simplify ONNX models
import onnxruntime as ort # Import the ONNX Runtime

Setting Up the Project

In this section, we’ll set the folder locations for our project and training session with the PyTorch checkpoint. Let’s also ensure we have a font file for annotating images.

Set the Directory Paths

# The name for the project
project_name = f"pytorch-keypoint-r-cnn"

# The path for the project folder
project_dir = Path(f"./{project_name}/")

# Create the project directory if it does not already exist
project_dir.mkdir(parents=True, exist_ok=True)

# The path to the checkpoint folder
checkpoint_dir = Path(project_dir/f"2024-01-30_10-44-52")

    "Project Directory:": project_dir,
    "Checkpoint Directory:": checkpoint_dir,
Project Directory: pytorch-keypoint-r-cnn
Checkpoint Directory: pytorch-keypoint-r-cnn/2024-01-30_10-44-52
I made a model checkpoint available on Hugging Face Hub in the repository linked below:
Those following along on Google Colab can drag the contents of their checkpoint folder into Colab’s file browser. Keep in mind the model checkpoint has a large file size.

Download a Font File

# Set the name of the font file
font_file = 'KFOlCnqEu92Fr1MmEU9vAw.ttf'

# Download the font file
download_file(f"https://fonts.gstatic.com/s/roboto/v30/{font_file}", "./")

Loading the Checkpoint Data

Now, we can load the colormap used during training and initialize a Keypoint R-CNN model with the saved checkpoint.

Load the Colormap

# The colormap path
colormap_path = list(checkpoint_dir.glob('*colormap.json'))[0]

# Load the JSON colormap data
with open(colormap_path, 'r') as file:
        colormap_json = json.load(file)

# Convert the JSON data to a dictionary        
colormap_dict = {item['label']: item['color'] for item in colormap_json['items']}

# Extract the class names from the colormap
class_names = list(colormap_dict.keys())

# Make a copy of the colormap in integer format
int_colors = [tuple(int(c*255) for c in color) for color in colormap_dict.values()]

Load the Model Checkpoint

# The model checkpoint path
checkpoint_path = list(checkpoint_dir.glob('*.pth'))[0]

# Load the model checkpoint onto the CPU
model_checkpoint = torch.load(checkpoint_path, map_location='cpu')

Load the Trained Keypoint R-CNN Model

# Load a pre-trained model
model = keypointrcnn_resnet50_fpn(weights='DEFAULT')

# Replace the classifier head with the number of keypoints
in_features = model.roi_heads.keypoint_predictor.kps_score_lowres.in_channels
model.roi_heads.keypoint_predictor = KeypointRCNNPredictor(in_channels=in_features, num_keypoints=len(class_names))

# Initialize the model with the checkpoint parameters and buffers
<All keys matched successfully>

Exporting the Model to ONNX

Before exporting the model, let’s ensure the model is in evaluation mode.


Prepare the Input Tensor

We need a sample input tensor for the export process.

input_tensor = torch.randn(1, 3, 256, 256)

Export the Model to ONNX

We can export the model using PyTorch’s torch.onnx.export() function. This function performs a single pass through the model and records all operations to generate a TorchScript graph. It then exports this graph to ONNX by decomposing each graph node (which contains a PyTorch operator) into a series of ONNX operators.

If we want the ONNX model to support different input sizes, we must set the width and height input axes as dynamic.

# Set a filename for the ONNX model
onnx_file_path = f"{checkpoint_dir}/{colormap_path.stem.removesuffix('-colormap')}-{checkpoint_path.stem}.onnx"

# Export the PyTorch model to ONNX format
                  input_names = ['input'],
                  output_names = ['boxes', 'labels', 'scores', 'keypoints', 'keypoints_scores'],
                  dynamic_axes={'input': {2 : 'height', 3 : 'width'}}

The export function may return some UserWarning messages when we export the model. We can ignore these warnings as the exported model functions as expected.

Simplify the ONNX Model

The ONNX models generated by PyTorch are not always the most concise. We can use the onnx-simplifier package to tidy up the exported model.

# Load the ONNX model from the onnx_file_name
onnx_model = onnx.load(onnx_file_path)

# Simplify the model
model_simp, check = simplify(onnx_model)

# Save the simplified model to the onnx_file_name
onnx.save(model_simp, onnx_file_path)

Performing Inference with ONNX Runtime

Now that we have our ONNX model, it’s time to test it with ONNX Runtime.

Create an Inference Session

We interact with models in ONNX Runtime through an InferenceSession object. Here we can specify which Execution Providers to use for inference and other configuration information. Execution Providers are the interfaces for hardware-specific inference engines like TensorRT for NVIDIA and OpenVINO for Intel. By default, the InferenceSession uses the generic CPUExecutionProvider.

# Load the model and create an InferenceSession
session = ort.InferenceSession(onnx_file_path, providers=['CPUExecutionProvider'])

Define Annotation Function

Next, we need to annotate images with key points. PIL includes functionality to draw circles on images.

def draw_keypoints_pil(image, keypoints, labels, colors, radius:int=5):

    Annotates an image with keypoints, each marked by a circle and associated with specific labels and colors.

    This function draws circles on the provided image at given keypoint coordinates. Each keypoint is associated 
    with a label and a color. The radius of the circles can be adjusted.

    image (PIL.Image): The input image on which annotations will be drawn.
    keypoints (list of tuples): A list of (x, y) tuples representing the coordinates of each keypoint.
    labels (list of str): A list of labels corresponding to each keypoint.
    colors (list of tuples): A list of RGB tuples for each keypoint, defining the color of the circle to be drawn.
    radius (int, optional): The radius of the circles to be drawn for each keypoint. Defaults to 5.

    annotated_image (PIL.Image): The image annotated with keypoints, each represented as a colored circle.
    # Create a copy of the image
    annotated_image = image.copy()

    # Create an ImageDraw object for drawing on the image
    draw = ImageDraw.Draw(annotated_image)

    # Loop through the bounding boxes and labels in the 'annotation' DataFrame
    for i in range(len(labels)):
        # Get the key point coordinates
        x, y = keypoints[i]

        # Draw a circle
        draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill=colors[i])
    return annotated_image

Select a Test Image

We can download an image from one of my HuggingFace repositories to verify the exported model performs as expected.

test_img_name = "pexels-2769554-man-doing-rock-and-roll-sign.jpg"
test_img_url = f"https://huggingface.co/datasets/cj-mills/pexel-hand-gesture-test-images/resolve/main/{test_img_name}"

download_file(test_img_url, './', False)

test_img = Image.open(test_img_name)

    "Test Image Size:": test_img.size, 

Test Image Size: (640, 960)

Prepare the Test Image

# Set test image size
test_sz = 512

## Resize the test image
input_img = resize_img(test_img, target_sz=test_sz, divisor=1)

# Calculate the scale between the source image and the resized image
min_img_scale = min(test_img.size) / min(input_img.size)


# Print the prediction data as a Pandas DataFrame for easy formatting
    "Source Image Size:": test_img.size,
    "Input Dims:": input_img.size,
    "Min Image Scale:": min_img_scale,
    "Input Image Size:": input_img.size

Source Image Size: (640, 960)
Input Dims: (512, 768)
Min Image Scale: 1.250000
Input Image Size: (512, 768)

Prepare the Input Tensor

When we convert the PIL input image to a NumPy array, we need to reorder the array values to channels-first format, scale the values from [0,255] to [0,1], and add a batch dimension.

# Convert the input image to NumPy format
input_tensor_np = np.array(input_img, dtype=np.float32).transpose((2, 0, 1))[None]/255

Compute the Predictions

Now, we can finally perform inference with our ONNX model.

# Run inference
model_output = session.run(None, {"input": input_tensor_np})

# Set the confidence threshold
conf_threshold = 0.8

# Filter the output based on the confidence threshold
scores_mask = model_output[2] > conf_threshold

# Extract and scale the predicted keypoints
predicted_keypoints = (model_output[3][scores_mask])[:,:,:-1].reshape(-1,2)*min_img_scale


                colors=[int_colors[i] for i in [class_names.index(label) for label in labels]],

The model appears to work as intended, even on this new image.

Google Colab Users
  1. Don’t forget to download the ONNX model from the Colab Environment’s file browser. (tutorial link)


Congratulations on reaching the end of this tutorial! We previously trained a Keypoint R-CNN model in PyTorch, and now we’ve exported that model to ONNX. With this, we can streamline our deployment process and leverage platform-specific hardware optimizations through ONNX Runtime.

As you move forward, consider exploring more about ONNX and its ecosystem. Check out the available Execution Providers that provide flexible interfaces to different hardware acceleration libraries.

Next Steps
  • Feel free to post questions or problems related to this tutorial in the comments below. I try to make time to address them on Thursdays and Fridays.
  • If you would like to explore my services for your project, you can reach out via email at [email protected]