Using Fastai and TensorFlow.js for Hand Gesture Recognition in Unity Pt. 1
- Introduction
- Overview
- Install Dependencies
- Import Dependencies
- Select a Model
- Download the Dataset
- Inspect the Dataset
- Define DataLoaders
- Finetune the Model
- Clean Dataset (Optional)
- Test the Model
- Export the Model
- Export Inference Data
- Summary
Tutorial Links
- Part 1: Train a hand gesture classifier using fastai and export it to TensorFlow.js.
- Part 2: Create a TensorFlow.js plugin for the Unity game engine.
- Part 3: Build a Unity project as a WebGL application and host it using GitHub Pages.
- GitHub Repository
Introduction
In this three-part tutorial series, we will use fastai and TensorFlow.js to create an in-browser hand gesture recognition system in Unity. In Part 1, we will train a hand gesture classifier using fastai and export it to TensorFlow.js. In Part 2, we will create a TensorFlow.js plugin for the Unity game engine. Finally, in Part 3, we will host the Unity project as a live demo on GitHub Pages. By the end of this tutorial series, you will have a hand gesture recognition system that you can use in your Unity projects.
In-Browser Demo: Hand Gesture Classifier
The tutorial uses a downscaled version of HaGRID (HAnd Gesture Recognition Image Dataset) that I modified for image classification tasks. The dataset contains images for 18 distinct hand gestures and an additional no_gesture
class to account for idle hands.
Reference Images
Class | Image |
---|---|
call | |
dislike | |
fist | |
four | |
like | |
mute | |
ok | |
one | |
palm | |
peace | |
peace_inverted | |
rock | |
stop | |
stop_inverted | |
three | |
three2 | |
two_up | |
two_up_inverted |
We can use a model trained on this dataset to map hand gestures to user input via a webcam in Unity.
Overview
In Part 1 of this tutorial series, we finetune an image classifier from the timm library using fastai and export it to TensorFlow.js. We will start by installing and importing the necessary dependencies. Then, we will select a model to use and download a dataset to train it. After inspecting the dataset, we will define data loaders to use for training. Finally, we will finetune and export the model. We also demonstrate how to clean the dataset to improve training. By the end of this post, you will have a trained hand gesture classifier that you can use in web applications.
You can find links to view the training code and run it on Google Colab and Kaggle below.
Jupyter Notebook | Colab | Kaggle |
---|---|---|
GitHub Repository | Open In Colab | Open in Kaggle |
Install Dependencies
The training code requires PyTorch for the fastai library, the fastai library itself for training, and the Kaggle API Python package for downloading the dataset. The timm library provides access to a wide range of pretrained image models.
Install model training dependencies
# %%capture
# !pip install -U torch torchvision torchaudio
# !pip install -U fastai==2.7.9
# !pip install -U kaggle
# !pip install -U Pillow
# !pip install -U timm
The model conversion process involves exporting the PyTorch model to ONNX, converting the ONNX model to a TensorFlow SavedModel, then converting the SavedModel to TensorFlow.js web format.
Install Tensorflow.js conversion dependencies
# %%capture
# !pip install -U onnxruntime
# !pip install onnx-tf
# !pip install tensorflowjs
# !pip install tensorflow_probability
# !pip install onnx-simplifier
# !pip install simple-onnx-processing-tools
# !pip install -U onnx_graphsurgeon --index-url https://pypi.ngc.nvidia.com
Import Dependencies
Import all fastai computer vision functionality
from fastai.vision.all import *
Import pandas and disable column and row limits
import pandas as pd
'max_colwidth', None)
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None) pd.set_option(
Import timm library
import timm
timm.__version__
'0.6.7'
Select a Model
Timm provides many pretrained models, but not all are fast enough for real-time applications. We can filter through the available models using the timm.list_models()
function.
View available ResNet models
'resnet*', pretrained=True)) pd.DataFrame(timm.list_models(
0 | |
---|---|
0 | resnet10t |
1 | resnet14t |
2 | resnet18 |
3 | resnet18d |
4 | resnet26 |
5 | resnet26d |
6 | resnet26t |
7 | resnet32ts |
8 | resnet33ts |
9 | resnet34 |
10 | resnet34d |
11 | resnet50 |
12 | resnet50_gn |
13 | resnet50d |
14 | resnet51q |
15 | resnet61q |
16 | resnet101 |
17 | resnet101d |
18 | resnet152 |
19 | resnet152d |
20 | resnet200d |
21 | resnetaa50 |
22 | resnetblur50 |
23 | resnetrs50 |
24 | resnetrs101 |
25 | resnetrs152 |
26 | resnetrs200 |
27 | resnetrs270 |
28 | resnetrs350 |
29 | resnetrs420 |
30 | resnetv2_50 |
31 | resnetv2_50d_evos |
32 | resnetv2_50d_gn |
33 | resnetv2_50x1_bit_distilled |
34 | resnetv2_50x1_bitm |
35 | resnetv2_50x1_bitm_in21k |
36 | resnetv2_50x3_bitm |
37 | resnetv2_50x3_bitm_in21k |
38 | resnetv2_101 |
39 | resnetv2_101x1_bitm |
40 | resnetv2_101x1_bitm_in21k |
41 | resnetv2_101x3_bitm |
42 | resnetv2_101x3_bitm_in21k |
43 | resnetv2_152x2_bit_teacher |
44 | resnetv2_152x2_bit_teacher_384 |
45 | resnetv2_152x2_bitm |
46 | resnetv2_152x2_bitm_in21k |
47 | resnetv2_152x4_bitm |
48 | resnetv2_152x4_bitm_in21k |
The smaller ResNet models are both fast and sufficiently accurate in most settings. Unfortunately, the resnet10t
and resnet14t
models contain operations unsupported by the TensorFlow.js conversion script. We’ll instead use the resnet18
model for our lightweight option.
Inspect config for specific model
Each model comes with a set of default configuration parameters. We must keep track of the mean
and std
values used to normalize the model input. Many pretrained models use the ImageNet normalization stats, but others, like MobileViT, do not.
from timm.models import resnet
= 'resnet18'
resnet_model ='index') pd.DataFrame.from_dict(resnet.default_cfgs[resnet_model], orient
0 | |
---|---|
url | https://download.pytorch.org/models/resnet18-5c106cde.pth |
num_classes | 1000 |
input_size | (3, 224, 224) |
pool_size | (7, 7) |
crop_pct | 0.875 |
interpolation | bilinear |
mean | (0.485, 0.456, 0.406) |
std | (0.229, 0.224, 0.225) |
first_conv | conv1 |
classifier | fc |
View available ConvNeXt models
'convnext*', pretrained=True)) pd.DataFrame(timm.list_models(
0 | |
---|---|
0 | convnext_base |
1 | convnext_base_384_in22ft1k |
2 | convnext_base_in22ft1k |
3 | convnext_base_in22k |
4 | convnext_large |
5 | convnext_large_384_in22ft1k |
6 | convnext_large_in22ft1k |
7 | convnext_large_in22k |
8 | convnext_nano |
9 | convnext_small |
10 | convnext_small_384_in22ft1k |
11 | convnext_small_in22ft1k |
12 | convnext_small_in22k |
13 | convnext_tiny |
14 | convnext_tiny_384_in22ft1k |
15 | convnext_tiny_hnf |
16 | convnext_tiny_in22ft1k |
17 | convnext_tiny_in22k |
18 | convnext_xlarge_384_in22ft1k |
19 | convnext_xlarge_in22ft1k |
20 | convnext_xlarge_in22k |
The convnext_nano
model is highly accurate for its size and is a good choice when compute power is less constrained.
Inspect config for specific model
from timm.models import convnext
= 'convnext_nano'
convnext_model ='index') pd.DataFrame.from_dict(convnext.default_cfgs[convnext_model], orient
0 | |
---|---|
url | https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_d1h-7eb4bdea.pth |
num_classes | 1000 |
input_size | (3, 224, 224) |
pool_size | (7, 7) |
crop_pct | 0.95 |
interpolation | bicubic |
mean | (0.485, 0.456, 0.406) |
std | (0.229, 0.224, 0.225) |
first_conv | stem.0 |
classifier | head.fc |
test_input_size | (3, 288, 288) |
test_crop_pct | 1.0 |
Select a model
# model_type = resnet
# model_name = resnet_model
= convnext
model_type = convnext_model model_name
Extract normalization stats from model config
= model_type.default_cfgs[model_name]['mean']
mean = model_type.default_cfgs[model_name]['std']
std mean, std
((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
Download the Dataset
The Kaggle API tool requires an API Key for a Kaggle account. Sign in or create a Kaggle account using the link below, then click the Create New API Token button.
- Kaggle Account Settings: https://www.kaggle.com/me/account
Kaggle will generate and download a kaggle.json
file containing your username and new API token. Initialize the creds
variable with the values for each.
Enter Kaggle username and API token
= '{"username":"","key":""}' creds
Save Kaggle credentials if none are present
= Path('~/.kaggle/kaggle.json').expanduser()
cred_path # Save API key to a json file if it does not already exist
if not cred_path.exists():
=True)
cred_path.parent.mkdir(exist_ok
cred_path.write_text(creds)0o600) cred_path.chmod(
Import Kaggle API
from kaggle import api
Select a dataset
Now that we have our Kaggle credentials set, we need to define the dataset and where to store it. I made three different-sized versions of the dataset available on Kaggle.
- HaGRID Classification 512p no_gesture 150k
- Images: 154,816
- Size: 4 GB
- HaGRID Classification 512p no_gesture 300k
- Images: 309,632
- Size: 8 GB
- HaGRID Classification 512p no_gesture
- Images: 619,264
- Size: 15 GB
= 'hagrid-classification-512p-no-gesture-150k'
dataset_name # dataset_name = 'hagrid-classification-512p-no-gesture-300k'
# dataset_name = 'hagrid-classification-512p-no-gesture'
= f'innominate817/{dataset_name}' kaggle_dataset
Define path to dataset
We will use the default archive and data folders for the fastai library to store the compressed and uncompressed datasets.
= URLs.path()
archive_dir = archive_dir/'../data'
dataset_dir = Path(f'{archive_dir}/{dataset_name}.zip')
archive_path = Path(f'{dataset_dir}/{dataset_name}') dataset_path
Define method to extract the dataset from an archive file
def file_extract(fname, dest=None):
"Extract `fname` to `dest` using `tarfile` or `zipfile`."
if dest is None: dest = Path(fname).parent
= str(fname)
fname if fname.endswith('gz'): tarfile.open(fname, 'r:gz').extractall(dest)
elif fname.endswith('zip'): zipfile.ZipFile(fname ).extractall(dest)
else: raise Exception(f'Unrecognized archive: {fname}')
Download the dataset if it is not present
if not archive_path.exists():
=archive_dir)
api.dataset_download_cli(kaggle_dataset, path=archive_path, dest=dataset_path) file_extract(fname
Inspect the Dataset
We can start inspecting the dataset once it finishes downloading.
Inspect the dataset path
list(dataset_path.ls())) pd.DataFrame(
0 | |
---|---|
0 | /home/innom-dt/.fastai/archive/../data/hagrid-classification-512p-no-gesture-150k/hagrid-classification-512p-no-gesture-150k |
Get image file paths
= dataset_path/dataset_name img_dir
= get_image_files(img_dir)
files len(files)
154816
Inspect files
The dataset indicates the image class in the parent folder names.
0], files[-1]]) pd.DataFrame([files[
0 | |
---|---|
0 | /home/innom-dt/.fastai/archive/../data/hagrid-classification-512p-no-gesture-150k/hagrid-classification-512p-no-gesture-150k/stop/3ac51cf4-cd81-4803-a608-76a55b36df26.jpeg |
1 | /home/innom-dt/.fastai/archive/../data/hagrid-classification-512p-no-gesture-150k/hagrid-classification-512p-no-gesture-150k/two_up/d5a0a30d-92aa-4a7c-9621-1fed0e8f0b66.jpeg |
Inspect class folder names
There are 19 class folders, and the dataset does not predefine a training-validation split.
= [path.name for path in Path(img_dir).ls()]
folder_names if 'models' in folder_names:
/'models')
os.removedirs(img_dir'models')
folder_names.remove(
folder_names.sort()print(f"Num classes: {len(folder_names)}")
pd.DataFrame(folder_names)
Num classes: 19
0 | |
---|---|
0 | call |
1 | dislike |
2 | fist |
3 | four |
4 | like |
5 | mute |
6 | no_gesture |
7 | ok |
8 | one |
9 | palm |
10 | peace |
11 | peace_inverted |
12 | rock |
13 | stop |
14 | stop_inverted |
15 | three |
16 | three2 |
17 | two_up |
18 | two_up_inverted |
Inspect one of the training images
import PIL
= PIL.Image.open(files[0])
img print(f"Class: {files[0].parent.name}")
print(f"Image Dims: {img.shape}")
img
Class: stop
Image Dims: (512, 512)
Define DataLoaders
Next, we need to define the Transforms for the DataLoaders object.
Define target input dimensions
The Unity project will take input from a webcam, which will likely not have a square aspect ratio. We can still train the models with a square aspect ratio, and training at 256x256
(65,536 pixels) is more efficient than training at 384x216
(82,944 pixels) for a 16:9 aspect ratio.
The ResNet and ConvNeXt models handle arbitrary input dimensions well. However, we must export some models like MobileViT with the exact input dimensions used for inference.
= (256, 256)
size_1_1 = (216, 384)
size_16_9 = size_1_1 size
Define Transforms
We can leave most of the options in aug_transforms
at their default values. The HaGRID dataset is diverse enough that we don’t need to add too much augmentation. However, we will disable the max_rotate
option as orientation is relevant for gesture recognition.
= [Resize(size, method=ResizeMethod.Pad, pad_mode=PadMode.Border)]
item_tfms
= [
batch_tfms *aug_transforms(
=size,
size=1.0,
mult=True,
do_flip=False,
flip_vert=0.0,
max_rotate=1.0,
min_zoom=1.1,
max_zoom=0.2,
max_lighting=0.2,
max_warp=0.75,
p_affine=PadMode.Border)
pad_mode ]
Note: The fastai library automatically normalizes input for timm models as of version 2.7.5.
Define batch size
= 32 bs
Define DataLoaders object
We can use the from_folder
method to instantiate the DataLoaders object.
= ImageDataLoaders.from_folder(
dls =img_dir,
path=0.2,
valid_pct=bs,
bs=item_tfms,
item_tfms=batch_tfms
batch_tfms )
Verify DataLoaders object
Let’s verify the DataLoaders object works as expected before training a model.
dls.train.show_batch()
Finetune the Model
Now we can define the Learner
object and finetune the selected model.
Define metrics
= [error_rate, accuracy] metrics
Define Learner object
= vision_learner(dls, model_name, metrics=metrics) learn
Define model checkpoint file path
= Path(f"{dataset_path.name}-{model_name}.pth")
checkpoint_path checkpoint_path
Path('hagrid-classification-512p-no-gesture-150k-convnext_nano.pth')
Load existing checkpoint (Optional)
# if checkpoint_path.exists():
# print("Loading checkpoint...")
# checkpoint = torch.load(checkpoint_path)
# learn.model.load_state_dict(checkpoint)
Find learning rate
= learn.lr_find() suggested_lrs
Define learning rate
We can use a slightly higher learning rate than the learning rate finder recommends to speed up training.
= suggested_lrs.valley*3
lr lr
0.003606793354265392
Define number of epochs
We should not need to train for more than a few epochs.
= 4 epochs
Define callbacks
Training with mixed precision can significantly reduce training time on modern GPUs. However, the older GPUs on the free tiers for Google Colab and Kaggle will likely not benefit from it.
= [MixedPrecision()] cbs
Fine tune model
=lr, cbs=cbs) learn.fine_tune(epochs, base_lr
epoch | train_loss | valid_loss | error_rate | accuracy | time |
---|---|---|---|---|---|
0 | 0.372442 | 0.189683 | 0.061299 | 0.938701 | 04:22 |
epoch | train_loss | valid_loss | error_rate | accuracy | time |
---|---|---|---|---|---|
0 | 0.096614 | 0.054749 | 0.017214 | 0.982786 | 05:29 |
1 | 0.048555 | 0.033598 | 0.010012 | 0.989988 | 05:34 |
2 | 0.030899 | 0.018264 | 0.005555 | 0.994445 | 05:32 |
3 | 0.018128 | 0.015447 | 0.004877 | 0.995123 | 05:30 |
Save model checkpoint
torch.save(learn.model.state_dict(), checkpoint_path)
Make predictions for a group of images
learn.show_results()
Define an Interpretation object
Once the model finishes training, we can create an Interpretation object to see where it struggles. An Interpretation object is also helpful to see if there are any mislabeled/low-quality training images.
= Interpretation.from_learner(learn) interp
Plot top losses
=9, figsize=(15,10)) interp.plot_top_losses(k
Clean Dataset (Optional)
I spent some time cleaning the modified HaGRID datasets, but there are likely some training images that should still be moved or deleted. If the model accuracy is insufficient, consider using the ImageClassifierCleaner
widget to edit the dataset.
Import fastai ImageClassifierCleaner
# from fastai.vision.widgets import ImageClassifierCleaner
Mark images to delete or move
# cleaner = ImageClassifierCleaner(learn)
# cleaner
Inspect samples to change
# cleaner.change()
Move selected samples to target class folder
# for idx,cat in cleaner.change(): shutil.move(str(cleaner.fns[idx]), img_dir/cat)
Inspect samples to delete
# cleaner.delete()
Delete selected samples
# for idx in cleaner.delete(): cleaner.fns[idx].unlink()
Note: Restart the notebook and finetune the trained model after cleaning the dataset.
Test the Model
Next, we will test the model on a single image.
Select a test image
import PIL
= files[0]
test_file test_file.parent.name, test_file.name
('stop', '3ac51cf4-cd81-4803-a608-76a55b36df26.jpeg')
= PIL.Image.open(test_file)
test_img print(f"Source image size: {test_img.size}")
test_img
Source image size: (512, 512)
Set target size
We will test the model with the target inference resolution to verify it performs as desired.
= size_16_9
size size
(216, 384)
Option 1: Pad to target input dims
= PIL.ImageOps.pad(test_img, [size[1], size[0]], method=PIL.Image.Resampling.BICUBIC)
padded_img print(f"Padded image size: {padded_img.size}")
padded_img
Padded image size: (384, 216)
Option 2: Stretch to target input dims
= test_img.resize([size[1], size[0]])
stretched_img print(f"Stretched image size: {stretched_img.size}")
stretched_img
Stretched image size: (384, 216)
Make a prediction on padded image using a fastai.vision.core.PILImage
= learn.predict(PILImage(padded_img))
pred 2].max() pred, pred[
(('stop',
TensorBase(13),
TensorBase([5.6481e-08, 2.9167e-09, 1.7013e-08, 1.0619e-07, 7.0342e-09, 4.3362e-10,
1.9056e-08, 7.8481e-07, 1.9450e-08, 5.0064e-06, 3.8272e-09, 8.2519e-11,
4.3597e-08, 9.9999e-01, 2.3984e-08, 1.0935e-08, 2.4180e-09, 2.1497e-08,
5.9654e-10])),
TensorBase(1.0000))
Make a prediction on stretched image using a fastai.vision.core.PILImage
= learn.predict(PILImage(stretched_img))
pred 2].max() pred, pred[
(('stop',
TensorBase(13),
TensorBase([1.3940e-06, 4.6373e-07, 1.3970e-04, 6.4621e-07, 6.8510e-08, 9.9468e-09,
1.1748e-07, 1.3881e-06, 1.1672e-06, 3.1076e-04, 9.0491e-09, 7.7261e-10,
8.4245e-08, 9.9954e-01, 1.7898e-07, 2.6569e-08, 3.4712e-08, 2.2750e-08,
6.5716e-09])),
TensorBase(0.9995))
When we are satisfied with the model, we can start preparing for implementing it in TensorFlow.js. We will need to apply some of the preprocessing and post-processing that fastai applies automatically.
Inspect the after_item
pipeline
We do not need to worry about padding the input image as both the ResNet and ConvNeXt models handle arbitrary input dimensions.
learn.dls.after_item
Pipeline: Resize -- {'size': (256, 256), 'method': 'pad', 'pad_mode': 'border', 'resamples': (<Resampling.BILINEAR: 2>, <Resampling.NEAREST: 0>), 'p': 1.0} -> ToTensor
Inspect the after_batch
pipeline
The after_batch
pipeline first scales the image color channel values from [0,255]
to [0,1]
. We will need to do the same for the TensorFlow.js plugin. We will also need to normalize the input image with the relevant normalization stats.
learn.dls.after_batch
Pipeline: IntToFloatTensor -- {'div': 255.0, 'div_mask': 1} -> Flip -- {'size': (256, 256), 'mode': 'bilinear', 'pad_mode': 'border', 'mode_mask': 'nearest', 'align_corners': True, 'p': 0.5} -> Brightness -- {'max_lighting': 0.2, 'p': 1.0, 'draw': None, 'batch': False} -> Normalize -- {'mean': tensor([[[[0.4850]],
[[0.4560]],
[[0.4060]]]], device='cuda:0'), 'std': tensor([[[[0.2290]],
[[0.2240]],
[[0.2250]]]], device='cuda:0'), 'axes': (0, 2, 3)}
Convert image to tensor
We’ll first prepare the input image by converting it to a tensor, batching it, and moving it to the GPU.
= tensor(padded_img).permute(2, 0, 1).float().unsqueeze(dim=0).cuda()
img_tensor img_tensor.shape, img_tensor
(torch.Size([1, 3, 216, 384]),
tensor([[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]],
[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]],
[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]]], device='cuda:0'))
Scale tensor values
We’ll then scale the values from [0, 255] to [0, 1].
= img_tensor / 255 scaled_tensor
Next, we’ll prepare the normalization values.
Prepare mean values
= tensor(mean).view(1,1,-1).permute(2, 0, 1).unsqueeze(dim=0).cuda()
mean_tensor mean_tensor.shape, mean_tensor
(torch.Size([1, 3, 1, 1]),
tensor([[[[0.4850]],
[[0.4560]],
[[0.4060]]]], device='cuda:0'))
Prepare std_dev values
= tensor(std).view(1,1,-1).permute(2, 0, 1).unsqueeze(dim=0).cuda()
std_tensor std_tensor.shape, std_tensor
(torch.Size([1, 3, 1, 1]),
tensor([[[[0.2290]],
[[0.2240]],
[[0.2250]]]], device='cuda:0'))
We can integrate the normalization preprocessing step and the Softmax postprocessing function into the exported model by creating a custom forward method. This way, we don’t need to store the normalization stats for each model in a separate file.
Create a backup of the default model forward function
We first need to create a backup of the current forward method.
= learn.model.forward origin_forward
Define custom forward function for exporting the model
The custom forward method will normalize the input tensor, feed it to the original forward method and pass the raw output through a Softmax function.
def forward_export(self, x):
# Normalize input
= (x - mean_tensor) / std_tensor
normalized_tensor
# Get raw model output
= origin_forward(normalized_tensor)
preds
# Apply SoftMax function
return torch.nn.functional.softmax(preds, dim=1)
Note: This custom forward method might also be a suitable spot to resize input images for models like MobileViT with fixed input dimensions.
Add custom forward function to model
We then add the custom forward method to the model using monkey patching.
= forward_export.__get__(learn.model) learn.model.forward_export
Test custom forward function
Now we can test the custom forward method to verify it returns the expected prediction.
with torch.no_grad():
= learn.model.forward_export(scaled_tensor)
preds preds.cpu().argmax(), preds.cpu()
(TensorBase(13),
TensorBase([[1.8443e-08, 1.8163e-09, 1.3866e-08, 2.6368e-08, 1.0109e-09, 6.3904e-10,
2.6506e-09, 7.7717e-09, 3.7365e-10, 1.0260e-06, 9.1487e-11, 4.4600e-11,
4.3488e-10, 1.0000e+00, 1.8129e-08, 3.4815e-09, 3.7684e-10, 1.1454e-08,
1.1459e-10]]))
Get the class labels
learn.dls.vocab
['call', 'dislike', 'fist', 'four', 'like', 'mute', 'no_gesture', 'ok', 'one', 'palm', 'peace', 'peace_inverted', 'rock', 'stop', 'stop_inverted', 'three', 'three2', 'two_up', 'two_up_inverted']
Get the predicted class label
learn.dls.vocab[preds.cpu().argmax()]
'stop'
Replace model forward function with custom function
Lastly, we need to replace the current forward method with the custom one before exporting the model to ONNX.
= learn.model.forward_export learn.model.forward
Export the Model
Now we can begin the process of converting the PyTorch model to TensorFlow.js.
Define ONNX opset version
= 15 opset
Define ONNX file name
= f"{dataset_path.name}-{learn.arch}-opset{opset}.onnx"
onnx_file_name print(onnx_file_name)
hagrid-classification-512p-no-gesture-150k-convnext_nano-opset15.onnx
Export trained model to ONNX
We will also unlock the input dimensions for the model to give ourselves more flexibility in Unity. This setting will not matter for models like MobileViT that require exact input dimensions.
torch.onnx.export(learn.model.cpu(),
batched_tensor,
onnx_file_name,=True,
export_params=opset,
opset_version=True,
do_constant_folding= ['input'],
input_names = ['output'],
output_names ={'input': {2 : 'height', 3 : 'width'}}
dynamic_axes )
Import dependencies for TensorFlow.js conversion
import onnx
from scc4onnx import order_conversion
from onnxsim import simplify
from onnx_tf.backend import prepare
Load ONNX model
= onnx.load(onnx_file_name) onnx_model
Inspect model input
Looking at the input layer for the ONNX model, we can see that it is channels-first. However, the channels-last format is more straightforward for preparing model input in JavaScript. We can switch the model input to channels-last format using the scc4onnx package.
input[0] onnx_model.graph.
name: "input"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 3
}
dim {
dim_param: "height"
}
dim {
dim_param: "width"
}
}
}
}
Get input name
= onnx_model.graph.input[0].name
input_name input_name
'input'
Convert model input to channels-last format
= order_conversion(
onnx_model =onnx_model,
onnx_graph={f"{input_name}": [0,2,3,1]},
input_op_names_and_order_dims=True
non_verbose )
Inspect updated model input
If we look at the input layer again, we can see it is now in channels-last format.
input[0] onnx_model.graph.
name: "input"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_param: "height"
}
dim {
dim_param: "width"
}
dim {
dim_value: 3
}
}
}
}
Simplify ONNX model
The ONNX models generated by PyTorch are not always the most concise. We can use the onnx-simplifier package to tidy up the exported model. This step is entirely optional.
= simplify(onnx_model)
onnx_model, check check
True
Prepare ONNX model for Tensorflow Backend
Next, we need to convert the ONNX model to an internal representation of the computational graph.
= prepare(onnx_model) tf_rep
Define path for TensorFlow saved model directory
= f"./{onnx_file_name.split('.')[0]}"
tf_model_dir tf_model_dir
'./hagrid-classification-512p-no-gesture-150k-convnext_nano-opset15'
We can now export the internal representation to a Tensorflow proto file.
Export backend representation to a Tensorflow proto file
tf_rep.export_graph(tf_model_dir)
Define directory path to store tfjs model files
= f"{tf_model_dir}-tfjs-uint8"
tfjs_model_dir tfjs_model_dir
'./hagrid-classification-512p-no-gesture-150k-convnext_nano-opset15-tfjs-uint8'
Define arguments for tfjs converter script
The TensorFlow.js conversion script provides a few quantization options. These can significantly reduce the model file size. The file size matters since users download the models when loading the web demo. However, using the quantization options on small models like MobileNet can hurt accuracy.
from IPython.display import Markdown, display
= f"""tensorflowjs_converter
tfjs_convert_command --input_format=tf_saved_model
--output_format=tfjs_graph_model
--signature_name=serving_default
--saved_model_tags=serve
"{tf_model_dir}"
"{tfjs_model_dir}"
"--quantize_uint8"
"""
= " ".join(tfjs_convert_command.split())
tfjs_convert_command f"```bash\n{tfjs_convert_command}\n```")) display(Markdown(
tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model --signature_name=serving_default --saved_model_tags=serve "./hagrid-classification-512p-no-gesture-150k-convnext_nano-opset15" "./hagrid-classification-512p-no-gesture-150k-convnext_nano-opset15-tfjs-uint8" "--quantize_uint8"
Export SavedModel to TFJS format
The conversion script will create a new folder containing a model.json
file describing the model architecture and some BIN files storing the model weights.
print("Exporting TensorFlow SavedModel to TensorFlow.js Graph model...")
= %sx $tfjs_convert_command
conversion_result print("\n".join(conversion_result))
Export Inference Data
We can export the list of class labels to a JSON file and import it into the Unity project. That way, we don’t have to hardcode them, and we can easily swap in models trained on different datasets.
import json
Export class labels
= list(learn.dls.vocab)
labels labels
['call',
'dislike',
'fist',
'four',
'like',
'mute',
'no_gesture',
'ok',
'one',
'palm',
'peace',
'peace_inverted',
'rock',
'stop',
'stop_inverted',
'three',
'three2',
'two_up',
'two_up_inverted']
= {"classes": labels}
class_labels = f"{dataset_path.name}-classes.json"
class_labels_file_name
with open(class_labels_file_name, "w") as write_file:
json.dump(class_labels, write_file)
Summary
In this post, we finetuned an image classifier from the timm library using fastai and exported it to TensorFlow.js. We started by installing and importing the necessary dependencies, then selected a model to use and downloaded a dataset to train it. After inspecting the dataset, we defined data loaders to use for training. Finally, we finetuned and exported the model. We also demonstrated how to clean the dataset to improve training. With this completed, we are ready to move on to Part 2 of this tutorial series, where we will create a TensorFlow.js plugin for the Unity game engine.
Next: In-Browser Hand Gesture Recognition for Unity with Fastai and TensorFlow.js Pt. 2
Project Resources: GitHub Repository
I’m Christian Mills, a deep learning consultant specializing in practical AI implementations. I help clients leverage cutting-edge AI technologies to solve real-world problems.
Interested in working together? Fill out my Quick AI Project Assessment form or learn more about me.