Exporting timm Image Classifiers from PyTorch to ONNX

pytorch
onnx
image-classification
tutorial
Learn how to export timm image classification models from PyTorch to ONNX and perform inference using ONNX Runtime.
Author

Christian Mills

Published

August 23, 2023

Modified

September 11, 2024

Introduction

Welcome back to this series on fine-tuning image classifiers with PyTorch and the timm library. Previously, we demonstrated how to fine-tune a ResNet18-D model from the timm library in PyTorch by creating a hand gesture classifier. This tutorial builds on that by showing how to export the model to ONNX and perform inference using ONNX Runtime.

ONNX (Open Neural Network Exchange) is an open format to represent machine learning models and make them portable across various platforms. ONNX Runtime is a cross-platform inference accelerator that provides interfaces to hardware-specific libraries. By exporting our model to ONNX, we can deploy it to multiple devices and leverage hardware acceleration for faster inference.

Additionally, we’ll wrap the PyTorch model with the required preprocessing and post-processing steps to include them in the ONNX model. By the end of this tutorial, you will have an ONNX version of our ResNet18-D model that you can deploy to servers and edge devices using ONNX Runtime.

This post assumes the reader has completed the previous tutorial linked below:

Getting Started with the Code

As with the previous tutorial, the code is available as a Jupyter Notebook.

Jupyter Notebook Google Colab
GitHub Repository Open In Colab

Setting Up Your Python Environment

We’ll need to add a few new libraries to our Python environment for working with ONNX models.

Package Description
onnx This package provides a Python API for working with ONNX models. (link)
onnxruntime ONNX Runtime is a runtime accelerator for machine learning models. (link)
onnx-simplifier This package helps simplify ONNX models. (link)

Run the following command to install these additional libraries:

# Install ONNX packages
pip install onnx onnxruntime onnx-simplifier

Importing the Required Dependencies

With our environment updated, we can dive into the code. First, we will import the necessary Python dependencies into our Jupyter Notebook.

# Import Python Standard Library dependencies
import json
from pathlib import Path
import random

# Import utility functions
from cjm_psl_utils.core import download_file, get_source_code
from cjm_pil_utils.core import resize_img

# Import numpy
import numpy as np

# Import the pandas package
import pandas as pd

# Import PIL for image manipulation
from PIL import Image

# Import timm library
import timm

# Import PyTorch dependencies
import torch
from torch import nn

# Import ONNX dependencies
import onnx # Import the onnx module
from onnxsim import simplify # Import the method to simplify ONNX models
import onnxruntime as ort # Import the ONNX Runtime

Setting Up the Project

In this section, we’ll set the folder locations for our project and training session with the PyTorch checkpoint.

Set the Directory Paths

# The name for the project
project_name = f"pytorch-timm-image-classifier"

# The path for the project folder
project_dir = Path(f"./{project_name}/")

# Create the project directory if it does not already exist
project_dir.mkdir(parents=True, exist_ok=True)

# The path to the checkpoint folder
checkpoint_dir = Path(project_dir/f"2024-02-02_15-41-23")

pd.Series({
    "Project Directory:": project_dir,
    "Checkpoint Directory:": checkpoint_dir,
}).to_frame().style.hide(axis='columns')
Project Directory: pytorch-timm-image-classifier
Checkpoint Directory: pytorch-timm-image-classifier/2024-02-02_15-41-23
Those following along on Google Colab can drag the contents of their checkpoint folder into Colab’s file browser.

Loading the Checkpoint Data

Now, we can load the class labels and initialize a ResNet18-D model with the saved checkpoint.

Load the Class Labels

# The class labels path
class_labels_path = list(checkpoint_dir.glob('*classes.json'))[0]

# Load the JSON class labels data
with open(class_labels_path, 'r') as file:
        class_labels_json = json.load(file)

# Get the list of classes
class_names = class_labels_json['classes']

# Print the list of classes
pd.DataFrame(class_names)
0
0 call
1 dislike
2 fist
3 four
4 like
5 mute
6 no_gesture
7 ok
8 one
9 palm
10 peace
11 peace_inverted
12 rock
13 stop
14 stop_inverted
15 three
16 three2
17 two_up
18 two_up_inverted

Load the Model Checkpoint

# The model checkpoint path
checkpoint_path = list(checkpoint_dir.glob('*.pth'))[0]

# Load the model checkpoint onto the CPU
model_checkpoint = torch.load(checkpoint_path, map_location='cpu')

