# How to Create a LibTorch Plugin for Unity on Windows Pt. 1

- Introduction
- Overview
- Install Dependencies
- Select a Model
- Modify Transforms
- Define Learner
- Export the Model
- Summary

## Introduction

The previous fastai-to-unity tutorial series implemented a ResNet-based image classifier in Unity with the Barracuda inference library. The Barracuda library works well with the older ResNet architecture but does not support more recent ones like ConvNeXt and MobileViT at the time of writing.

This follow-up series covers using LibTorch, the C++ distribution of PyTorch, to perform inference with these newer model architectures. We’ll modify the original tutorial code and create a dynamic link library (DLL) file to access the LibTorch functionality in Unity.

## Overview

This post covers the required modifications to the original training code. We’ll finetune models from the Timm library on the same ASL dataset as the original tutorial. The Timm library provides access to a wide range of pretrained computer vision models and integrates with the fastai library. Below is a link to the complete modified training code, along with links for running the notebook on Google Colab and Kaggle.

GitHub Repository | Colab | Kaggle |
---|---|---|

Jupyter Notebook | Open in Colab | Open in Kaggle |

## Install Dependencies

The pip package for the Timm library is more stable than the GitHub repository but has fewer model types and pretrained weights. For example, the pip package has pretrained ConvNeXt models but no MobileViT models. However, the latest GitHub version had some issues running the MobileNetV3 models at the time of writing.

Recent updates to the fastai library resolve some performance issues with PyTorch so let’s update that too. They also provide a new `ChannelsLast`

(beta) callback that further improves performance on modern GPUs.

**Uncomment the cell below if running on Google Colab or Kaggle**

```
# %%capture
# !pip3 install -U torch torchvision torchaudio
# !pip3 install -U fastai==2.7.2
# !pip3 install -U kaggle==1.5.12
# !pip3 install -U Pillow==9.1.0
# !pip3 install -U timm==0.5.4 # more stable fewer models
# !pip3 install -U git+https://github.com/rwightman/pytorch-image-models.git # more models less stable
```

**Note for Colab:** You must restart the runtime in order to use newly installed version of Pillow.

**Import all fastai computer vision functionality**

`from fastai.vision.all import *`

**Disable max rows and columns for pandas**

```
import pandas as pd
'max_colwidth', None)
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None) pd.set_option(
```

## Select a Model

Let’s start by selecting a model from the Timm library to finetune. The available pretrained models depend on the version of the Timm library installed.

**Import the Timm library**

`import timm`

` timm.__version__`

`'0.6.2.dev0'`

**Check available pretrained model types**

We can check which model types have pretrained weights using the `timm.list_models()`

function.

```
= list(set([model.split('_')[0] for model in timm.list_models(pretrained=True)]))
model_types
model_types.sort() pd.DataFrame(model_types)
```

0 | |
---|---|

0 | adv |

1 | bat |

2 | beit |

3 | botnet26t |

4 | cait |

5 | coat |

6 | convit |

7 | convmixer |

8 | convnext |

9 | crossvit |

10 | cspdarknet53 |

11 | cspresnet50 |

12 | cspresnext50 |

13 | deit |

14 | densenet121 |

15 | densenet161 |

16 | densenet169 |

17 | densenet201 |

18 | densenetblur121d |

19 | dla102 |

20 | dla102x |

21 | dla102x2 |

22 | dla169 |

23 | dla34 |

24 | dla46 |

25 | dla46x |

26 | dla60 |

27 | dla60x |

28 | dm |

29 | dpn107 |

30 | dpn131 |

31 | dpn68 |

32 | dpn68b |

33 | dpn92 |

34 | dpn98 |

35 | eca |

36 | ecaresnet101d |

37 | ecaresnet269d |

38 | ecaresnet26t |

39 | ecaresnet50d |

40 | ecaresnet50t |

41 | ecaresnetlight |

42 | efficientnet |

43 | efficientnetv2 |

44 | ens |

45 | ese |

46 | fbnetc |

47 | fbnetv3 |

48 | gc |

49 | gcresnet33ts |

50 | gcresnet50t |

