How to Create a LibTorch Plugin for Unity on Windows Pt. 1

fastai
libtorch
unity
tutorial
Modify the training code from the fastai-to-unity tutorial to export the model to a TorchScript module.
Author

Christian Mills

Published

June 28, 2022

Introduction

The previous fastai-to-unity tutorial series implemented a ResNet-based image classifier in Unity with the Barracuda inference library. The Barracuda library works well with the older ResNet architecture but does not support more recent ones like ConvNeXt and MobileViT at the time of writing.

This follow-up series covers using LibTorch, the C++ distribution of PyTorch, to perform inference with these newer model architectures. We’ll modify the original tutorial code and create a dynamic link library (DLL) file to access the LibTorch functionality in Unity.

Overview

This post covers the required modifications to the original training code. We’ll finetune models from the Timm library on the same ASL dataset as the original tutorial. The Timm library provides access to a wide range of pretrained computer vision models and integrates with the fastai library. Below is a link to the complete modified training code, along with links for running the notebook on Google Colab and Kaggle.

GitHub Repository Colab Kaggle
Jupyter Notebook Open in Colab Open in Kaggle

Install Dependencies

The pip package for the Timm library is more stable than the GitHub repository but has fewer model types and pretrained weights. For example, the pip package has pretrained ConvNeXt models but no MobileViT models. However, the latest GitHub version had some issues running the MobileNetV3 models at the time of writing.

Recent updates to the fastai library resolve some performance issues with PyTorch so let’s update that too. They also provide a new ChannelsLast (beta) callback that further improves performance on modern GPUs.

Uncomment the cell below if running on Google Colab or Kaggle

# %%capture
# !pip3 install -U torch torchvision torchaudio
# !pip3 install -U fastai==2.7.2
# !pip3 install -U kaggle==1.5.12
# !pip3 install -U Pillow==9.1.0
# !pip3 install -U timm==0.5.4 # more stable fewer models
# !pip3 install -U git+https://github.com/rwightman/pytorch-image-models.git # more models less stable

Note for Colab: You must restart the runtime in order to use newly installed version of Pillow.

Import all fastai computer vision functionality

from fastai.vision.all import *

Disable max rows and columns for pandas

import pandas as pd
pd.set_option('max_colwidth', None)
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)

Select a Model

Let’s start by selecting a model from the Timm library to finetune. The available pretrained models depend on the version of the Timm library installed.

Import the Timm library

import timm
timm.__version__
'0.6.2.dev0'

Check available pretrained model types

We can check which model types have pretrained weights using the timm.list_models() function.

model_types = list(set([model.split('_')[0] for model in timm.list_models(pretrained=True)]))
model_types.sort()
pd.DataFrame(model_types)
0
0 adv
1 bat
2 beit
3 botnet26t
4 cait
5 coat
6 convit
7 convmixer
8 convnext
9 crossvit
10 cspdarknet53
11 cspresnet50
12 cspresnext50
13 deit
14 densenet121
15 densenet161
16 densenet169
17 densenet201
18 densenetblur121d
19 dla102
20 dla102x
21 dla102x2
22 dla169
23 dla34
24 dla46
25 dla46x
26 dla60
27 dla60x
28 dm
29 dpn107
30 dpn131
31 dpn68
32 dpn68b
33 dpn92
34 dpn98
35 eca
36 ecaresnet101d
37 ecaresnet269d
38 ecaresnet26t
39 ecaresnet50d
40 ecaresnet50t
41 ecaresnetlight
42 efficientnet
43 efficientnetv2
44 ens
45 ese
46 fbnetc
47 fbnetv3
48 gc
49 gcresnet33ts
50 gcresnet50t
51 gcresnext26ts
52 gcresnext50ts
53 gernet
54 ghostnet
55 gluon
56 gmixer
57 gmlp
58 halo2botnet50ts
59 halonet26t
60 halonet50ts
61 haloregnetz
62 hardcorenas
63 hrnet
64 ig
65 inception
66 jx
67 lambda
68 lamhalobotnet50ts
69 lcnet
70 legacy
71 levit
72 mixer
73 mixnet
74 mnasnet
75 mobilenetv2
76 mobilenetv3
77 mobilevit
78 nasnetalarge
79 nf
80 nfnet
81 pit
82 pnasnet5large
83 poolformer
84 regnetv
85 regnetx
86 regnety
87 regnetz
88 repvgg
89 res2net101
90 res2net50
91 res2next50
92 resmlp
93 resnest101e
94 resnest14d
95 resnest200e
96 resnest269e
97 resnest26d
98 resnest50d
99 resnet101
100 resnet101d
101 resnet152
102 resnet152d
103 resnet18
104 resnet18d
105 resnet200d
106 resnet26
107 resnet26d
108 resnet26t
109 resnet32ts
110 resnet33ts
111 resnet34
112 resnet34d
113 resnet50
114 resnet50d
115 resnet51q
116 resnet61q
117 resnetblur50
118 resnetrs101
119 resnetrs152
120 resnetrs200
121 resnetrs270
122 resnetrs350
123 resnetrs420
124 resnetrs50
125 resnetv2
126 resnext101
127 resnext26ts
128 resnext50
129 resnext50d
130 rexnet
131 sebotnet33ts
132 sehalonet33ts
133 selecsls42b
134 selecsls60
135 selecsls60b
136 semnasnet
137 sequencer2d
138 seresnet152d
139 seresnet33ts
140 seresnet50
141 seresnext101
142 seresnext101d
143 seresnext26d
144 seresnext26t
145 seresnext26ts
146 seresnext50
147 seresnextaa101d
148 skresnet18
149 skresnet34
150 skresnext50
151 spnasnet
152 ssl
153 swin
154 swinv2
155 swsl
156 tf
157 tinynet
158 tnt
159 tresnet
160 tv
161 twins
162 vgg11
163 vgg13
164 vgg16
165 vgg19
166 visformer
167 vit
168 volo
169 wide
170 xception
171 xception41
172 xception41p
173 xception65
174 xception65p
175 xception71
176 xcit

