Exporting Keypoint R-CNN Models from PyTorch to ONNX
- Introduction
- Getting Started with the Code
- Setting Up Your Python Environment
- Importing the Required Dependencies
- Setting Up the Project
- Loading the Checkpoint Data
- Exporting the Model to ONNX
- Performing Inference with ONNX Runtime
- Conclusion
Introduction
Welcome back to this series on training Keypoint R-CNN models with PyTorch. Previously, we demonstrated how to fine-tune a Keypoint R-CNN model by training it to identify the locations of human noses and faces. This tutorial builds on that by showing how to export the model to ONNX and perform inference using ONNX Runtime.
ONNX (Open Neural Network Exchange) is an open format to represent machine learning models and make them portable across various platforms. ONNX Runtime is a cross-platform inference accelerator that provides interfaces to hardware-specific libraries. By exporting our model to ONNX, we can deploy it to multiple devices and leverage hardware acceleration for faster inference. The Keypoint R-CNN model is computationally intensive, so any improvements to inference speed are welcome.
Additionally, we’ll implement the functionality to annotate images with key points without relying on PyTorch as a dependency. By the end of this tutorial, you will have an ONNX version of our Keypoint R-CNN model that you can deploy to servers and edge devices using ONNX Runtime.
Getting Started with the Code
As with the previous tutorial, the code is available as a Jupyter Notebook.
Jupyter Notebook | Google Colab |
---|---|
GitHub Repository | Open In Colab |
Setting Up Your Python Environment
We’ll need to add a few new libraries to our Python environment for working with ONNX models.
Run the following command to install these additional libraries:
# Install ONNX packages
pip install onnx onnxruntime onnx-simplifier
Importing the Required Dependencies
With our environment updated, we can dive into the code. First, we will import the necessary Python dependencies into our Jupyter Notebook.
# Import Python Standard Library dependencies
import json
from pathlib import Path
import random
# Import utility functions
from cjm_psl_utils.core import download_file
from cjm_pil_utils.core import resize_img
# Import numpy
import numpy as np
# Import the pandas package
import pandas as pd
# Import PIL for image manipulation
from PIL import Image, ImageDraw, ImageFont
# Import PyTorch dependencies
import torch
# Import Keypoint R-CNN
from torchvision.models.detection import keypointrcnn_resnet50_fpn
from torchvision.models.detection.keypoint_rcnn import KeypointRCNNPredictor
# Import ONNX dependencies
import onnx # Import the onnx module
from onnxsim import simplify # Import the method to simplify ONNX models
import onnxruntime as ort # Import the ONNX Runtime
Setting Up the Project
In this section, we’ll set the folder locations for our project and training session with the PyTorch checkpoint. Let’s also ensure we have a font file for annotating images.
Set the Directory Paths
# The name for the project
= f"pytorch-keypoint-r-cnn"
project_name
# The path for the project folder
= Path(f"./{project_name}/")
project_dir
# Create the project directory if it does not already exist
=True, exist_ok=True)
project_dir.mkdir(parents
# The path to the checkpoint folder
= Path(project_dir/f"2024-01-30_10-44-52")
checkpoint_dir
pd.Series({"Project Directory:": project_dir,
"Checkpoint Directory:": checkpoint_dir,
='columns') }).to_frame().style.hide(axis
Project Directory: | pytorch-keypoint-r-cnn |
---|---|
Checkpoint Directory: | pytorch-keypoint-r-cnn/2024-01-30_10-44-52 |
Download a Font File
# Set the name of the font file
= 'KFOlCnqEu92Fr1MmEU9vAw.ttf'
font_file
# Download the font file
f"https://fonts.gstatic.com/s/roboto/v30/{font_file}", "./") download_file(
Loading the Checkpoint Data
Now, we can load the colormap used during training and initialize a Keypoint R-CNN model with the saved checkpoint.
Load the Colormap
# The colormap path
= list(checkpoint_dir.glob('*colormap.json'))[0]
colormap_path
# Load the JSON colormap data
with open(colormap_path, 'r') as file:
= json.load(file)
colormap_json
# Convert the JSON data to a dictionary
= {item['label']: item['color'] for item in colormap_json['items']}
colormap_dict
# Extract the class names from the colormap
= list(colormap_dict.keys())
class_names
# Make a copy of the colormap in integer format
= [tuple(int(c*255) for c in color) for color in colormap_dict.values()] int_colors
Load the Model Checkpoint
# The model checkpoint path
= list(checkpoint_dir.glob('*.pth'))[0]
checkpoint_path
# Load the model checkpoint onto the CPU
= torch.load(checkpoint_path, map_location='cpu') model_checkpoint
Load the Trained Keypoint R-CNN Model
# Load a pre-trained model
= keypointrcnn_resnet50_fpn(weights='DEFAULT')
model
# Replace the classifier head with the number of keypoints
= model.roi_heads.keypoint_predictor.kps_score_lowres.in_channels
in_features = KeypointRCNNPredictor(in_channels=in_features, num_keypoints=len(class_names))
model.roi_heads.keypoint_predictor
# Initialize the model with the checkpoint parameters and buffers
model.load_state_dict(model_checkpoint)
<All keys matched successfully>
Exporting the Model to ONNX
Before exporting the model, let’s ensure the model is in evaluation mode.
eval(); model.
Prepare the Input Tensor
We need a sample input tensor for the export process.
= torch.randn(1, 3, 256, 256) input_tensor
Export the Model to ONNX
We can export the model using PyTorch’s torch.onnx.export()
function. This function performs a single pass through the model and records all operations to generate a TorchScript graph. It then exports this graph to ONNX by decomposing each graph node (which contains a PyTorch operator) into a series of ONNX operators.
If we want the ONNX model to support different input sizes, we must set the width and height input axes as dynamic.
# Set a filename for the ONNX model
= f"{checkpoint_dir}/{colormap_path.stem.removesuffix('-colormap')}-{checkpoint_path.stem}.onnx"
onnx_file_path
# Export the PyTorch model to ONNX format
torch.onnx.export(model.cpu(),
input_tensor.cpu(),
onnx_file_path,=True,
export_params=False,
do_constant_folding= ['input'],
input_names = ['boxes', 'labels', 'scores', 'keypoints', 'keypoints_scores'],
output_names ={'input': {2 : 'height', 3 : 'width'}}
dynamic_axes )
The export function may return some UserWarning
messages when we export the model. We can ignore these warnings as the exported model functions as expected.
Simplify the ONNX Model
The ONNX models generated by PyTorch are not always the most concise. We can use the onnx-simplifier
package to tidy up the exported model.
# Load the ONNX model from the onnx_file_name
= onnx.load(onnx_file_path)
onnx_model
# Simplify the model
= simplify(onnx_model)
model_simp, check
# Save the simplified model to the onnx_file_name
onnx.save(model_simp, onnx_file_path)
Performing Inference with ONNX Runtime
Now that we have our ONNX model, it’s time to test it with ONNX Runtime.
Create an Inference Session
We interact with models in ONNX Runtime through an InferenceSession
object. Here we can specify which Execution Providers to use for inference and other configuration information. Execution Providers are the interfaces for hardware-specific inference engines like TensorRT for NVIDIA and OpenVINO for Intel. By default, the InferenceSession
uses the generic CPUExecutionProvider
.
# Load the model and create an InferenceSession
= ort.InferenceSession(onnx_file_path, providers=['CPUExecutionProvider']) session
Define Annotation Function
Next, we need to annotate images with key points. PIL includes functionality to draw circles on images.
def draw_keypoints_pil(image, keypoints, labels, colors, radius:int=5):
"""
Annotates an image with keypoints, each marked by a circle and associated with specific labels and colors.
This function draws circles on the provided image at given keypoint coordinates. Each keypoint is associated
with a label and a color. The radius of the circles can be adjusted.
Parameters:
image (PIL.Image): The input image on which annotations will be drawn.
keypoints (list of tuples): A list of (x, y) tuples representing the coordinates of each keypoint.
labels (list of str): A list of labels corresponding to each keypoint.
colors (list of tuples): A list of RGB tuples for each keypoint, defining the color of the circle to be drawn.
radius (int, optional): The radius of the circles to be drawn for each keypoint. Defaults to 5.
Returns:
annotated_image (PIL.Image): The image annotated with keypoints, each represented as a colored circle.
"""
# Create a copy of the image
= image.copy()
annotated_image
# Create an ImageDraw object for drawing on the image
= ImageDraw.Draw(annotated_image)
draw
# Loop through the bounding boxes and labels in the 'annotation' DataFrame
for i in range(len(labels)):
# Get the key point coordinates
= keypoints[i]
x, y
# Draw a circle
- radius, y - radius, x + radius, y + radius), fill=colors[i])
draw.ellipse((x
return annotated_image
Select a Test Image
We can download an image from one of my HuggingFace repositories to verify the exported model performs as expected.
= "pexels-2769554-man-doing-rock-and-roll-sign.jpg"
test_img_name = f"https://huggingface.co/datasets/cj-mills/pexel-hand-gesture-test-images/resolve/main/{test_img_name}"
test_img_url
'./', False)
download_file(test_img_url,
= Image.open(test_img_name)
test_img
display(test_img)
pd.Series({"Test Image Size:": test_img.size,
='columns') }).to_frame().style.hide(axis
Test Image Size: | (640, 960) |
---|
Prepare the Test Image
# Set test image size
= 512
test_sz
## Resize the test image
= resize_img(test_img, target_sz=test_sz, divisor=1)
input_img
# Calculate the scale between the source image and the resized image
= min(test_img.size) / min(input_img.size)
min_img_scale
display(input_img)
# Print the prediction data as a Pandas DataFrame for easy formatting
pd.Series({"Source Image Size:": test_img.size,
"Input Dims:": input_img.size,
"Min Image Scale:": min_img_scale,
"Input Image Size:": input_img.size
='columns') }).to_frame().style.hide(axis
Source Image Size: | (640, 960) |
---|---|
Input Dims: | (512, 768) |
Min Image Scale: | 1.250000 |
Input Image Size: | (512, 768) |
Prepare the Input Tensor
When we convert the PIL input image to a NumPy array, we need to reorder the array values to channels-first format, scale the values from [0,255]
to [0,1]
, and add a batch dimension.
# Convert the input image to NumPy format
= np.array(input_img, dtype=np.float32).transpose((2, 0, 1))[None]/255 input_tensor_np
Compute the Predictions
Now, we can finally perform inference with our ONNX model.
# Run inference
= session.run(None, {"input": input_tensor_np})
model_output
# Set the confidence threshold
= 0.8
conf_threshold
# Filter the output based on the confidence threshold
= model_output[2] > conf_threshold
scores_mask
# Extract and scale the predicted keypoints
= (model_output[3][scores_mask])[:,:,:-1].reshape(-1,2)*min_img_scale
predicted_keypoints
predicted_keypoints
=class_names*sum(scores_mask).item()
labels
draw_keypoints_pil(test_img,
predicted_keypoints, =labels,
labels=[int_colors[i] for i in [class_names.index(label) for label in labels]],
colors )
The model appears to work as intended, even on this new image.
- Don’t forget to download the ONNX model from the Colab Environment’s file browser. (tutorial link)
Conclusion
Congratulations on reaching the end of this tutorial! We previously trained a Keypoint R-CNN model in PyTorch, and now we’ve exported that model to ONNX. With this, we can streamline our deployment process and leverage platform-specific hardware optimizations through ONNX Runtime.
As you move forward, consider exploring more about ONNX and its ecosystem. Check out the available Execution Providers that provide flexible interfaces to different hardware acceleration libraries.
- Feel free to post questions or problems related to this tutorial in the comments below. I try to make time to address them on Thursdays and Fridays.
- I’m Christian Mills, a deep learning consultant specializing in computer vision and practical AI implementations.
- I help clients leverage cutting-edge AI technologies to solve real-world problems.
- Learn more about me or reach out via email at [email protected] to discuss your project.