51 | gcresnext26ts |

52 | gcresnext50ts |

53 | gernet |

54 | ghostnet |

55 | gluon |

56 | gmixer |

57 | gmlp |

58 | halo2botnet50ts |

59 | halonet26t |

60 | halonet50ts |

61 | haloregnetz |

62 | hardcorenas |

63 | hrnet |

64 | ig |

65 | inception |

66 | jx |

67 | lambda |

68 | lamhalobotnet50ts |

69 | lcnet |

70 | legacy |

71 | levit |

72 | mixer |

73 | mixnet |

74 | mnasnet |

75 | mobilenetv2 |

76 | mobilenetv3 |

77 | mobilevit |

78 | nasnetalarge |

79 | nf |

80 | nfnet |

81 | pit |

82 | pnasnet5large |

83 | poolformer |

84 | regnetv |

85 | regnetx |

86 | regnety |

87 | regnetz |

88 | repvgg |

89 | res2net101 |

90 | res2net50 |

91 | res2next50 |

92 | resmlp |

93 | resnest101e |

94 | resnest14d |

95 | resnest200e |

96 | resnest269e |

97 | resnest26d |

98 | resnest50d |

99 | resnet101 |

100 | resnet101d |

101 | resnet152 |

102 | resnet152d |

103 | resnet18 |

104 | resnet18d |

105 | resnet200d |

106 | resnet26 |

107 | resnet26d |

108 | resnet26t |

109 | resnet32ts |

110 | resnet33ts |

111 | resnet34 |

112 | resnet34d |

113 | resnet50 |

114 | resnet50d |

115 | resnet51q |

116 | resnet61q |

117 | resnetblur50 |

118 | resnetrs101 |

119 | resnetrs152 |

120 | resnetrs200 |

121 | resnetrs270 |

122 | resnetrs350 |

123 | resnetrs420 |

124 | resnetrs50 |

125 | resnetv2 |

126 | resnext101 |

127 | resnext26ts |

128 | resnext50 |

129 | resnext50d |

130 | rexnet |

131 | sebotnet33ts |

132 | sehalonet33ts |

133 | selecsls42b |

134 | selecsls60 |

135 | selecsls60b |

136 | semnasnet |

137 | sequencer2d |

138 | seresnet152d |

139 | seresnet33ts |

140 | seresnet50 |

141 | seresnext101 |

142 | seresnext101d |

143 | seresnext26d |

144 | seresnext26t |

145 | seresnext26ts |

146 | seresnext50 |

147 | seresnextaa101d |

148 | skresnet18 |

149 | skresnet34 |

150 | skresnext50 |

151 | spnasnet |

152 | ssl |

153 | swin |

154 | swinv2 |

155 | swsl |

156 | tf |

157 | tinynet |

158 | tnt |

159 | tresnet |

160 | tv |

161 | twins |

162 | vgg11 |

163 | vgg13 |

164 | vgg16 |

165 | vgg19 |

166 | visformer |

167 | vit |

168 | volo |

169 | wide |

170 | xception |

171 | xception41 |

172 | xception41p |

173 | xception65 |

174 | xception65p |

175 | xception71 |

176 | xcit |

Timm provides many pretrained models, but not all of them are fast enough for real-time applications. We can filter the results by providing a full or partial model name.

**Check available pretrained ConvNeXt models**

`'convnext*', pretrained=True)) pd.DataFrame(timm.list_models(`

0 | |
---|---|

0 | convnext_base |

1 | convnext_base_384_in22ft1k |

2 | convnext_base_in22ft1k |

3 | convnext_base_in22k |

4 | convnext_large |

5 | convnext_large_384_in22ft1k |

6 | convnext_large_in22ft1k |

7 | convnext_large_in22k |

8 | convnext_small |

9 | convnext_small_384_in22ft1k |

10 | convnext_small_in22ft1k |

11 | convnext_small_in22k |

12 | convnext_tiny |

13 | convnext_tiny_384_in22ft1k |

14 | convnext_tiny_hnf |

15 | convnext_tiny_in22ft1k |

16 | convnext_tiny_in22k |