Timm provides many pretrained models, but not all of them are fast enough for real-time applications. We can filter the results by providing a full or partial model name.

Check available pretrained ConvNeXt models

pd.DataFrame(timm.list_models('convnext*', pretrained=True))
0
0 convnext_base
1 convnext_base_384_in22ft1k
2 convnext_base_in22ft1k
3 convnext_base_in22k
4 convnext_large
5 convnext_large_384_in22ft1k
6 convnext_large_in22ft1k
7 convnext_large_in22k
8 convnext_small
9 convnext_small_384_in22ft1k
10 convnext_small_in22ft1k
11 convnext_small_in22k
12 convnext_tiny
13 convnext_tiny_384_in22ft1k
14 convnext_tiny_hnf
15 convnext_tiny_in22ft1k
16 convnext_tiny_in22k
17 convnext_xlarge_384_in22ft1k
18 convnext_xlarge_in22ft1k
19 convnext_xlarge_in22k

Let’s go with the convnext_tiny model since we want higher framerates. Each model comes with a set of default configuration parameters. We must keep track of the mean and std values used to normalize the model input. Many pretrained models use the ImageNet normalization stats, but others like MobileViT do not.

Inspect the default configuration for the convnext_tiny model

from timm.models import convnext
convnext_model = 'convnext_tiny'
pd.DataFrame.from_dict(convnext.default_cfgs[convnext_model], orient='index')
0
url https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth
num_classes 1000
input_size (3, 224, 224)
pool_size (7, 7)
crop_pct 0.875
interpolation bicubic
mean (0.485, 0.456, 0.406)
std (0.229, 0.224, 0.225)
first_conv stem.0
classifier head.fc

Check available pretrained MobileNetV2 models

pd.DataFrame(timm.list_models('mobilenetv2*', pretrained=True))
0
0 mobilenetv2_050
1 mobilenetv2_100
2 mobilenetv2_110d
3 mobilenetv2_120d
4 mobilenetv2_140

Inspect the default configuration for the mobilenetv2_050 model

from timm.models import efficientnet
mobilenetv2_model = 'mobilenetv2_050'
pd.DataFrame.from_dict(efficientnet.default_cfgs[mobilenetv2_model], orient='index')
0
url https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_050-3d30d450.pth
num_classes 1000
input_size (3, 224, 224)
pool_size (7, 7)
crop_pct 0.875
interpolation bicubic
mean (0.485, 0.456, 0.406)
std (0.229, 0.224, 0.225)
first_conv conv_stem
classifier classifier

Check available pretrained MobileNetV3 models

pd.DataFrame(timm.list_models('mobilenetv3*', pretrained=True))
0
0 mobilenetv3_large_100
1 mobilenetv3_large_100_miil
2 mobilenetv3_large_100_miil_in21k
3 mobilenetv3_rw
4 mobilenetv3_small_050
5 mobilenetv3_small_075
6 mobilenetv3_small_100

Inspect the default configuration for the mobilenetv3_small_050 model

from timm.models import mobilenetv3
mobilenetv3_model = 'mobilenetv3_small_050'
pd.DataFrame.from_dict(mobilenetv3.default_cfgs[mobilenetv3_model], orient='index')
0
url https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_050_lambc-4b7bbe87.pth
num_classes 1000
input_size (3, 224, 224)
pool_size (7, 7)
crop_pct 0.875
interpolation bicubic
mean (0.485, 0.456, 0.406)
std (0.229, 0.224, 0.225)
first_conv conv_stem
classifier classifier

