Exporting YOLOX Models from PyTorch to ONNX
- Introduction
- Getting Started with the Code
- Setting Up Your Python Environment
- Importing the Required Dependencies
- Setting Up the Project
- Loading the Checkpoint Data
- Exporting the Model to ONNX
- Performing Inference with ONNX Runtime
- Conclusion
Introduction
Welcome back to this series on training YOLOX models for real-time applications! Previously, we demonstrated how to fine-tune a YOLOX model in PyTorch by creating a hand gesture detector. This tutorial builds on that by showing how to export the model to ONNX and perform inference using ONNX Runtime.
ONNX (Open Neural Network Exchange) is an open format to represent machine learning models and make them portable across various platforms. ONNX Runtime is a cross-platform inference accelerator that provides interfaces to hardware-specific libraries. By exporting our model to ONNX, we can deploy it to multiple devices and leverage hardware acceleration for faster inference. When it comes to real-time applications, even minor speedups have a noticeable impact.
Additionally, we’ll implement the functionality to handle post-processing and draw bounding boxes without relying on PyTorch as a dependency. By the end of this tutorial, you will have an ONNX version of our YOLOX model that you can deploy to servers and edge devices using ONNX Runtime.
Getting Started with the Code
As with the previous tutorial, the code is available as a Jupyter Notebook.
Jupyter Notebook | Google Colab |
---|---|
GitHub Repository | Open In Colab |
Setting Up Your Python Environment
We’ll need to add a few new libraries to our Python environment for working with ONNX models.
Run the following command to install these additional libraries:
# Install ONNX packages
pip install onnx onnxruntime onnx-simplifier
Importing the Required Dependencies
With our environment updated, we can dive into the code. First, we will import the necessary Python dependencies into our Jupyter Notebook.
# Import Python Standard Library dependencies
import json
from pathlib import Path
import random
# Import utility functions
from cjm_psl_utils.core import download_file
from cjm_pil_utils.core import resize_img
# Import YOLOX package
from cjm_yolox_pytorch.model import build_model
from cjm_yolox_pytorch.inference import YOLOXInferenceWrapper
# Import numpy
import numpy as np
# Import the pandas package
import pandas as pd
# Import PIL for image manipulation
from PIL import Image, ImageDraw, ImageFont
# Import PyTorch dependencies
import torch
# Import ONNX dependencies
import onnx # Import the onnx module
from onnxsim import simplify # Import the method to simplify ONNX models
import onnxruntime as ort # Import the ONNX Runtime
Setting Up the Project
In this section, we’ll set the folder locations for our project and training session with the PyTorch checkpoint. Let’s also ensure we have a font file for annotating images.
Set the Directory Paths
# The name for the project
= f"pytorch-yolox-object-detector"
project_name
# The path for the project folder
= Path(f"./{project_name}/")
project_dir
# Create the project directory if it does not already exist
=True, exist_ok=True)
project_dir.mkdir(parents
# The path to the checkpoint folder
= Path(project_dir/f"2024-02-17_11-08-46")
checkpoint_dir
pd.Series({"Project Directory:": project_dir,
"Checkpoint Directory:": checkpoint_dir,
='columns') }).to_frame().style.hide(axis
Project Directory: | pytorch-yolox-object-detector |
---|---|
Checkpoint Directory: | pytorch-yolox-object-detector/2024-02-17_11-08-46 |
Download a Font File
# Set the name of the font file
= 'KFOlCnqEu92Fr1MmEU9vAw.ttf'
font_file
# Download the font file
f"https://fonts.gstatic.com/s/roboto/v30/{font_file}", "./") download_file(
Loading the Checkpoint Data
Now we can load the colormap and normalization stats used during training and initialize a YOLOX model with the saved checkpoint.
Load the Colormap
# The colormap path
= list(checkpoint_dir.glob('*colormap.json'))[0]
colormap_path
# Load the JSON colormap data
with open(colormap_path, 'r') as file:
= json.load(file)
colormap_json
# Convert the JSON data to a dictionary
= {item['label']: item['color'] for item in colormap_json['items']}
colormap_dict
# Extract the class names from the colormap
= list(colormap_dict.keys())
class_names
# Make a copy of the colormap in integer format
= [tuple(int(c*255) for c in color) for color in colormap_dict.values()] int_colors
Load the Normalization Statistics
# The normalization stats path
= checkpoint_dir/'norm_stats.json'
norm_stats_path
# Read the normalization stats from the JSON file
with open(norm_stats_path, "r") as f:
= json.load(f)
norm_stats_dict
# Convert the dictionary to a tuple
= (norm_stats_dict["mean"], norm_stats_dict["std_dev"])
norm_stats
# Print the mean and standard deviation
pd.DataFrame(norm_stats)
0 | 1 | 2 | |
---|---|---|---|
0 | 0.5 | 0.5 | 0.5 |
1 | 1.0 | 1.0 | 1.0 |
Load the Model Checkpoint
# The model checkpoint path
= list(checkpoint_dir.glob('*.pth'))[0]
checkpoint_path
# Load the model checkpoint onto the CPU
= torch.load(checkpoint_path, map_location='cpu') model_checkpoint
Load the Trained YOLOX Model
# Select the YOLOX model configuration
= checkpoint_path.stem
model_type
# Create a YOLOX model with the number of output classes equal to the number of class names
= build_model(model_type, len(class_names))
model
# Get stride values for processing output
= model.bbox_head.strides
strides
# Initialize the model with the checkpoint parameters and buffers
model.load_state_dict(model_checkpoint)
<All keys matched successfully>
Exporting the Model to ONNX
Before exporting the model, we’ll wrap it with the preprocessing and post-processing steps as we did previously. These steps will be included in the ONNX model, reducing the code we need to write when deploying the model to other platforms.
Prepare the Model for Inference
The YOLOXInferenceWrapper
class has some optional settings we did not explore in the previous tutorial. The scale_inp
setting will scale pixel data from the range [0,255]
to[0,1]
, and the channels_last
setting sets the model to expect input tensors in channels-last format. These settings can be helpful when deploying to platforms where tensor operations are less convenient.
Additionally, we can turn off the post-processing steps if we plan to deploy the model using tools that do not support those operations, like the Barracuda inference library for the Unity game engine.
The post-processing steps require the width and height of the input tensor. The indices for accessing those values depend on the format for the input tensor, so we’ll store the slice
to access them for later.
# Convert the normalization stats to tensors
= torch.tensor(norm_stats[0]).view(1, 3, 1, 1)
mean_tensor = torch.tensor(norm_stats[1]).view(1, 3, 1, 1)
std_tensor
# Set the model to evaluation mode
eval();
model.
# Wrap the model with preprocessing and post-processing steps
= YOLOXInferenceWrapper(model,
wrapped_model
mean_tensor,
std_tensor, =False, # Scale input values from the rang [0,255] to [0,1]
scale_inp=False, # Have the model expect input in channels-last format
channels_last=True # Enable or disable post-processing steps
run_box_and_prob_calculation
)
# Get the slice object for extracting the input dimensions
= wrapped_model.input_dim_slice
input_dim_slice input_dim_slice
slice(2, 4, None)
# Convert the normalization stats to tensors
= torch.tensor(norm_stats[0]).view(1, 3, 1, 1)
mean_tensor = torch.tensor(norm_stats[1]).view(1, 3, 1, 1)
std_tensor
# Set the model to evaluation mode
eval();
model.
# Wrap the model with preprocessing and post-processing steps
= YOLOXInferenceWrapper(model,
wrapped_model
mean_tensor,
std_tensor, =True, # Scale input values from the rang [0,255] to [0,1]
scale_inp=True, # Have the model expect input in channels-first format
channels_last=False # Enable or disable post-processing steps
run_box_and_prob_calculation
)
# Get the slice object for extracting the input dimensions
= wrapped_model.input_dim_slice
input_dim_slice input_dim_slice
slice(1, 3, None)
=False
scale_inp=False
channels_last=False run_box_and_prob_calculation
Prepare the Input Tensor
We need a sample input tensor for the export process.
= torch.randn(1, 3, 256, 256) input_tensor
= torch.randn(1, 256, 256, 3) input_tensor
Export the Model to ONNX
We can export the model using PyTorch’s torch.onnx.export()
function. This function performs a single pass through the model and records all operations to generate a TorchScript graph. It then exports this graph to ONNX by decomposing each graph node (which contains a PyTorch operator) into a series of ONNX operators.
If we want the ONNX model to support different input sizes, we must set the width and height input axes as dynamic. These axes again depend on the input format, so we’ll use the slice object we saved earlier.
# Set a filename for the ONNX model
= f"{checkpoint_dir}/{colormap_path.stem.removesuffix('-colormap')}-{model_type}.onnx"
onnx_file_path
# Export the PyTorch model to ONNX format
torch.onnx.export(wrapped_model.cpu(),
input_tensor.cpu(),
onnx_file_path,=True,
export_params=False,
do_constant_folding= ['input'],
input_names = ['output'],
output_names ={'input': {input_dim_slice.start : 'height', input_dim_slice.stop-1 : 'width'}}
dynamic_axes )
============= Diagnostic Run torch.onnx.export version 2.0.1+cu118 =============
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================
The export function will return a TracerWarning
when we export the model with the post-processing steps enabled. The post-processing steps involve iterating over the list of stride values for the YOLOX model, and the exported ONNX model will not support dynamic sizes for that list. We can ignore this warning as the stride values will not change during inference.
Simplify the ONNX model
The ONNX models generated by PyTorch are not always the most concise. We can use the onnx-simplifier
package to tidy up the exported model.
This step is usually optional but is necessary for the ONNX model to work with the Barracuda inference library.
# Load the ONNX model from the onnx_file_name
= onnx.load(onnx_file_path)
onnx_model
# Simplify the model
= simplify(onnx_model)
model_simp, check
# Save the simplified model to the onnx_file_name
onnx.save(model_simp, onnx_file_path)
Performing Inference with ONNX Runtime
Now that we have our ONNX model, it’s time to test it with ONNX Runtime.
Create an Inference Session
We interact with models in ONNX Runtime through an InferenceSession
object. Here we can specify which Execution Providers to use for inference and other configuration information. Execution Providers are the interfaces for hardware-specific inference engines like TensorRT for NVIDIA and OpenVINO for Intel. By default, the InferenceSession
uses the generic CPUExecutionProvider
.
# Load the model and create an InferenceSession
= ort.InferenceSession(onnx_file_path, providers=['CPUExecutionProvider']) session
Define Utility Functions
In the previous tutorial, we used PyTorch to process the model output, filter the predictions, and annotate images with bounding boxes. Now we will implement that functionality using NumPy and PIL.
Define a function to generate the output grids
The YOLOX model uses information from different parts of its backbone model to make predictions. In our case, it uses features from three sections, with one early, in the middle, and at the end of the backbone model. This approach helps the YOLOX model detect objects of different sizes in the image.
We use the stride values to scale predictions from these sections back to the input resolution. Here, we can see the difference in results when using a single stride value in isolation with a YOLOX model trained on the COCO dataset.
The following function generates grids of values using the input dimensions and stride values to scale bounding box predictions to the input resolution.
def generate_output_grids_np(height, width, strides=[8,16,32]):
"""
Generate a numpy array containing grid coordinates and strides for a given height and width.
Args:
height (int): The height of the image.
width (int): The width of the image.
Returns:
np.ndarray: A numpy array containing grid coordinates and strides.
"""
= []
all_coordinates
for stride in strides:
# Calculate the grid height and width
= height // stride
grid_height = width // stride
grid_width
# Generate grid coordinates
= np.meshgrid(np.arange(grid_height), np.arange(grid_width), indexing='ij')
g1, g0
# Create an array of strides
= np.full((grid_height, grid_width), stride)
s
# Stack the coordinates along with the stride
= np.stack((g0.flatten(), g1.flatten(), s.flatten()), axis=-1)
coordinates
# Append to the list
all_coordinates.append(coordinates)
# Concatenate all arrays in the list along the first dimension
= np.concatenate(all_coordinates, axis=0)
output_grids
return output_grids
Define a function to calculate bounding boxes and probabilities
Next, we’ll scale the bounding box predictions and extract the predicted class and corresponding confidence score.
def calculate_boxes_and_probs(model_output:np.ndarray, output_grids:np.ndarray) -> np.ndarray:
"""
Calculate the bounding boxes and their probabilities.
Parameters:
model_output (numpy.ndarray): The output of the model.
output_grids (numpy.ndarray): The output grids.
Returns:
numpy.ndarray: The array containing the bounding box coordinates, class labels, and maximum probabilities.
"""
# Calculate the bounding box coordinates
= (model_output[..., :2] + output_grids[..., :2]) * output_grids[..., 2:]
box_centroids = np.exp(model_output[..., 2:4]) * output_grids[..., 2:]
box_sizes
= [t.squeeze(axis=2) for t in np.split(box_centroids - box_sizes / 2, 2, axis=2)]
x0, y0 = [t.squeeze(axis=2) for t in np.split(box_sizes, 2, axis=2)]
w, h
# Calculate the probabilities for each class
= model_output[..., 4]
box_objectness = model_output[..., 5:]
box_cls_scores = np.expand_dims(box_objectness, -1) * box_cls_scores
box_probs
# Get the maximum probability and corresponding class for each proposal
= np.max(box_probs, axis=-1)
max_probs = np.argmax(box_probs, axis=-1)
labels
return np.array([x0, y0, w, h, labels, max_probs]).transpose((1, 2, 0))
Define a function to calculate the intersection-over-union
Previously, we used the nms function included with torchvision to filter bounding box proposals using Non-Maximum Suppression. This approach filters bounding box proposals when they overlap too much with another bounding box with a higher confidence score.
We determine how much a pair of bounding boxes overlap by computing the Intersection over Union (IoU). The following function shows how to do this in NumPy.
def calc_iou(proposals:np.ndarray) -> np.ndarray:
"""
Calculates the Intersection over Union (IoU) for all pairs of bounding boxes (x,y,w,h) in 'proposals'.
The IoU is a measure of overlap between two bounding boxes. It is calculated as the area of
intersection divided by the area of union of the two boxes.
Parameters:
proposals (2D np.array): A NumPy array of bounding boxes, where each box is an array [x, y, width, height].
Returns:
iou (2D np.array): The IoU matrix where each element i,j represents the IoU of boxes i and j.
"""
# Calculate coordinates for the intersection rectangles
= np.maximum(proposals[:, 0], proposals[:, 0][:, None])
x1 = np.maximum(proposals[:, 1], proposals[:, 1][:, None])
y1 = np.minimum(proposals[:, 0] + proposals[:, 2], (proposals[:, 0] + proposals[:, 2])[:, None])
x2 = np.minimum(proposals[:, 1] + proposals[:, 3], (proposals[:, 1] + proposals[:, 3])[:, None])
y2
# Calculate intersection areas
= np.maximum(x2 - x1, 0) * np.maximum(y2 - y1, 0)
intersections
# Calculate union areas
= proposals[:, 2] * proposals[:, 3]
areas = areas[:, None] + areas - intersections
unions
# Calculate IoUs
= intersections / unions
iou
# Return the iou matrix
return iou
Define a function to filter bounding box proposals using Non-Maximum Suppression
Now we create a function to determine which proposal indices to keep using the calculated IoU values.
def nms_sorted_boxes(iou:np.ndarray, iou_thresh:float=0.45) -> np.ndarray:
"""
Applies non-maximum suppression (NMS) to sorted bounding boxes.
It suppresses boxes that have high overlap (as defined by the IoU threshold) with a box that
has a higher score.
Parameters:
iou (np.ndarray): An IoU matrix where each element i,j represents the IoU of boxes i and j.
iou_thresh (float): The IoU threshold for suppression. Boxes with IoU > iou_thresh are suppressed.
Returns:
keep (np.ndarray): The indices of the boxes to keep after applying NMS.
"""
# Create a boolean mask to keep track of boxes
= np.ones(iou.shape[0], dtype=bool)
mask
# Apply non-max suppression
for i in range(iou.shape[0]):
if mask[i]:
# Suppress boxes with higher index and IoU > threshold
> iou_thresh) & (np.arange(iou.shape[0]) > i)] = False
mask[(iou[i]
# Return the indices of the boxes to keep
return np.arange(iou.shape[0])[mask]
Define a function to annotate an image with bounding boxes
Now that we have implemented the functionality to process and filter the model output, we only need to annotate images with bounding boxes and labels. PIL includes functionality to draw boxes and write text on images. The following function also scales the font size based on the image resolution to keep the relative size consistent across images.
def draw_bboxes_pil(image, boxes, labels, colors, font, width:int=2, font_size:int=18, probs=None):
"""
Annotates an image with bounding boxes, labels, and optional probability scores.
This function draws bounding boxes on the provided image using the given box coordinates,
colors, and labels. If probabilities are provided, they will be added to the labels.
Parameters:
image (PIL.Image): The input image on which annotations will be drawn.
boxes (list of tuples): A list of bounding box coordinates where each tuple is (x, y, w, h).
labels (list of str): A list of labels corresponding to each bounding box.
colors (list of str): A list of colors for each bounding box and its corresponding label.
font (str): Path to the font file to be used for displaying the labels.
width (int, optional): Width of the bounding box lines. Defaults to 2.
font_size (int, optional): Size of the font for the labels. Defaults to 25.
probs (list of float, optional): A list of probability scores corresponding to each label. Defaults to None.
Returns:
annotated_image (PIL.Image): The image annotated with bounding boxes, labels, and optional probability scores.
"""
# Define a reference diagonal
= 1000
REFERENCE_DIAGONAL
# Scale the font size using the hypotenuse of the image
= int(font_size * (np.hypot(*image.size) / REFERENCE_DIAGONAL))
font_size
# Add probability scores to labels
if probs is not None:
= [f"{label}: {prob*100:.2f}%" for label, prob in zip(labels, probs)]
labels
# Create a copy of the image
= image.copy()
annotated_image
# Create an ImageDraw object for drawing on the image
= ImageDraw.Draw(annotated_image)
draw
# Loop through the bounding boxes and labels in the 'annotation' DataFrame
for i in range(len(labels)):
# Get the bounding box coordinates
= boxes[i]
x, y, w, h
# Create a tuple of coordinates for the bounding box
= (x, y, x+w, y+h)
shape
# Draw the bounding box on the image
=colors[i], width=width)
draw.rectangle(shape, outline
# Load the font file
= ImageFont.truetype(font, font_size)
fnt
# Draw the label box on the image
= draw.textbbox(xy=(0,0), text=labels[i], font=fnt)[2:]
label_w, label_h -label_h, x+label_w, y), outline=colors[i], fill=colors[i], width=width)
draw.rectangle((x, y
# Draw the label on the image
-label_h), labels[i], font=fnt, fill='black' if np.mean(int_colors[5]) > 127.5 else 'white')
draw.multiline_text((x, y
return annotated_image
With our utility functions taken care of, we can select an image to test our ONNX model.
Select a Test Image
Let’s use the same test image and input size from the previous tutorial to compare the results with the PyTorch model.
= "pexels-2769554-man-doing-rock-and-roll-sign.jpg"
test_img_name = f"https://huggingface.co/datasets/cj-mills/pexel-hand-gesture-test-images/resolve/main/{test_img_name}"
test_img_url
'./', False)
download_file(test_img_url,
= Image.open(test_img_name)
test_img
display(test_img)
pd.Series({"Test Image Size:": test_img.size,
='columns') }).to_frame().style.hide(axis
Test Image Size: | (640, 960) |
---|
Prepare the Test Image
# Set test image size
= 384
test_sz
# Resize image without cropping to multiple of the max stride
= resize_img(test_img, target_sz=test_sz, divisor=1)
resized_img
# Calculating the input dimensions that multiples of the max stride
= [dim - dim % max(strides) for dim in resized_img.size]
input_dims
# Calculate the offsets from the resized image dimensions to the input dimensions
= (np.array(resized_img.size) - input_dims)/2
offsets
# Calculate the scale between the source image and the resized image
= min(test_img.size) / min(resized_img.size)
min_img_scale
# Crop the resized image to the input dimensions
= resized_img.crop(box=[*offsets, *resized_img.size-offsets])
input_img
display(input_img)
pd.Series({"Resized Image Size:": resized_img.size,
"Input Dims:": input_dims,
"Offsets:": offsets,
"Min Image Scale:": min_img_scale,
"Input Image Size:": input_img.size
='columns') }).to_frame().style.hide(axis
Resized Image Size: | (384, 576) |
---|---|
Input Dims: | [384, 576] |
Offsets: | [0. 0.] |
Min Image Scale: | 1.666667 |
Input Image Size: | (384, 576) |
Prepare the Input Tensor
When we convert the PIL input image to a NumPy array, we need to reorder the array values to channels-first format, scale the values from [0,255]
to [0,1]
, and add a batch dimension. When we enable the scale_inp
and channels_last
options, we only need to add a batch dimension.
# Convert the existing input image to NumPy format
= np.array(input_img, dtype=np.float32).transpose((2, 0, 1))[None]/255 input_tensor_np
# Convert the existing input image to NumPy format
= np.array(input_img, dtype=np.float32)[None] input_tensor_np
Compute the Predictions
Now we can finally perform inference with our ONNX model.
# Run inference
= session.run(None, {"input": input_tensor_np})[0]
outputs
# Process the model output
if not wrapped_model.run_box_and_prob_calculation:
= calculate_boxes_and_probs(outputs, generate_output_grids_np(*input_tensor_np.shape[input_dim_slice]))
outputs
= 0.45
bbox_conf_thresh = 0.45
iou_thresh
# Filter the proposals based on the confidence threshold
= outputs[:, : ,-1]
max_probs = max_probs > bbox_conf_thresh
mask = outputs[mask]
proposals
# Sort the proposals by probability in descending order
= proposals[proposals[..., -1].argsort()][::-1]
proposals
# Apply non-max suppression to the proposals with the specified threshold
= nms_sorted_boxes(calc_iou(proposals[:, :-2]), iou_thresh)
proposal_indices = proposals[proposal_indices]
proposals
= (proposals[:,:4]+[*offsets, 0, 0])*min_img_scale
bbox_list = [class_names[int(idx)] for idx in proposals[:,4]]
label_list = proposals[:,5]
probs_list
draw_bboxes_pil(=test_img,
image=bbox_list,
boxes=label_list,
labels=probs_list,
probs=[int_colors[class_names.index(i)] for i in label_list],
colors=font_file,
font )
Predicted BBoxes: | [‘rock:[342.625 242.367 111.735 110.166]’, ‘no_gesture:[192.449 518.634 104.243 80.717]’] |
---|---|
Confidence Scores: | [‘rock: 91.29%’, ‘no_gesture: 86.78%’] |
The model predictions should be virtually identical to the PyTorch model, but the probability scores can sometimes vary slightly.
- Don’t forget to download the ONNX model from the Colab Environment’s file browser. (tutorial link)
Conclusion
Congratulations on reaching the end of this tutorial! We previously trained a YOLOX model in PyTorch for hand gesture detection, and now we’ve exported that model to ONNX. With this, we can streamline our deployment process and leverage platform-specific hardware optimizations through ONNX Runtime.
As you move forward, consider exploring more about ONNX and its ecosystem. Check out the available Execution Providers that provide flexible interfaces to different hardware acceleration libraries.
Recommended Tutorials
- Quantizing YOLOX with ONNX Runtime and TensorRT in Ubuntu: Learn how to quantize YOLOX models with ONNX Runtime and TensorRT for int8 inference.
- Real-Time Object Tracking with YOLOX and ByteTrack: Learn how to track objects across video frames with YOLOX and ByteTrack.
- Real-Time Object Detection in Unity with ONNX Runtime and DirectML: Learn how to integrate a native plugin within the Unity game engine for real-time object detection using ONNX Runtime.
- Feel free to post questions or problems related to this tutorial in the comments below. I try to make time to address them on Thursdays and Fridays.
I’m Christian Mills, a deep learning consultant specializing in practical AI implementations. I help clients leverage cutting-edge AI technologies to solve real-world problems.
Interested in working together? Fill out my Quick AI Project Assessment form or learn more about me.