17 | convnext_xlarge_384_in22ft1k |

18 | convnext_xlarge_in22ft1k |

19 | convnext_xlarge_in22k |

Let’s go with the `convnext_tiny`

model since we want higher framerates. Each model comes with a set of default configuration parameters. We must keep track of the mean and std values used to normalize the model input. Many pretrained models use the ImageNet normalization stats, but others like MobileViT do not.

**Inspect the default configuration for the convnext_tiny model**

```
from timm.models import convnext
= 'convnext_tiny'
convnext_model ='index') pd.DataFrame.from_dict(convnext.default_cfgs[convnext_model], orient
```

0 | |
---|---|

url | https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth |

num_classes | 1000 |

input_size | (3, 224, 224) |

pool_size | (7, 7) |

crop_pct | 0.875 |

interpolation | bicubic |

mean | (0.485, 0.456, 0.406) |

std | (0.229, 0.224, 0.225) |

first_conv | stem.0 |

classifier | head.fc |

**Check available pretrained MobileNetV2 models**

`'mobilenetv2*', pretrained=True)) pd.DataFrame(timm.list_models(`

0 | |
---|---|

0 | mobilenetv2_050 |

1 | mobilenetv2_100 |

2 | mobilenetv2_110d |

3 | mobilenetv2_120d |

4 | mobilenetv2_140 |

**Inspect the default configuration for the mobilenetv2_050 model**

```
from timm.models import efficientnet
= 'mobilenetv2_050'
mobilenetv2_model ='index') pd.DataFrame.from_dict(efficientnet.default_cfgs[mobilenetv2_model], orient
```

0 | |
---|---|

url | https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_050-3d30d450.pth |

num_classes | 1000 |

input_size | (3, 224, 224) |

pool_size | (7, 7) |

crop_pct | 0.875 |

interpolation | bicubic |

mean | (0.485, 0.456, 0.406) |

std | (0.229, 0.224, 0.225) |

first_conv | conv_stem |

classifier | classifier |

**Check available pretrained MobileNetV3 models**

`'mobilenetv3*', pretrained=True)) pd.DataFrame(timm.list_models(`

0 | |
---|---|

0 | mobilenetv3_large_100 |

1 | mobilenetv3_large_100_miil |

2 | mobilenetv3_large_100_miil_in21k |

3 | mobilenetv3_rw |

4 | mobilenetv3_small_050 |

5 | mobilenetv3_small_075 |

6 | mobilenetv3_small_100 |

**Inspect the default configuration for the mobilenetv3_small_050 model**

```
from timm.models import mobilenetv3
= 'mobilenetv3_small_050'
mobilenetv3_model ='index') pd.DataFrame.from_dict(mobilenetv3.default_cfgs[mobilenetv3_model], orient
```

0 | |
---|---|

url | https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_050_lambc-4b7bbe87.pth |

num_classes | 1000 |

input_size | (3, 224, 224) |

pool_size | (7, 7) |

crop_pct | 0.875 |

interpolation | bicubic |

mean | (0.485, 0.456, 0.406) |

std | (0.229, 0.224, 0.225) |

first_conv | conv_stem |

classifier | classifier |

**Check available pretrained MobileViT models** * **Note:** MobileViT models are not available in timm `0.5.4`

`'mobilevit*', pretrained=True)) pd.DataFrame(timm.list_models(`

0 | |
---|---|

0 | mobilevit_s |

1 | mobilevit_xs |

2 | mobilevit_xxs |

**Inspect the default configuration for the mobilevit_xxs model**

```
from timm.models import mobilevit
= 'mobilevit_xxs'
mobilevit_model ='index') pd.DataFrame.from_dict(mobilevit.default_cfgs[mobilevit_model], orient
```

0 | |
---|---|

url | https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevit_xxs-ad385b40.pth |

num_classes | 1000 |

input_size | (3, 256, 256) |

pool_size | (8, 8) |

crop_pct | 0.9 |

interpolation | bicubic |

mean | (0, 0, 0) |

std | (1, 1, 1) |

first_conv | stem.conv |

classifier | head.fc |