Check available pretrained MobileViT models * Note: MobileViT models are not available in timm 0.5.4

pd.DataFrame(timm.list_models('mobilevit*', pretrained=True))
0
0 mobilevit_s
1 mobilevit_xs
2 mobilevit_xxs

Inspect the default configuration for the mobilevit_xxs model

from timm.models import mobilevit
mobilevit_model = 'mobilevit_xxs'
pd.DataFrame.from_dict(mobilevit.default_cfgs[mobilevit_model], orient='index')
0
url https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevit_xxs-ad385b40.pth
num_classes 1000
input_size (3, 256, 256)
pool_size (8, 8)
crop_pct 0.9
interpolation bicubic
mean (0, 0, 0)
std (1, 1, 1)
first_conv stem.conv
classifier head.fc
fixed_input_size False

Select a model

model_type = convnext
model_name = convnext_model
# model_type = efficientnet
# model_name = mobilenetv2_model
# model_type = mobilenetv3
# model_name = mobilenetv3_model
# model_type = mobilevit
# model_name = mobilevit_model

After picking a model, we’ll store the related normalization stats for future use.

Store normalization stats

mean = model_type.default_cfgs[model_name]['mean']
std = model_type.default_cfgs[model_name]['std']
mean, std
((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))

Define target input dimensions

# size_1_1 = (224, 224)
# size_3_2 = (224, 336)
# size_4_3 = (216, 288)
size_16_9 = (216, 384)
# size_16_9_l = (288, 512)
input_dims = size_16_9

Modify Transforms

We can apply the normalization stats at the end of the batch transforms.

item_tfms = [FlipItem(p=1.0), Resize(input_dims, method=ResizeMethod.Pad, pad_mode=PadMode.Border)]

batch_tfms = [
    Contrast(max_lighting=0.25),
    Saturation(max_lighting=0.25),
    Hue(max_hue=0.05),
    *aug_transforms(
        size=input_dims, 
        mult=1.0,
        do_flip=False,
        flip_vert=False,
        max_rotate=0.0,
        min_zoom=0.5,
        max_zoom=1.5,
        max_lighting=0.5,
        max_warp=0.2, 
        p_affine=0.0,
        pad_mode=PadMode.Border),
    Normalize.from_stats(mean=mean, std=std)
]

Define Learner

The training process is identical to the original tutorial, and we only need to pass the name of the Timm model to the vision_learner object.

learn = vision_learner(dls, model_name, metrics=metrics).to_fp16()
# learn = vision_learner(dls, model_name, metrics=metrics, cbs=[ChannelsLast]).to_fp16()

Export the Model

Once training completes, we need to convert our trained PyTorch model to a TorchScript module for use in LibTorch. We do so using the torch.jit.trace() method.

Generate a TorchScript module using the test image

traced_script_module = torch.jit.trace(learn.model.cpu(), batched_tensor)

We can perform inference with the TorchScript module the same way we would a PyTorch model.

Verify the TorchScript module’s accuracy

with torch.no_grad():
    torchscript_preds = traced_script_module(batched_tensor)
learn.dls.vocab[torch.nn.functional.softmax(torchscript_preds, dim=1).argmax()]
'J'

Define TorchScript file name

module_file_name = f"{dataset_path.name}-{learn.arch}.pt"
module_file_name
'asl-and-some-words-convnext_tiny.pt'

Some models like MobileViT will require the exact input dimensions in LibTorch as was used in the torch.jit.trace() method. Therefore we’ll convert the PyTorch model again using the training dimensions before saving the TorchScript module to a file.

Generate a torchscript module using the target input dimensions and save it to a file

torch.randn(1, 3, *input_dims).shape
torch.Size([1, 3, 216, 384])
traced_script_module = torch.jit.trace(learn.model.cpu(), torch.randn(1, 3, *input_dims))
traced_script_module.save(module_file_name)

We can export the normalization stats to a JSON file using the same method for the class labels. We’ll load the stats in Unity and pass them to the LibTorch plugin.

Export model normalization stats

normalization_stats = {"mean": list(mean), "std": list(std)}
normalization_stats_file_name = f"{learn.arch}-normalization_stats.json"

with open(normalization_stats_file_name, "w") as write_file:
    json.dump(normalization_stats, write_file)

Summary

This post covered how to modify the training code from the fastai-to-unity tutorialto finetune models from the Timm library and export them as TorchScript modules. Part 2 will cover creating a dynamic link library (DLL) file in Visual Studio to perform inference with these TorchScript modules using LibTorch.

Previous: Fastai to Unity Tutorial Pt. 3

Next: How to Create a LibTorch Plugin for Unity on Windows Pt.2

Project Resources: GitHub Repository