Load the Finetuned Model

# Specify the model configuration
model_type = checkpoint_path.stem.split(".")[0]

# Create a model with the number of output classes equal to the number of class names
model = timm.create_model(model_type, num_classes=len(class_names))

# Initialize the model with the checkpoint parameters and buffers
model.load_state_dict(model_checkpoint)
<All keys matched successfully>

Get the Normalization Stats

# Import the resnet module
from timm.models import resnet

# Get the default configuration of the chosen model
model_cfg = resnet.default_cfgs[model_type].default.to_dict()

# Retrieve normalization statistics (mean and std) specific to the pretrained model
mean, std = model_cfg['mean'], model_cfg['std']
norm_stats = (mean, std)
norm_stats
((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))

Exporting the Model to ONNX

Before exporting the model, we will wrap it with the preprocessing and post-processing steps. These steps will be included in the ONNX model, reducing the code we need to write when deploying the model to other platforms.

Prepare the Model for Inference

Whenever we make predictions with the model, we must normalize the input data and pass the model output through a Softmax function. We can define a wrapper class that automatically performs these steps.

Additionally, we can include options to scale pixel data from the range [0,255] to [0,1] and set the model to expect input tensors in channels-last format. These settings can be helpful when deploying to platforms where tensor operations are less convenient.

Define model export wrapper

class InferenceWrapper(nn.Module):
    def __init__(self, model, normalize_mean, normalize_std, scale_inp=False, channels_last=False):
        super().__init__()
        self.model = model
        self.register_buffer("normalize_mean", normalize_mean)
        self.register_buffer("normalize_std", normalize_std)
        self.scale_inp = scale_inp
        self.channels_last = channels_last
        self.softmax = nn.Softmax(dim=1)

    def preprocess_input(self, x):
        if self.scale_inp:
            x = x / 255.0

        if self.channels_last:
            x = x.permute(0, 3, 1, 2)

        x = (x - self.normalize_mean) / self.normalize_std
        return x

    def forward(self, x):
        x = self.preprocess_input(x)
        x = self.model(x)
        x = self.softmax(x)
        return x

Wrap model with preprocessing and post-processing steps

# Define the normalization mean and standard deviation
mean_tensor = torch.tensor(norm_stats[0]).view(1, 3, 1, 1)
std_tensor = torch.tensor(norm_stats[1]).view(1, 3, 1, 1)

# Set the model to evaluation mode
model.eval();

# Wrap the model with preprocessing and post-processing steps
wrapped_model = InferenceWrapper(model, 
                                 mean_tensor, 
                                 std_tensor, 
                                 scale_inp=False, # Scale input values from the rang [0,255] to [0,1]
                                 channels_last=False, # Have the model expect input in channels-last format
                                )
# Define the normalization mean and standard deviation
mean_tensor = torch.tensor(norm_stats[0]).view(1, 3, 1, 1)
std_tensor = torch.tensor(norm_stats[1]).view(1, 3, 1, 1)

# Set the model to evaluation mode
model.eval();

# Wrap the model with preprocessing and post-processing steps
wrapped_model = InferenceWrapper(model, 
                                 mean_tensor, 
                                 std_tensor, 
                                 scale_inp=True, # Scale input values from the rang [0,255] to [0,1]
                                 channels_last=True, # Have the model expect input in channels-last format
                                )
Settings for Unity’s Barracuda Inference Library:
scale_inp=False
channels_last=False

Prepare the Input Tensor

We need a sample input tensor for the export process.

input_tensor = torch.randn(1, 3, 256, 256)
input_tensor = torch.randn(1, 256, 256, 3)

Export the Model to ONNX

We can export the model using the torch.onnx.export() function. This function performs a single pass through the model and records all operations to generate a TorchScript graph. It then exports this graph to ONNX by decomposing each graph node (which contains a PyTorch operator) into a series of ONNX operators.

If we want the ONNX model to support different input sizes, we must set the width and height input axes as dynamic.

# Set a filename for the ONNX model
onnx_file_path = f"{checkpoint_dir}/{class_labels_path.stem.removesuffix('-classes')}-{model_type}.onnx"

# Export the PyTorch model to ONNX format
torch.onnx.export(wrapped_model.cpu(),
                  input_tensor.cpu(),
                  onnx_file_path,
                  export_params=True,
                  do_constant_folding=False,
                  input_names = ['input'],
                  output_names = ['output'],
                  dynamic_axes={'input': {2 : 'height', 3 : 'width'}}
                 )