fixed_input_size | False |

**Select a model**

```
= convnext
model_type = convnext_model
model_name # model_type = efficientnet
# model_name = mobilenetv2_model
# model_type = mobilenetv3
# model_name = mobilenetv3_model
# model_type = mobilevit
# model_name = mobilevit_model
```

After picking a model, we’ll store the related normalization stats for future use.

**Store normalization stats**

```
= model_type.default_cfgs[model_name]['mean']
mean = model_type.default_cfgs[model_name]['std']
std mean, std
```

`((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))`

**Define target input dimensions**

```
# size_1_1 = (224, 224)
# size_3_2 = (224, 336)
# size_4_3 = (216, 288)
= (216, 384)
size_16_9 # size_16_9_l = (288, 512)
```

`= size_16_9 input_dims `

## Modify Transforms

We can apply the normalization stats at the end of the batch transforms.

```
= [FlipItem(p=1.0), Resize(input_dims, method=ResizeMethod.Pad, pad_mode=PadMode.Border)]
item_tfms
= [
batch_tfms =0.25),
Contrast(max_lighting=0.25),
Saturation(max_lighting=0.05),
Hue(max_hue*aug_transforms(
=input_dims,
size=1.0,
mult=False,
do_flip=False,
flip_vert=0.0,
max_rotate=0.5,
min_zoom=1.5,
max_zoom=0.5,
max_lighting=0.2,
max_warp=0.0,
p_affine=PadMode.Border),
pad_mode=mean, std=std)
Normalize.from_stats(mean ]
```

## Define Learner

The training process is identical to the original tutorial, and we only need to pass the name of the Timm model to the `vision_learner`

object.

```
= vision_learner(dls, model_name, metrics=metrics).to_fp16()
learn # learn = vision_learner(dls, model_name, metrics=metrics, cbs=[ChannelsLast]).to_fp16()
```

## Export the Model

Once training completes, we need to convert our trained PyTorch model to a TorchScript module for use in LibTorch. We do so using the `torch.jit.trace()`

method.

**Generate a TorchScript module using the test image**

`= torch.jit.trace(learn.model.cpu(), batched_tensor) traced_script_module `

We can perform inference with the TorchScript module the same way we would a PyTorch model.

**Verify the TorchScript module’s accuracy**

```
with torch.no_grad():
= traced_script_module(batched_tensor)
torchscript_preds =1).argmax()] learn.dls.vocab[torch.nn.functional.softmax(torchscript_preds, dim
```

`'J'`

**Define TorchScript file name**

```
= f"{dataset_path.name}-{learn.arch}.pt"
module_file_name module_file_name
```

`'asl-and-some-words-convnext_tiny.pt'`

Some models like MobileViT will require the exact input dimensions in LibTorch as was used in the `torch.jit.trace()`

method. Therefore we’ll convert the PyTorch model again using the training dimensions before saving the TorchScript module to a file.

**Generate a torchscript module using the target input dimensions and save it to a file**

`1, 3, *input_dims).shape torch.randn(`

`torch.Size([1, 3, 216, 384])`

```
= torch.jit.trace(learn.model.cpu(), torch.randn(1, 3, *input_dims))
traced_script_module traced_script_module.save(module_file_name)
```

We can export the normalization stats to a JSON file using the same method for the class labels. We’ll load the stats in Unity and pass them to the LibTorch plugin.

**Export model normalization stats**

```
= {"mean": list(mean), "std": list(std)}
normalization_stats = f"{learn.arch}-normalization_stats.json"
normalization_stats_file_name
with open(normalization_stats_file_name, "w") as write_file:
json.dump(normalization_stats, write_file)
```

## Summary

This post covered how to modify the training code from the fastai-to-unity tutorialto finetune models from the Timm library and export them as TorchScript modules. Part 2 will cover creating a dynamic link library (DLL) file in Visual Studio to perform inference with these TorchScript modules using LibTorch.

**Previous:** Fastai to Unity Tutorial Pt. 3

**Next:** How to Create a LibTorch Plugin for Unity on Windows Pt.2

**Project Resources:** GitHub Repository