Exporting timm Image Classifiers from Fastai to TorchScript
- Introduction
- Getting Started with the Code
- Importing the Required Dependencies
- Loading the Checkpoint Data
- Exporting the Model to TorchScript
- Performing Inference with the TorchScript Module
- Conclusion
Introduction
Welcome back to this series on fine-tuning image classifiers with fastai and the timm library! Previously, we demonstrated how to fine-tune a ResNet18 model from the timm library with fastai by creating a hand gesture classifier. This tutorial builds on that by showing how to export the underlying PyTorch model to TorchScript for seamless deployment across various platforms and optimized inference.
Exporting a PyTorch model to TorchScript offers performance optimization through JIT compilation, ensuring faster and more efficient model execution. This conversion also enhances deployment flexibility by enabling platform independence, especially in environments where Python isn’t available, and ensures consistent behavior across diverse platforms. Additionally, TorchScript provides robustness with static typing and boosts security by decoupling from the Python runtime.
Additionally, we’ll wrap the PyTorch model with the required preprocessing and post-processing steps to include them in the TorchScript module. By the end of this tutorial, you’ll have a deployable TorchScript ResNet18 model compatible with various environments and which contains all crucial processing steps for real-world use.
Getting Started with the Code
As with the previous tutorial, the code is available as a Jupyter Notebook.
Jupyter Notebook | Google Colab |
---|---|
GitHub Repository | Open In Colab |
Importing the Required Dependencies
First, we will import the necessary Python dependencies into our Jupyter Notebook.
# Import Python Standard Library dependencies
import json
from pathlib import Path
import random
# Import utility functions
from cjm_psl_utils.core import download_file
from cjm_pil_utils.core import resize_img
# Import numpy
import numpy as np
# Import the pandas package
import pandas as pd
# Import PIL for image manipulation
from PIL import Image
# Import PyTorch dependencies
import torch
from torch import nn
# Import fastai function to load a saved learner object
from fastai.learner import load_learner
Setting Up the Project
Next, we will set the folder locations for our project and training session with the exported Learner object.
Set the Directory Paths
# The name for the project
= f"fastai-timm-image-classifier"
project_name
# The path for the project folder
= Path(f"./{project_name}/")
project_dir
# Create the project directory if it does not already exist
=True, exist_ok=True)
project_dir.mkdir(parents
# The path to the checkpoint folder
= Path(project_dir/f"2023-10-06_11-17-36")
checkpoint_dir
pd.Series({"Project Directory:": project_dir,
"Checkpoint Directory:": checkpoint_dir,
='columns') }).to_frame().style.hide(axis
Project Directory: | fastai-timm-image-classifier |
---|---|
Checkpoint Directory: | fastai-timm-image-classifier/2023-10-06_11-17-36 |
Loading the Checkpoint Data
Now, we can load the class labels and normalization stats and get the fine-tuned ResNet18-D model from the saved Learner object.
Load the Class Labels
# The class labels path
= list(checkpoint_dir.glob('*classes.json'))[0]
class_labels_path
# Load the JSON class labels data
with open(class_labels_path, 'r') as file:
= json.load(file)
class_labels_json
# Get the list of classes
= class_labels_json['classes']
class_names
# Print the list of classes
pd.DataFrame(class_names)
0 | |
---|---|
0 | call |
1 | dislike |
2 | fist |
3 | four |
4 | like |
5 | mute |
6 | no_gesture |
7 | ok |
8 | one |
9 | palm |
10 | peace |
11 | peace_inverted |
12 | rock |
13 | stop |
14 | stop_inverted |
15 | three |
16 | three2 |
17 | two_up |
18 | two_up_inverted |
Load the Normalization Statistics
# The normalization stats path
= checkpoint_dir/'norm_stats.json'
norm_stats_path
# Read the normalization stats from the JSON file
with open(norm_stats_path, "r") as f:
= json.load(f)
norm_stats_dict
# Convert the dictionary to a tuple
= (norm_stats_dict["mean"], norm_stats_dict["std_dev"])
norm_stats
# Print the mean and standard deviation
pd.DataFrame(norm_stats)
0 | 1 | 2 | |
---|---|---|---|
0 | 0.485 | 0.456 | 0.406 |
1 | 0.229 | 0.224 | 0.225 |
Load the Learner Checkpoint
# The learner checkpoint path
= list(checkpoint_dir.glob('*.pkl'))[0]
checkpoint_path
# Load the learner checkpoint onto the CPU
= load_learner(checkpoint_path) learner
Get the Finetuned Model
# Load the fine-tuned model from the exported Learner object
= load_learner(checkpoint_path).model model
Exporting the Model to TorchScript
Before exporting the model, we will wrap it with the preprocessing and post-processing steps. These steps will be included in the TorchScript module, reducing the code we need to write when deploying the model to other platforms.
Prepare the Model for Inference
Whenever we make predictions with the model, we must normalize the input data and pass the model output through a Softmax function. We can define a wrapper class that automatically performs these steps.
Additionally, we can include options to scale pixel data from the range [0,255] to[0,1] and set the model to expect input tensors in channels-last format. These settings can be helpful when deploying to platforms where tensor operations are less convenient.
Define model export wrapper
class InferenceWrapper(nn.Module):
def __init__(self, model, normalize_mean, normalize_std, scale_inp=False, channels_last=False):
super().__init__()
self.model = model
self.register_buffer("normalize_mean", normalize_mean)
self.register_buffer("normalize_std", normalize_std)
self.scale_inp = scale_inp
self.channels_last = channels_last
self.softmax = nn.Softmax(dim=1)
def preprocess_input(self, x):
if self.scale_inp:
= x / 255.0
x
if self.channels_last:
= x.permute(0, 3, 1, 2)
x
= (x - self.normalize_mean) / self.normalize_std
x return x
def forward(self, x):
= self.preprocess_input(x)
x = self.model(x)
x = self.softmax(x)
x return x
Wrap model with preprocessing and post-processing steps
# Define the normalization mean and standard deviation
= torch.tensor(norm_stats[0]).view(1, 3, 1, 1)
mean_tensor = torch.tensor(norm_stats[1]).view(1, 3, 1, 1)
std_tensor
# Set the model to evaluation mode
eval();
model.
# Wrap the model with preprocessing and post-processing steps
= InferenceWrapper(model,
wrapped_model
mean_tensor,
std_tensor, =False, # Scale input values from the rang [0,255] to [0,1]
scale_inp=False, # Have the model expect input in channels-last format
channels_last )
Prepare the Input Tensor
= torch.randn(1, 3, 256, 256) input_tensor
Export the Model to TorchScript
We can export the model using the torch.jit.trace()
function. This function performs a single pass through the model and records all operations to create a static computation graph representing the model’s forward pass.
# Set a filename for the TorchScript module
= f"{checkpoint_dir}/{class_labels_path.stem.removesuffix('-classes')}-{checkpoint_path.stem}.pt"
torchscript_file_path
= torch.jit.trace(wrapped_model.cpu(), input_tensor)
traced_script_module traced_script_module.save(torchscript_file_path)
Performing Inference with the TorchScript Module
Now that we have our TorchScript module, it’s time to compare its performance with the original PyTorch model.
Load the TorchScript Module
We can load the saved TorchScript module using the torch.jit.load()
function.
# Load the TorchScript module
= torch.jit.load(torchscript_file_path) traced_script_module
Select a Test Image
Let’s use the same test image and input size from the previous tutorial to compare the results with the PyTorch model.
= "pexels-elina-volkova-16191659.jpg"
test_img_name = f"https://huggingface.co/datasets/cj-mills/pexel-hand-gesture-test-images/resolve/main/{test_img_name}"
test_img_url
'./', False)
download_file(test_img_url,
= Image.open(test_img_name)
test_img
display(test_img)
= "mute"
target_cls
pd.Series({"Test Image Size:": test_img.size,
"Target Class:": target_cls
='columns') }).to_frame().style.hide(axis
Test Image Size: | (637, 960) |
---|---|
Target Class: | mute |
Prepare the Test Image
# Set test image size
= 288
test_sz
# Resize image without cropping to multiple of the max stride
= resize_img(test_img.copy(), target_sz=test_sz)
input_img
display(input_img)
pd.Series({"Input Image Size:": input_img.size
='columns') }).to_frame().style.hide(axis
Input Image Size: | (288, 416) |
---|
Prepare the Input Tensor
# Convert the existing input image to a PyTorch tensor
= torch.Tensor(np.array(input_img, dtype=np.float32)).permute(2,0,1)[None]/255 input_tensor
Compute the Predictions
Now, we can see how the TorchScript module compares with the PyTorch model.
# Run inference
# outputs = session.run(None, {"input": input_tensor_np})[0]
with torch.no_grad():
= traced_script_module(input_tensor)
outputs
# Get the highest confidence score
= outputs.max()
confidence_score
# Get the class index with the highest confidence score and convert it to the class name
= class_names[outputs.argmax()]
pred_class
# Display the image
display(test_img)
# Store the prediction data in a Pandas Series for easy formatting
pd.Series({"Input Size:": input_img.size,
"Target Class:": target_cls,
"Predicted Class:": pred_class,
"Confidence Score:": f"{confidence_score*100:.2f}%"
='columns') }).to_frame().style.hide(axis
Input Size: | (288, 416) |
---|---|
Target Class: | mute |
Predicted Class: | mute |
Confidence Score: | 99.99% |
The model predictions should be virtually identical to the PyTorch model.
- Don’t forget to download the TorchScript module from the Colab Environment’s file browser. (tutorial link)
Conclusion
Congratulations on reaching the end of this tutorial! We previously fine-tuned a model from the timm library with fastai for hand gesture classification and now exported the underlying PyTorch model to TorchScript. With this conversion, you can deploy your model seamlessly across diverse platforms, ensuring optimized performance and enhanced portability for real-world applications.
If you found this guide helpful, consider sharing it with others.
- I’m Christian Mills, a deep learning consultant specializing in computer vision and practical AI implementations.
- I help clients leverage cutting-edge AI technologies to solve real-world problems.
- Learn more about me or reach out via email at [email protected] to discuss your project.