How to Create a LibTorch Plugin for Unity on Windows Pt. 1
- Introduction
- Overview
- Install Dependencies
- Select a Model
- Modify Transforms
- Define Learner
- Export the Model
- Summary
Introduction
The previous fastai-to-unity tutorial series implemented a ResNet-based image classifier in Unity with the Barracuda inference library. The Barracuda library works well with the older ResNet architecture but does not support more recent ones like ConvNeXt and MobileViT at the time of writing.
This follow-up series covers using LibTorch, the C++ distribution of PyTorch, to perform inference with these newer model architectures. We’ll modify the original tutorial code and create a dynamic link library (DLL) file to access the LibTorch functionality in Unity.
Overview
This post covers the required modifications to the original training code. We’ll finetune models from the Timm library on the same ASL dataset as the original tutorial. The Timm library provides access to a wide range of pretrained computer vision models and integrates with the fastai library. Below is a link to the complete modified training code, along with links for running the notebook on Google Colab and Kaggle.
GitHub Repository | Colab | Kaggle |
---|---|---|
Jupyter Notebook | Open in Colab | Open in Kaggle |
Install Dependencies
The pip package for the Timm library is more stable than the GitHub repository but has fewer model types and pretrained weights. For example, the pip package has pretrained ConvNeXt models but no MobileViT models. However, the latest GitHub version had some issues running the MobileNetV3 models at the time of writing.
Recent updates to the fastai library resolve some performance issues with PyTorch so let’s update that too. They also provide a new ChannelsLast
(beta) callback that further improves performance on modern GPUs.
Uncomment the cell below if running on Google Colab or Kaggle
# %%capture
# !pip3 install -U torch torchvision torchaudio
# !pip3 install -U fastai==2.7.2
# !pip3 install -U kaggle==1.5.12
# !pip3 install -U Pillow==9.1.0
# !pip3 install -U timm==0.5.4 # more stable fewer models
# !pip3 install -U git+https://github.com/rwightman/pytorch-image-models.git # more models less stable
Note for Colab: You must restart the runtime in order to use newly installed version of Pillow.
Import all fastai computer vision functionality
from fastai.vision.all import *
Disable max rows and columns for pandas
import pandas as pd
'max_colwidth', None)
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None) pd.set_option(
Select a Model
Let’s start by selecting a model from the Timm library to finetune. The available pretrained models depend on the version of the Timm library installed.
Import the Timm library
import timm
timm.__version__
'0.6.2.dev0'
Check available pretrained model types
We can check which model types have pretrained weights using the timm.list_models()
function.
= list(set([model.split('_')[0] for model in timm.list_models(pretrained=True)]))
model_types
model_types.sort() pd.DataFrame(model_types)
0 | |
---|---|
0 | adv |
1 | bat |
2 | beit |
3 | botnet26t |
4 | cait |
5 | coat |
6 | convit |
7 | convmixer |
8 | convnext |
9 | crossvit |
10 | cspdarknet53 |
11 | cspresnet50 |
12 | cspresnext50 |
13 | deit |
14 | densenet121 |
15 | densenet161 |
16 | densenet169 |
17 | densenet201 |
18 | densenetblur121d |
19 | dla102 |
20 | dla102x |
21 | dla102x2 |
22 | dla169 |
23 | dla34 |
24 | dla46 |
25 | dla46x |
26 | dla60 |
27 | dla60x |
28 | dm |
29 | dpn107 |
30 | dpn131 |
31 | dpn68 |
32 | dpn68b |
33 | dpn92 |
34 | dpn98 |
35 | eca |
36 | ecaresnet101d |
37 | ecaresnet269d |
38 | ecaresnet26t |
39 | ecaresnet50d |
40 | ecaresnet50t |
41 | ecaresnetlight |
42 | efficientnet |
43 | efficientnetv2 |
44 | ens |
45 | ese |
46 | fbnetc |
47 | fbnetv3 |
48 | gc |
49 | gcresnet33ts |
50 | gcresnet50t |
51 | gcresnext26ts |
52 | gcresnext50ts |
53 | gernet |
54 | ghostnet |
55 | gluon |
56 | gmixer |
57 | gmlp |
58 | halo2botnet50ts |
59 | halonet26t |
60 | halonet50ts |
61 | haloregnetz |
62 | hardcorenas |
63 | hrnet |
64 | ig |
65 | inception |
66 | jx |
67 | lambda |
68 | lamhalobotnet50ts |
69 | lcnet |
70 | legacy |
71 | levit |
72 | mixer |
73 | mixnet |
74 | mnasnet |
75 | mobilenetv2 |
76 | mobilenetv3 |
77 | mobilevit |
78 | nasnetalarge |
79 | nf |
80 | nfnet |
81 | pit |
82 | pnasnet5large |
83 | poolformer |
84 | regnetv |
85 | regnetx |
86 | regnety |
87 | regnetz |
88 | repvgg |
89 | res2net101 |
90 | res2net50 |
91 | res2next50 |
92 | resmlp |
93 | resnest101e |
94 | resnest14d |
95 | resnest200e |
96 | resnest269e |
97 | resnest26d |
98 | resnest50d |
99 | resnet101 |
100 | resnet101d |
101 | resnet152 |
102 | resnet152d |
103 | resnet18 |
104 | resnet18d |
105 | resnet200d |
106 | resnet26 |
107 | resnet26d |
108 | resnet26t |
109 | resnet32ts |
110 | resnet33ts |
111 | resnet34 |
112 | resnet34d |
113 | resnet50 |
114 | resnet50d |
115 | resnet51q |
116 | resnet61q |
117 | resnetblur50 |
118 | resnetrs101 |
119 | resnetrs152 |
120 | resnetrs200 |
121 | resnetrs270 |
122 | resnetrs350 |
123 | resnetrs420 |
124 | resnetrs50 |
125 | resnetv2 |
126 | resnext101 |
127 | resnext26ts |
128 | resnext50 |
129 | resnext50d |
130 | rexnet |
131 | sebotnet33ts |
132 | sehalonet33ts |
133 | selecsls42b |
134 | selecsls60 |
135 | selecsls60b |
136 | semnasnet |
137 | sequencer2d |
138 | seresnet152d |
139 | seresnet33ts |
140 | seresnet50 |
141 | seresnext101 |
142 | seresnext101d |
143 | seresnext26d |
144 | seresnext26t |
145 | seresnext26ts |
146 | seresnext50 |
147 | seresnextaa101d |
148 | skresnet18 |
149 | skresnet34 |
150 | skresnext50 |
151 | spnasnet |
152 | ssl |
153 | swin |
154 | swinv2 |
155 | swsl |
156 | tf |
157 | tinynet |
158 | tnt |
159 | tresnet |
160 | tv |
161 | twins |
162 | vgg11 |
163 | vgg13 |
164 | vgg16 |
165 | vgg19 |
166 | visformer |
167 | vit |
168 | volo |
169 | wide |
170 | xception |
171 | xception41 |
172 | xception41p |
173 | xception65 |
174 | xception65p |
175 | xception71 |
176 | xcit |
Timm provides many pretrained models, but not all of them are fast enough for real-time applications. We can filter the results by providing a full or partial model name.
Check available pretrained ConvNeXt models
'convnext*', pretrained=True)) pd.DataFrame(timm.list_models(
0 | |
---|---|
0 | convnext_base |
1 | convnext_base_384_in22ft1k |
2 | convnext_base_in22ft1k |
3 | convnext_base_in22k |
4 | convnext_large |
5 | convnext_large_384_in22ft1k |
6 | convnext_large_in22ft1k |
7 | convnext_large_in22k |
8 | convnext_small |
9 | convnext_small_384_in22ft1k |
10 | convnext_small_in22ft1k |
11 | convnext_small_in22k |
12 | convnext_tiny |
13 | convnext_tiny_384_in22ft1k |
14 | convnext_tiny_hnf |
15 | convnext_tiny_in22ft1k |
16 | convnext_tiny_in22k |
17 | convnext_xlarge_384_in22ft1k |
18 | convnext_xlarge_in22ft1k |
19 | convnext_xlarge_in22k |
Let’s go with the convnext_tiny
model since we want higher framerates. Each model comes with a set of default configuration parameters. We must keep track of the mean and std values used to normalize the model input. Many pretrained models use the ImageNet normalization stats, but others like MobileViT do not.
Inspect the default configuration for the convnext_tiny
model
from timm.models import convnext
= 'convnext_tiny'
convnext_model ='index') pd.DataFrame.from_dict(convnext.default_cfgs[convnext_model], orient
0 | |
---|---|
url | https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth |
num_classes | 1000 |
input_size | (3, 224, 224) |
pool_size | (7, 7) |
crop_pct | 0.875 |
interpolation | bicubic |
mean | (0.485, 0.456, 0.406) |
std | (0.229, 0.224, 0.225) |
first_conv | stem.0 |
classifier | head.fc |
Check available pretrained MobileNetV2 models
'mobilenetv2*', pretrained=True)) pd.DataFrame(timm.list_models(
0 | |
---|---|
0 | mobilenetv2_050 |
1 | mobilenetv2_100 |
2 | mobilenetv2_110d |
3 | mobilenetv2_120d |
4 | mobilenetv2_140 |
Inspect the default configuration for the mobilenetv2_050
model
from timm.models import efficientnet
= 'mobilenetv2_050'
mobilenetv2_model ='index') pd.DataFrame.from_dict(efficientnet.default_cfgs[mobilenetv2_model], orient
0 | |
---|---|
url | https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_050-3d30d450.pth |
num_classes | 1000 |
input_size | (3, 224, 224) |
pool_size | (7, 7) |
crop_pct | 0.875 |
interpolation | bicubic |
mean | (0.485, 0.456, 0.406) |
std | (0.229, 0.224, 0.225) |
first_conv | conv_stem |
classifier | classifier |
Check available pretrained MobileNetV3 models
'mobilenetv3*', pretrained=True)) pd.DataFrame(timm.list_models(
0 | |
---|---|
0 | mobilenetv3_large_100 |
1 | mobilenetv3_large_100_miil |
2 | mobilenetv3_large_100_miil_in21k |
3 | mobilenetv3_rw |
4 | mobilenetv3_small_050 |
5 | mobilenetv3_small_075 |
6 | mobilenetv3_small_100 |
Inspect the default configuration for the mobilenetv3_small_050
model
from timm.models import mobilenetv3
= 'mobilenetv3_small_050'
mobilenetv3_model ='index') pd.DataFrame.from_dict(mobilenetv3.default_cfgs[mobilenetv3_model], orient
0 | |
---|---|
url | https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_050_lambc-4b7bbe87.pth |
num_classes | 1000 |
input_size | (3, 224, 224) |
pool_size | (7, 7) |
crop_pct | 0.875 |
interpolation | bicubic |
mean | (0.485, 0.456, 0.406) |
std | (0.229, 0.224, 0.225) |
first_conv | conv_stem |
classifier | classifier |
Check available pretrained MobileViT models * Note: MobileViT models are not available in timm 0.5.4
'mobilevit*', pretrained=True)) pd.DataFrame(timm.list_models(
0 | |
---|---|
0 | mobilevit_s |
1 | mobilevit_xs |
2 | mobilevit_xxs |
Inspect the default configuration for the mobilevit_xxs
model
from timm.models import mobilevit
= 'mobilevit_xxs'
mobilevit_model ='index') pd.DataFrame.from_dict(mobilevit.default_cfgs[mobilevit_model], orient
0 | |
---|---|
url | https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevit_xxs-ad385b40.pth |
num_classes | 1000 |
input_size | (3, 256, 256) |
pool_size | (8, 8) |
crop_pct | 0.9 |
interpolation | bicubic |
mean | (0, 0, 0) |
std | (1, 1, 1) |
first_conv | stem.conv |
classifier | head.fc |
fixed_input_size | False |
Select a model
= convnext
model_type = convnext_model
model_name # model_type = efficientnet
# model_name = mobilenetv2_model
# model_type = mobilenetv3
# model_name = mobilenetv3_model
# model_type = mobilevit
# model_name = mobilevit_model
After picking a model, we’ll store the related normalization stats for future use.
Store normalization stats
= model_type.default_cfgs[model_name]['mean']
mean = model_type.default_cfgs[model_name]['std']
std mean, std
((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
Define target input dimensions
# size_1_1 = (224, 224)
# size_3_2 = (224, 336)
# size_4_3 = (216, 288)
= (216, 384)
size_16_9 # size_16_9_l = (288, 512)
= size_16_9 input_dims
Modify Transforms
We can apply the normalization stats at the end of the batch transforms.
= [FlipItem(p=1.0), Resize(input_dims, method=ResizeMethod.Pad, pad_mode=PadMode.Border)]
item_tfms
= [
batch_tfms =0.25),
Contrast(max_lighting=0.25),
Saturation(max_lighting=0.05),
Hue(max_hue*aug_transforms(
=input_dims,
size=1.0,
mult=False,
do_flip=False,
flip_vert=0.0,
max_rotate=0.5,
min_zoom=1.5,
max_zoom=0.5,
max_lighting=0.2,
max_warp=0.0,
p_affine=PadMode.Border),
pad_mode=mean, std=std)
Normalize.from_stats(mean ]
Define Learner
The training process is identical to the original tutorial, and we only need to pass the name of the Timm model to the vision_learner
object.
= vision_learner(dls, model_name, metrics=metrics).to_fp16()
learn # learn = vision_learner(dls, model_name, metrics=metrics, cbs=[ChannelsLast]).to_fp16()
Export the Model
Once training completes, we need to convert our trained PyTorch model to a TorchScript module for use in LibTorch. We do so using the torch.jit.trace()
method.
Generate a TorchScript module using the test image
= torch.jit.trace(learn.model.cpu(), batched_tensor) traced_script_module
We can perform inference with the TorchScript module the same way we would a PyTorch model.
Verify the TorchScript module’s accuracy
with torch.no_grad():
= traced_script_module(batched_tensor)
torchscript_preds =1).argmax()] learn.dls.vocab[torch.nn.functional.softmax(torchscript_preds, dim
'J'
Define TorchScript file name
= f"{dataset_path.name}-{learn.arch}.pt"
module_file_name module_file_name
'asl-and-some-words-convnext_tiny.pt'
Some models like MobileViT will require the exact input dimensions in LibTorch as was used in the torch.jit.trace()
method. Therefore we’ll convert the PyTorch model again using the training dimensions before saving the TorchScript module to a file.
Generate a torchscript module using the target input dimensions and save it to a file
1, 3, *input_dims).shape torch.randn(
torch.Size([1, 3, 216, 384])
= torch.jit.trace(learn.model.cpu(), torch.randn(1, 3, *input_dims))
traced_script_module traced_script_module.save(module_file_name)
We can export the normalization stats to a JSON file using the same method for the class labels. We’ll load the stats in Unity and pass them to the LibTorch plugin.
Export model normalization stats
= {"mean": list(mean), "std": list(std)}
normalization_stats = f"{learn.arch}-normalization_stats.json"
normalization_stats_file_name
with open(normalization_stats_file_name, "w") as write_file:
json.dump(normalization_stats, write_file)
Summary
This post covered how to modify the training code from the fastai-to-unity tutorialto finetune models from the Timm library and export them as TorchScript modules. Part 2 will cover creating a dynamic link library (DLL) file in Visual Studio to perform inference with these TorchScript modules using LibTorch.
Previous: Fastai to Unity Tutorial Pt. 3
Next: How to Create a LibTorch Plugin for Unity on Windows Pt.2
Project Resources: GitHub Repository
I’m Christian Mills, a deep learning consultant specializing in practical AI implementations. I help clients leverage cutting-edge AI technologies to solve real-world problems.
Interested in working together? Fill out my Quick AI Project Assessment form or learn more about me.