# Set a filename for the ONNX model
onnx_file_path = f"{checkpoint_dir}/{class_labels_path.stem.removesuffix('-classes')}-{model_type}.onnx"

# Export the PyTorch model to ONNX format
torch.onnx.export(wrapped_model.cpu(),
                  input_tensor.cpu(),
                  onnx_file_path,
                  export_params=True,
                  do_constant_folding=False,
                  input_names = ['input'],
                  output_names = ['output'],
                  dynamic_axes={'input': {1 : 'height', 2 : 'width'}}
                 )

Simplify the ONNX Model

The ONNX models generated by PyTorch are not always the most concise. We can use the onnx-simplifier package to tidy up the exported model. We then save the updated ONNX model back to disk for later.

# Load the ONNX model from the onnx_file_name
onnx_model = onnx.load(onnx_file_path)

# Simplify the model
model_simp, check = simplify(onnx_model)

# Save the simplified model to the onnx_file_name
onnx.save(model_simp, onnx_file_path)

Performing Inference with ONNX Runtime

Now that we have our ONNX model, it’s time to test it with ONNX Runtime.

Create an Inference Session

We interact with models in ONNX Runtime through an InferenceSession object. Here, we can specify which Execution Providers to use for inference and other configuration information. Execution Providers are the interfaces for hardware-specific inference engines like TensorRT for NVIDIA and OpenVINO for Intel. By default, the InferenceSession uses the generic CPUExecutionProvider.

# Load the model and create an InferenceSession
session = ort.InferenceSession(onnx_file_path, , providers=['CPUExecutionProvider'])

With our inference session initialized, we can select an image to test our ONNX model.

Select a Test Image

Let’s use the same test image and input size from the previous tutorial to compare the results with the PyTorch model.

test_img_name = 'pexels-elina-volkova-16191659.jpg'
test_img_url = f"https://huggingface.co/datasets/cj-mills/pexel-hand-gesture-test-images/resolve/main/{test_img_name}"

download_file(test_img_url, './', False)

test_img = Image.open(test_img_name)
display(test_img)

pd.Series({
    "Test Image Size:": test_img.size, 
}).to_frame().style.hide(axis='columns')

Test Image Size: (637, 960)

Prepare the Test Image

# Set test image size
test_sz = 288

# Resize image without cropping
input_img = resize_img(test_img.copy(), target_sz=test_sz)

display(input_img)

pd.Series({
    "Input Image Size:": input_img.size
}).to_frame().style.hide(axis='columns')

Input Image Size: (288, 416)

Prepare the Input Tensor

When we convert the PIL input image to a NumPy array, we need to reorder the array values to channels-first format, scale the values from [0,255] to [0,1], and add a batch dimension. When we enable the scale_inp and channels_last options, we only need to add a batch dimension.

# Convert the existing input image to NumPy format
input_tensor_np = np.array(input_img, dtype=np.float32).transpose((2, 0, 1))[None]/255
# Convert the existing input image to NumPy format
input_tensor_np = np.array(input_img, dtype=np.float32)[None]

Compute the Predictions

Now, we can finally perform inference with our ONNX model.

# Run inference
outputs = session.run(None, {"input": input_tensor_np})[0]

# Get the highest confidence score
confidence_score = outputs.max()

# Get the class index with the highest confidence score and convert it to the class name
pred_class = class_names[outputs.argmax()]

# Display the image
display(test_img)

# Store the prediction data in a Pandas Series for easy formatting
pd.Series({
    "Input Size:": input_img.size,
    "Target Class:": target_cls,
    "Predicted Class:": pred_class,
    "Confidence Score:": f"{confidence_score*100:.2f}%"
}).to_frame().style.hide(axis='columns')

Input Size: (288, 416)
Predicted Class: mute
Confidence Score: 100.00%

The model predictions should be virtually identical to the PyTorch model, but the confidence scores can sometimes vary slightly.

Google Colab Users
  1. Don’t forget to download the ONNX model from the Colab Environment’s file browser. (tutorial link)

Conclusion

Congratulations on reaching the end of this tutorial! We previously fine-tuned a model from the timm library in PyTorch for hand gesture classification and now exported that model to ONNX. With this, we can streamline our deployment process and leverage platform-specific hardware optimizations through ONNX Runtime.

As you move forward, consider exploring more about ONNX and its ecosystem. Check out the available Execution Providers that provide flexible interfaces to different hardware acceleration libraries.