Exporting timm Image Classifiers from PyTorch to ONNX
- Introduction
- Getting Started with the Code
- Setting Up Your Python Environment
- Importing the Required Dependencies
- Setting Up the Project
- Loading the Checkpoint Data
- Exporting the Model to ONNX
- Performing Inference with ONNX Runtime
- Conclusion
Introduction
Welcome back to this series on fine-tuning image classifiers with PyTorch and the timm library. Previously, we demonstrated how to fine-tune a ResNet18-D model from the timm library in PyTorch by creating a hand gesture classifier. This tutorial builds on that by showing how to export the model to ONNX and perform inference using ONNX Runtime.
ONNX (Open Neural Network Exchange) is an open format to represent machine learning models and make them portable across various platforms. ONNX Runtime is a cross-platform inference accelerator that provides interfaces to hardware-specific libraries. By exporting our model to ONNX, we can deploy it to multiple devices and leverage hardware acceleration for faster inference.
Additionally, we’ll wrap the PyTorch model with the required preprocessing and post-processing steps to include them in the ONNX model. By the end of this tutorial, you will have an ONNX version of our ResNet18-D model that you can deploy to servers and edge devices using ONNX Runtime.
Getting Started with the Code
As with the previous tutorial, the code is available as a Jupyter Notebook.
Jupyter Notebook | Google Colab |
---|---|
GitHub Repository | Open In Colab |
Setting Up Your Python Environment
We’ll need to add a few new libraries to our Python environment for working with ONNX models.
Run the following command to install these additional libraries:
# Install ONNX packages
pip install onnx onnxruntime onnx-simplifier
Importing the Required Dependencies
With our environment updated, we can dive into the code. First, we will import the necessary Python dependencies into our Jupyter Notebook.
# Import Python Standard Library dependencies
import json
from pathlib import Path
import random
# Import utility functions
from cjm_psl_utils.core import download_file, get_source_code
from cjm_pil_utils.core import resize_img
# Import numpy
import numpy as np
# Import the pandas package
import pandas as pd
# Import PIL for image manipulation
from PIL import Image
# Import timm library
import timm
# Import PyTorch dependencies
import torch
from torch import nn
# Import ONNX dependencies
import onnx # Import the onnx module
from onnxsim import simplify # Import the method to simplify ONNX models
import onnxruntime as ort # Import the ONNX Runtime
Setting Up the Project
In this section, we’ll set the folder locations for our project and training session with the PyTorch checkpoint.
Set the Directory Paths
# The name for the project
= f"pytorch-timm-image-classifier"
project_name
# The path for the project folder
= Path(f"./{project_name}/")
project_dir
# Create the project directory if it does not already exist
=True, exist_ok=True)
project_dir.mkdir(parents
# The path to the checkpoint folder
= Path(project_dir/f"2024-02-02_15-41-23")
checkpoint_dir
pd.Series({"Project Directory:": project_dir,
"Checkpoint Directory:": checkpoint_dir,
='columns') }).to_frame().style.hide(axis
Project Directory: | pytorch-timm-image-classifier |
---|---|
Checkpoint Directory: | pytorch-timm-image-classifier/2024-02-02_15-41-23 |
Loading the Checkpoint Data
Now, we can load the class labels and initialize a ResNet18-D model with the saved checkpoint.
Load the Class Labels
# The class labels path
= list(checkpoint_dir.glob('*classes.json'))[0]
class_labels_path
# Load the JSON class labels data
with open(class_labels_path, 'r') as file:
= json.load(file)
class_labels_json
# Get the list of classes
= class_labels_json['classes']
class_names
# Print the list of classes
pd.DataFrame(class_names)
0 | |
---|---|
0 | call |
1 | dislike |
2 | fist |
3 | four |
4 | like |
5 | mute |
6 | no_gesture |
7 | ok |
8 | one |
9 | palm |
10 | peace |
11 | peace_inverted |
12 | rock |
13 | stop |
14 | stop_inverted |
15 | three |
16 | three2 |
17 | two_up |
18 | two_up_inverted |
Load the Model Checkpoint
# The model checkpoint path
= list(checkpoint_dir.glob('*.pth'))[0]
checkpoint_path
# Load the model checkpoint onto the CPU
= torch.load(checkpoint_path, map_location='cpu') model_checkpoint
Load the Finetuned Model
# Specify the model configuration
= checkpoint_path.stem.split(".")[0]
model_type
# Create a model with the number of output classes equal to the number of class names
= timm.create_model(model_type, num_classes=len(class_names))
model
# Initialize the model with the checkpoint parameters and buffers
model.load_state_dict(model_checkpoint)
<All keys matched successfully>
Get the Normalization Stats
# Import the resnet module
from timm.models import resnet
# Get the default configuration of the chosen model
= resnet.default_cfgs[model_type].default.to_dict()
model_cfg
# Retrieve normalization statistics (mean and std) specific to the pretrained model
= model_cfg['mean'], model_cfg['std']
mean, std = (mean, std)
norm_stats norm_stats
((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
Exporting the Model to ONNX
Before exporting the model, we will wrap it with the preprocessing and post-processing steps. These steps will be included in the ONNX model, reducing the code we need to write when deploying the model to other platforms.
Prepare the Model for Inference
Whenever we make predictions with the model, we must normalize the input data and pass the model output through a Softmax function. We can define a wrapper class that automatically performs these steps.
Additionally, we can include options to scale pixel data from the range [0,255]
to [0,1]
and set the model to expect input tensors in channels-last format. These settings can be helpful when deploying to platforms where tensor operations are less convenient.
Define model export wrapper
class InferenceWrapper(nn.Module):
def __init__(self, model, normalize_mean, normalize_std, scale_inp=False, channels_last=False):
super().__init__()
self.model = model
self.register_buffer("normalize_mean", normalize_mean)
self.register_buffer("normalize_std", normalize_std)
self.scale_inp = scale_inp
self.channels_last = channels_last
self.softmax = nn.Softmax(dim=1)
def preprocess_input(self, x):
if self.scale_inp:
= x / 255.0
x
if self.channels_last:
= x.permute(0, 3, 1, 2)
x
= (x - self.normalize_mean) / self.normalize_std
x return x
def forward(self, x):
= self.preprocess_input(x)
x = self.model(x)
x = self.softmax(x)
x return x
Wrap model with preprocessing and post-processing steps
# Define the normalization mean and standard deviation
= torch.tensor(norm_stats[0]).view(1, 3, 1, 1)
mean_tensor = torch.tensor(norm_stats[1]).view(1, 3, 1, 1)
std_tensor
# Set the model to evaluation mode
eval();
model.
# Wrap the model with preprocessing and post-processing steps
= InferenceWrapper(model,
wrapped_model
mean_tensor,
std_tensor, =False, # Scale input values from the rang [0,255] to [0,1]
scale_inp=False, # Have the model expect input in channels-last format
channels_last )
# Define the normalization mean and standard deviation
= torch.tensor(norm_stats[0]).view(1, 3, 1, 1)
mean_tensor = torch.tensor(norm_stats[1]).view(1, 3, 1, 1)
std_tensor
# Set the model to evaluation mode
eval();
model.
# Wrap the model with preprocessing and post-processing steps
= InferenceWrapper(model,
wrapped_model
mean_tensor,
std_tensor, =True, # Scale input values from the rang [0,255] to [0,1]
scale_inp=True, # Have the model expect input in channels-last format
channels_last )
=False
scale_inp=False channels_last
Prepare the Input Tensor
We need a sample input tensor for the export process.
= torch.randn(1, 3, 256, 256) input_tensor
= torch.randn(1, 256, 256, 3) input_tensor
Export the Model to ONNX
We can export the model using the torch.onnx.export()
function. This function performs a single pass through the model and records all operations to generate a TorchScript graph. It then exports this graph to ONNX by decomposing each graph node (which contains a PyTorch operator) into a series of ONNX operators.
If we want the ONNX model to support different input sizes, we must set the width and height input axes as dynamic.
# Set a filename for the ONNX model
= f"{checkpoint_dir}/{class_labels_path.stem.removesuffix('-classes')}-{model_type}.onnx"
onnx_file_path
# Export the PyTorch model to ONNX format
torch.onnx.export(wrapped_model.cpu(),
input_tensor.cpu(),
onnx_file_path,=True,
export_params=False,
do_constant_folding= ['input'],
input_names = ['output'],
output_names ={'input': {2 : 'height', 3 : 'width'}}
dynamic_axes )
# Set a filename for the ONNX model
= f"{checkpoint_dir}/{class_labels_path.stem.removesuffix('-classes')}-{model_type}.onnx"
onnx_file_path
# Export the PyTorch model to ONNX format
torch.onnx.export(wrapped_model.cpu(),
input_tensor.cpu(),
onnx_file_path,=True,
export_params=False,
do_constant_folding= ['input'],
input_names = ['output'],
output_names ={'input': {1 : 'height', 2 : 'width'}}
dynamic_axes )
Simplify the ONNX Model
The ONNX models generated by PyTorch are not always the most concise. We can use the onnx-simplifier
package to tidy up the exported model. We then save the updated ONNX model back to disk for later.
# Load the ONNX model from the onnx_file_name
= onnx.load(onnx_file_path)
onnx_model
# Simplify the model
= simplify(onnx_model)
model_simp, check
# Save the simplified model to the onnx_file_name
onnx.save(model_simp, onnx_file_path)
Performing Inference with ONNX Runtime
Now that we have our ONNX model, it’s time to test it with ONNX Runtime.
Create an Inference Session
We interact with models in ONNX Runtime through an InferenceSession
object. Here, we can specify which Execution Providers to use for inference and other configuration information. Execution Providers are the interfaces for hardware-specific inference engines like TensorRT for NVIDIA and OpenVINO for Intel. By default, the InferenceSession
uses the generic CPUExecutionProvider
.
# Load the model and create an InferenceSession
= ort.InferenceSession(onnx_file_path, , providers=['CPUExecutionProvider']) session
With our inference session initialized, we can select an image to test our ONNX model.
Select a Test Image
Let’s use the same test image and input size from the previous tutorial to compare the results with the PyTorch model.
= 'pexels-elina-volkova-16191659.jpg'
test_img_name = f"https://huggingface.co/datasets/cj-mills/pexel-hand-gesture-test-images/resolve/main/{test_img_name}"
test_img_url
'./', False)
download_file(test_img_url,
= Image.open(test_img_name)
test_img
display(test_img)
pd.Series({"Test Image Size:": test_img.size,
='columns') }).to_frame().style.hide(axis
Test Image Size: | (637, 960) |
---|
Prepare the Test Image
# Set test image size
= 288
test_sz
# Resize image without cropping
= resize_img(test_img.copy(), target_sz=test_sz)
input_img
display(input_img)
pd.Series({"Input Image Size:": input_img.size
='columns') }).to_frame().style.hide(axis
Input Image Size: | (288, 416) |
---|
Prepare the Input Tensor
When we convert the PIL input image to a NumPy array, we need to reorder the array values to channels-first format, scale the values from [0,255]
to [0,1]
, and add a batch dimension. When we enable the scale_inp
and channels_last
options, we only need to add a batch dimension.
# Convert the existing input image to NumPy format
= np.array(input_img, dtype=np.float32).transpose((2, 0, 1))[None]/255 input_tensor_np
# Convert the existing input image to NumPy format
= np.array(input_img, dtype=np.float32)[None] input_tensor_np
Compute the Predictions
Now, we can finally perform inference with our ONNX model.
# Run inference
= session.run(None, {"input": input_tensor_np})[0]
outputs
# Get the highest confidence score
= outputs.max()
confidence_score
# Get the class index with the highest confidence score and convert it to the class name
= class_names[outputs.argmax()]
pred_class
# Display the image
display(test_img)
# Store the prediction data in a Pandas Series for easy formatting
pd.Series({"Input Size:": input_img.size,
"Target Class:": target_cls,
"Predicted Class:": pred_class,
"Confidence Score:": f"{confidence_score*100:.2f}%"
='columns') }).to_frame().style.hide(axis
Input Size: | (288, 416) |
---|---|
Predicted Class: | mute |
Confidence Score: | 100.00% |
The model predictions should be virtually identical to the PyTorch model, but the confidence scores can sometimes vary slightly.
- Don’t forget to download the ONNX model from the Colab Environment’s file browser. (tutorial link)
Conclusion
Congratulations on reaching the end of this tutorial! We previously fine-tuned a model from the timm library in PyTorch for hand gesture classification and now exported that model to ONNX. With this, we can streamline our deployment process and leverage platform-specific hardware optimizations through ONNX Runtime.
As you move forward, consider exploring more about ONNX and its ecosystem. Check out the available Execution Providers that provide flexible interfaces to different hardware acceleration libraries.
Recommended Tutorials
- Quantizing timm Image Classifiers with ONNX Runtime and TensorRT in Ubuntu: Learn how to quantize timm image classification models with ONNX Runtime and TensorRT for int8 inference.
- Feel free to post questions or problems related to this tutorial in the comments below. I try to make time to address them on Thursdays and Fridays.
- I’m Christian Mills, a deep learning consultant specializing in computer vision and practical AI implementations.
- I help clients leverage cutting-edge AI technologies to solve real-world problems.
- Learn more about me or reach out via email at [email protected] to discuss your project.