Barracuda PoseNet Tutorial 2nd Edition Pt. 6
Overview
In this post, we will cover how to implement the post processing steps for multi-pose estimation. This method is more complex than what is required to perform single pose estimation. However, it can produce more reliable results.
Note: The original JavaScript code for decoding multiple poses can be found in the official tfjs-models repository on GitHub. The code has been modified for this tutorial to better take advantage of functionality provided by Unity and .NET.
Update Utils
Script
There are a couple new variables and several methods that we will need to add to decode multiple poses from the model output.
Add Required Namespace
First, we need to add the System
namespace to access the Tuple
class. We also need to access the System.Linq
namespace to access classes and interfaces for querying data structures.
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.Barracuda;
using System;
using System.Linq;
Add Public Variables
When iterating through the heatmaps from the model output, we will only be considering heatmap indices with the highest confidence score within a local radius called kLocalMaximumRadius
. Naturally, setting this value to be larger than the dimensions of the heatmap would not do any good. The original code sets this radius to a constant value of 1
so we will do the same.
When decoding the key point locations for a single body, we will need to traverse from the current key point to its neighboring key point. For example, the nose neighbors both the left eye and right eye. We will keep track of which key points neighbor each other in a TupleTuple<int, int>
array, where the values are the key point id numbers.
/// <summary>
/// Defines the size of the local window in the heatmap to look for
/// confidence scores higher than the one at the current heatmap coordinate
/// </summary>
const int kLocalMaximumRadius = 1;
/// <summary>
/// Defines the parent->child relationships used for multipose detection.
/// </summary>
public static Tuple<int, int>[] parentChildrenTuples = new Tuple<int, int>[]{
// Nose to Left Eye
.Create(0, 1),
Tuple// Left Eye to Left Ear
.Create(1, 3),
Tuple// Nose to Right Eye
.Create(0, 2),
Tuple// Right Eye to Right Ear
.Create(2, 4),
Tuple// Nose to Left Shoulder
.Create(0, 5),
Tuple// Left Shoulder to Left Elbow
.Create(5, 7),
Tuple// Left Elbow to Left Wrist
.Create(7, 9),
Tuple// Left Shoulder to Left Hip
.Create(5, 11),
Tuple// Left Hip to Left Knee
.Create(11, 13),
Tuple// Left Knee to Left Ankle
.Create(13, 15),
Tuple// Nose to Right Shoulder
.Create(0, 6),
Tuple// Right Shoulder to Right Elbow
.Create(6, 8),
Tuple// Right Elbow to Right Wrist
.Create(8, 10),
Tuple// Right Shoulder to Right Hip
.Create(6, 12),
Tuple// Right Hip to Right Knee
.Create(12, 14),
Tuple// Right Knee to Right Ankle
.Create(14, 16)
Tuple};
Create GetStridedIndexNearPoint
Method
In order to traverse from a key point to its neighboring key point, we will need to downscale the key point position back down to the heatmap resolution. We can calculate the nearest heatmap indices by dividing the position by the stride value for the model and clamping the result.
/// <summary>
/// Calculate the heatmap indices closest to the provided point
/// </summary>
/// <param name="point"></param>
/// <param name="stride"></param>
/// <param name="height"></param>
/// <param name="width"></param>
/// <returns>A vector with the nearest heatmap coordinates</returns>
static Vector2Int GetStridedIndexNearPoint(Vector2 point, int stride, int height, int width)
{
// Downscale the point coordinates to the heatmap dimensions
return new Vector2Int(
(int)Mathf.Clamp(Mathf.Round(point.x / stride), 0, width - 1),
(int)Mathf.Clamp(Mathf.Round(point.y / stride), 0, height - 1)
);
}
Create GetDisplacement
Method
The displacement layers from the model output are used to find the location of the nearest neighboring key point. Much like the offset layer, they provide vectors that we then add to the current key point position.
/// <summary>
/// Retrieve the displacement values for the provided point
/// </summary>
/// <param name="edgeId"></param>
/// <param name="point"></param>
/// <param name="displacements"></param>
/// <returns>A vector witht he displacement values for the provided point</returns>
static Vector2 GetDisplacement(int edgeId, Vector2Int point, Tensor displacements)
{
// Calculate the number of edges for the pose skeleton
int numEdges = (int)(displacements.channels / 2);
// Get the displacement values for the provided heatmap coordinates
return new Vector2(
[0, point.y, point.x, numEdges + edgeId],
displacements[0, point.y, point.x, edgeId]
displacements);
}
Create TraverseToTargetKeypoint
Method
We can use the GetStridedIndexNearPoint
and GetDisplacement
methods to find the location of the neighboring key point for a given Keypoint
.
Method Steps
Get the nearest heatmap indices for the current key point position
Get the displacement vector for the nearest heatmap indices
Calculate the position for a neighboring key point using the displacement vector
Get the nearest heatmap indices for the displaced point
Refine the location key point location with the associated offset vector
Get the confidence score for the neighboring key point
Return the neighboring
Keypoint
/// <summary>
/// Get a new keypoint along the provided edgeId for the pose instance.
/// </summary>
/// <param name="edgeId"></param>
/// <param name="sourceKeypoint"></param>
/// <param name="targetKeypointId"></param>
/// <param name="scores"></param>
/// <param name="offsets"></param>
/// <param name="stride"></param>
/// <param name="displacements"></param>
/// <returns>A new keypoint with the displaced coordinates</returns>
static Keypoint TraverseToTargetKeypoint(
int edgeId, Keypoint sourceKeypoint, int targetKeypointId,
, Tensor offsets, int stride,
Tensor scores)
Tensor displacements{
// Get heatmap dimensions
int height = scores.height;
int width = scores.width;
// Get neareast heatmap indices for source keypoint
= GetStridedIndexNearPoint(
Vector2Int sourceKeypointIndices .position, stride, height, width);
sourceKeypoint// Retrieve the displacement values for the current indices
= GetDisplacement(edgeId, sourceKeypointIndices, displacements);
Vector2 displacement // Add the displacement values to the keypoint position
= sourceKeypoint.position + displacement;
Vector2 displacedPoint // Get neareast heatmap indices for displaced keypoint
=
Vector2Int displacedPointIndices GetStridedIndexNearPoint(displacedPoint, stride, height, width);
// Get the offset vector for the displaced keypoint indices
= GetOffsetVector(
Vector2 offsetVector .y, displacedPointIndices.x, targetKeypointId,
displacedPointIndices);
offsets// Get the heatmap value at the displaced keypoint location
float score = scores[0, displacedPointIndices.y, displacedPointIndices.x, targetKeypointId];
// Calculate the position for the displaced keypoint
= (displacedPointIndices * stride) + offsetVector;
Vector2 targetKeypoint
return new Keypoint(score, targetKeypoint, targetKeypointId);
}
Create DecodePose
Method
We don’t know which key point (e.g. nose, left shoulder, right wrist) we will start from when decoding a single pose. Therefore, we will need to traverse the list of neighboring key points both forwards and backwards to get all 17 of the key points for an individual in the input image.
Method Steps
Initialize a new
Keypoint
arrayGet the input image coordinates for the starting key point
Store the starting key point in the
Keypoint
arrayIterate upwards through the list of neighboring key points
- Confirm that the current child key point has already been found and that the parent key point has not already been found.
- Call the
TraverseToTargetKeypoint
method to obtain neighboring key points - Store each neighboring key point in the
Keypoint
array according to its id number
- Call the
- Confirm that the current child key point has already been found and that the parent key point has not already been found.
Iterate downwards through the list of neighboring key points
- Confirm that the current parent key point has already been found and that the child key point has not already been found.
- Call the
TraverseToTargetKeypoint
method to obtain neighboring key points - Store each neighboring key point in the
Keypoint
array according to its id number
- Call the
- Confirm that the current parent key point has already been found and that the child key point has not already been found.
Return the
Keypoint
array
/// <summary>
/// Follows the displacement fields to decode the full pose of the object
/// instance given the position of a part that acts as root.
/// </summary>
/// <param name="root"></param>
/// <param name="scores"></param>
/// <param name="offsets"></param>
/// <param name="stride"></param>
/// <param name="displacementsFwd"></param>
/// <param name="displacementsBwd"></param>
/// <returns>An array of keypoints for a single pose</returns>
static Keypoint[] DecodePose(Keypoint root, Tensor scores, Tensor offsets,
int stride, Tensor displacementsFwd, Tensor displacementsBwd)
{
[] instanceKeypoints = new Keypoint[scores.channels];
Keypoint
// Start a new detection instance at the position of the root.
= GetImageCoords(root, stride, offsets);
Vector2 rootPoint
[root.id] = new Keypoint(root.score, rootPoint, root.id);
instanceKeypoints
int numEdges = parentChildrenTuples.Length;
// Decode the part positions upwards in the tree, following the backward
// displacements.
for (int edge = numEdges - 1; edge >= 0; --edge)
{
int sourceKeypointId = parentChildrenTuples[edge].Item2;
int targetKeypointId = parentChildrenTuples[edge].Item1;
if (instanceKeypoints[sourceKeypointId].score > 0.0f &&
[targetKeypointId].score == 0.0f)
instanceKeypoints{
[targetKeypointId] = TraverseToTargetKeypoint(
instanceKeypoints, instanceKeypoints[sourceKeypointId], targetKeypointId, scores,
edge, stride, displacementsBwd);
offsets}
}
// Decode the part positions downwards in the tree, following the forward
// displacements.
for (int edge = 0; edge < numEdges; ++edge)
{
int sourceKeypointId = parentChildrenTuples[edge].Item1;
int targetKeypointId = parentChildrenTuples[edge].Item2;
if (instanceKeypoints[sourceKeypointId].score > 0.0f &&
[targetKeypointId].score == 0.0f)
instanceKeypoints{
[targetKeypointId] = TraverseToTargetKeypoint(
instanceKeypoints, instanceKeypoints[sourceKeypointId], targetKeypointId, scores,
edge, stride, displacementsFwd);
offsets}
}
return instanceKeypoints;
}
Create ScoreIsMaximumInLocalWindow
Method
As mentioned earlier, we only consider key points with the highest confidence score in their local area as potential starting key points. We will determine whether a given key point has the highest score in a new method called ScoreIsMaximumInLocalWindow
.
Method Steps
Calculate the starting and ending indices for the local heatmap window
Iterate through the heatmap indices within the local window
Compare each confidence score in the local window to the score for the provided key point
- Return
false
if any higher scores are found
- Return
/// <summary>
/// Compare the value at the current heatmap location to the surrounding values
/// </summary>
/// <param name="keypointId"></param>
/// <param name="score"></param>
/// <param name="heatmapY"></param>
/// <param name="heatmapX"></param>
/// <param name="localMaximumRadius"></param>
/// <param name="scores"></param>
/// <returns>True if the value is the highest within a given radius</returns>
static bool ScoreIsMaximumInLocalWindow(int keypointId, float score, int heatmapY, int heatmapX,
int localMaximumRadius, Tensor heatmaps)
{
bool localMaximum = true;
// Calculate the starting heatmap colummn index
int yStart = Mathf.Max(heatmapY - localMaximumRadius, 0);
// Calculate the ending heatmap colummn index
int yEnd = Mathf.Min(heatmapY + localMaximumRadius + 1, heatmaps.height);
// Iterate through calulated range of heatmap columns
for (int yCurrent = yStart; yCurrent < yEnd; ++yCurrent)
{
// Calculate the starting heatmap row index
int xStart = Mathf.Max(heatmapX - localMaximumRadius, 0);
// Calculate the ending heatmap row index
int xEnd = Mathf.Min(heatmapX + localMaximumRadius + 1, heatmaps.width);
// Iterate through calulated range of heatmap rows
for (int xCurrent = xStart; xCurrent < xEnd; ++xCurrent)
{
// Check if the score for at the current heatmap location
// is the highest within the specified radius
if (heatmaps[0, yCurrent, xCurrent, keypointId] > score)
{
= false;
localMaximum break;
}
}
if (!localMaximum) break;
}
return localMaximum;
}
Create BuildPartList
Method
This is where we will build the list of potential starting key points that be passed to the DecodePose
method.
Much like the DecodeSinglePose
method, we need to iterate through the entire heatmap Tensor. This time, we will only consider heatmap indices with a value above the provided score threshold. When we get to an index with a value that meets this threshold, we will call the ScoreIsMaximumInLocalWindow
method to confirm that it is the highest score in its local area. The heatmap indices with the highest local score will be added to a Keypoint
List
.
/// <summary>
/// Iterate through the heatmaps and create a list of indicies
/// with the highest values within the provided radius.
/// </summary>
/// <param name="scoreThreshold"></param>
/// <param name="localMaximumRadius"></param>
/// <param name="scores"></param>
/// <returns>A list of keypoints with the highest values in their local area</returns>
static List<Keypoint> BuildPartList(float scoreThreshold, int localMaximumRadius, Tensor heatmaps)
{
<Keypoint> list = new List<Keypoint>();
List
// Iterate through heatmaps
for (int c = 0; c < heatmaps.channels; c++)
{
// Iterate through heatmap columns
for (int y = 0; y < heatmaps.height; y++)
{
// Iterate through column rows
for (int x = 0; x < heatmaps.width; x++)
{
float score = heatmaps[0, y, x, c];
// Skip parts with score less than the scoreThreshold
if (score < scoreThreshold) continue;
// Only add keypoints with the highest score in a local window.
if (ScoreIsMaximumInLocalWindow(c, score, y, x, localMaximumRadius, heatmaps))
{
.Add(new Keypoint(score, new Vector2(x, y), c));
list}
}
}
}
return list;
}
Create WithinNmsRadiusOfCorrespondingPoint
Method
We want to make sure that any key points that have already been assigned to a body do not get used again. We can prevent this by only sending key points to the DecodePose
method that are not too close to any key points in an existing Keypoint
array.
/// <summary>
/// Check if the provided image coordinates are too close to any keypoints in existing poses
/// </summary>
/// <param name="poses"></param>
/// <param name="squaredNmsRadius"></param>
/// <param name="vec"></param>
/// <param name="keypointId"></param>
/// <returns>True if there are any existing poses too close to the provided coords</returns>
static bool WithinNmsRadiusOfCorrespondingPoint(
<Keypoint[]> poses, float squaredNmsRadius, Vector2 vec, int keypointId)
List{
// SquaredDistance
return poses.Any(pose => (vec - pose[keypointId].position).sqrMagnitude <= squaredNmsRadius);
}
Create DecodeMultiplePoses
Method
This is the method that will be called from the PoseEstimator
script after executing the model. It will take in all four output Tensors from the model output along with the stride value, max number poses to decode, a minimum confidence score threshold, and the radius for determining if a key point is too close to an existing pose.
Method Steps
- Initialize a new
List
ofKeypoint
arrays. - Square the provided radius value
- Call the
BuildPartList
method to get theList
of potential starting key points - Sort the
List
in descending order based on the confidence scores for the key points - Iterate through the
List
of starting key points- Create a copy of the key point with the highest score
- Remove the key point from the
List
- Get the input image coordinates for the key point
- Skip the key point if it is too close to an existing
Keypoint
array - Call the
DecodePose
method with the key point as the starting key point - Add the new
Keypoint
array to theList
- Return the
List
ofKeypoint
arrays as an array.
/// <summary>
/// Detects multiple poses and finds their parts from part scores and displacement vectors.
/// </summary>
/// <param name="heatmaps"></param>
/// <param name="offsets"></param>
/// <param name="displacementsFwd"></param>
/// <param name="displacementBwd"></param>
/// <param name="stride"></param>
/// <param name="maxPoseDetections"></param>
/// <param name="scoreThreshold"></param>
/// <param name="nmsRadius"></param>
/// <returns>An array of poses up to maxPoseDetections in size</returns>
public static Keypoint[][] DecodeMultiplePoses(
, Tensor offsets,
Tensor heatmaps, Tensor displacementBwd,
Tensor displacementsFwdint stride, int maxPoseDetections,
float scoreThreshold = 0.5f, int nmsRadius = 20)
{
// Stores the final poses
<Keypoint[]> poses = new List<Keypoint[]>();
List//
float squaredNmsRadius = (float)nmsRadius * nmsRadius;
// Get a list of indicies with the highest values within the provided radius.
<Keypoint> list = BuildPartList(scoreThreshold, kLocalMaximumRadius, heatmaps);
List// Order the list in descending order based on score
= list.OrderByDescending(x => x.score).ToList();
list
// Decode poses until the max number of poses has been reach or the part list is empty
while (poses.Count < maxPoseDetections && list.Count > 0)
{
// Get the part with the highest score in the list
= list[0];
Keypoint root // Remove the keypoint from the list
.RemoveAt(0);
list
// Calculate the input image coordinates for the current part
= GetImageCoords(root, stride, offsets);
Vector2 rootImageCoords
// Skip parts that are too close to existing poses
if (WithinNmsRadiusOfCorrespondingPoint(
, squaredNmsRadius, rootImageCoords, root.id))
poses{
continue;
}
// Find the keypoints in the same pose as the root part
[] keypoints = DecodePose(
Keypoint, heatmaps, offsets, stride, displacementsFwd,
root);
displacementBwd
// The current list of keypoints
.Add(keypoints);
poses}
return poses.ToArray();
}
Update PoseEstimator
Script
Now we can complete the ProcessOutput
method in the PoseEstimator
script.
Add Public Variables
First, we will add some public variables so that we can adjust the max number of poses to decode, score threshold, and radius for determining whether a key point is too close to an existing pose from the Inspector tab.
[Tooltip("The maximum number of posees to estimate")]
[Range(1, 20)]
public int maxPoses = 20;
[Tooltip("The score threshold for multipose estimation")]
[Range(0, 1.0f)]
public float scoreThreshold = 0.25f;
[Tooltip("Non-maximum suppression part distance")]
public int nmsRadius = 100;
Modify ProcessOutput
Method
We will assign the output from the DecodeMultiplePoses
to the poses
variable.
// Determine the key point locations
= Utils.DecodeMultiplePoses(
poses , offsets,
heatmaps, displacementBWD,
displacementFWD: stride, maxPoseDetections: maxPoses,
stride: scoreThreshold,
scoreThreshold: nmsRadius); nmsRadius
Full Code
/// <summary>
/// Obtains the model output and either decodes single or mutlple poses
/// </summary>
/// <param name="engine"></param>
private void ProcessOutput(IWorker engine)
{
// Get the model output
= engine.PeekOutput(predictionLayer);
Tensor heatmaps = engine.PeekOutput(offsetsLayer);
Tensor offsets = engine.PeekOutput(displacementFWDLayer);
Tensor displacementFWD = engine.PeekOutput(displacementBWDLayer);
Tensor displacementBWD
// Calculate the stride used to scale down the inputImage
int stride = (imageDims.y - 1) / (heatmaps.shape.height - 1);
-= (stride % 8);
stride
if (estimationType == EstimationType.SinglePose)
{
= new Utils.Keypoint[1][];
poses
// Determine the key point locations
[0] = Utils.DecodeSinglePose(heatmaps, offsets, stride);
poses}
else
{
// Determine the key point locations
= Utils.DecodeMultiplePoses(
poses , offsets,
heatmaps, displacementBWD,
displacementFWD: stride, maxPoseDetections: maxPoses,
stride: scoreThreshold,
scoreThreshold: nmsRadius);
nmsRadius}
.Dispose();
heatmaps.Dispose();
offsets.Dispose();
displacementFWD.Dispose();
displacementBWD}
Summary
We now have everything need to perform pose estimation. However, we cannot currently gauge the accuracy of the estimated poses. In the next post, we will demonstrate how to add pose skeletons so that we can compare the estimated key point locations to the source video feed.
Previous: Part 5
Next: Part 7
Project Resources: GitHub Repository
- I’m Christian Mills, a deep learning consultant specializing in computer vision and practical AI implementations.
- I help clients leverage cutting-edge AI technologies to solve real-world problems.
- Learn more about me or reach out via email at [email protected] to discuss your project.