Notes on fastai Book Ch. 18
ai
fastai
notes
pytorch
Chapter 18 covers interpreting convolutional neural networks using class-activation maps.
This post is part of the following series:
#hide
# !pip install -Uqq fastbook
import fastbook
fastbook.setup_book()
#hide
from fastbook import *
import inspect
def print_source(obj):
for line in inspect.getsource(obj).split("\n"):
print(line)
CNN Interpretation with CAM
CAM and Hooks
Class Activation Map (CAM)
- Learning Deep Features for Discriminative Localization
- uses the output of the last convolutional layer together with the predictions to generate a heatmap visualization of why the model made its decision
- at each position of our final convolutional layer, we have as many filters as in the last linear layer
- we can compute the dot product of the activations with the final weights to get the score of the feature that was used to make a decision for each location in our feature map
- need a way to get access to the activations inside the model while it is training
- can use PyTorch hooks
PyTorch Hooks
- PyTorch’s equivalent of fastai’s callbacks
- allow you to inject code into the forward and backward calculations themselves
- can attach a hook to any layer of the model, and it will be executed when we compute the outputs (forward hook) or during backpropogation (backward hook)
- forward hook:
- a function that takes a module, its input, and its output
- can perform any behavior you want
HookCallback
fastai.callback.hook.HookCallback
print_source(HookCallback)
@funcs_kwargs
class HookCallback(Callback):
"`Callback` that can be used to register hooks on `modules`"
_methods = ["hook"]
hook = noops
def __init__(self, modules=None, every=None, remove_end=True, is_forward=True, detach=True, cpu=True, include_paramless=False , **kwargs):
store_attr('modules,every,remove_end,is_forward,detach,cpu, include_paramless')
assert not kwargs
def before_fit(self):
"Register the `Hooks` on `self.modules`."
if self.modules is None: self.modules = [m for m in flatten_model(self.model) if self.include_paramless or has_params(m)]
if self.every is None: self._register()
def before_batch(self):
if self.every is None: return
if self.training and self.train_iter%self.every==0: self._register()
def after_batch(self):
if self.every is None: return
if self.training and self.train_iter%self.every==0: self._remove()
def after_fit(self):
"Remove the `Hooks`."
if self.remove_end: self._remove()
def _register(self): self.hooks = Hooks(self.modules, self.hook, self.is_forward, self.detach, self.cpu)
def _remove(self):
if getattr(self, 'hooks', None): self.hooks.remove()
def __del__(self): self._remove()
= untar_data(URLs.PETS)/'images'
path path
Path('/home/innom-dt/.fastai/data/oxford-iiit-pet/images')
def is_cat(x): return x[0].isupper()
= ImageDataLoaders.from_name_func(
dls =0.2, seed=21,
path, get_image_files(path), valid_pct=is_cat, item_tfms=Resize(224))
label_func# Train a cat classifier with the default settings
= cnn_learner(dls, resnet34, metrics=error_rate)
learn 1) learn.fine_tune(
epoch | train_loss | valid_loss | error_rate | time |
---|---|---|---|---|
0 | 0.127009 | 0.042651 | 0.011502 | 00:15 |
epoch | train_loss | valid_loss | error_rate | time |
---|---|---|---|---|
0 | 0.045776 | 0.023368 | 0.007442 | 00:19 |
print_source(image_cat)
def image_cat (): return BytesIO(pkgutil.get_data('fastbook', 'images/cat.jpg'))
# Load an image of a cat
= PILImage.create(image_cat())
img # Get a batch of data from the test set
= first(dls.test_dl([img])) x,
# Define a hook that stores a copy of the output
class Hook():
def hook_func(self, m, i, o): self.stored = o.detach().clone()
# Instantiate a hook
= Hook()
hook_output # Attach the hook to the last layer in the model
= learn.model[0].register_forward_hook(hook_output.hook_func) hook
# Perform inference on the test batch
with torch.no_grad(): output = learn.model.eval()(x)
# Get the copy of the model activations from the hook
= hook_output.stored[0]
act act.shape
torch.Size([512, 7, 7])
# Check the model predictions
=-1) F.softmax(output, dim
TensorBase([[4.0876e-09, 1.0000e+00]], device='cuda:0')
Note: The model is confident the image contains a cat.
dls.vocab
[False, True]
x.shape
torch.Size([1, 3, 224, 224])
1] learn.model[
Sequential(
(0): AdaptiveConcatPool2d(
(ap): AdaptiveAvgPool2d(output_size=1)
(mp): AdaptiveMaxPool2d(output_size=1)
)
(1): Flatten(full=False)
(2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): Dropout(p=0.25, inplace=False)
(4): Linear(in_features=1024, out_features=512, bias=False)
(5): ReLU(inplace=True)
(6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): Dropout(p=0.5, inplace=False)
(8): Linear(in_features=512, out_features=2, bias=False)
)
1][-1] learn.model[
Linear(in_features=512, out_features=2, bias=False)
# Calculate the dot product the weight matrix for the last layer with the activations
= torch.einsum('ck,kij->cij', learn.model[1][-1].weight, act)
cam_map cam_map.shape
torch.Size([2, 7, 7])
Note: The result of the dot product is a \(7x7\) feature map that shows where the activations were higher and lower.
# Get the first image from the test batch
= TensorImage(dls.train.decode((x,))[0][0])
x_dec = plt.subplots()
_,ax =ax)
x_dec.show(ctx# Overlay the feature map for the cat class on top of the cat image
1].detach().cpu(), alpha=0.6, extent=(0,224,224,0),
ax.imshow(cam_map[='bilinear', cmap='magma'); interpolation
Note: The bright yellow spots correspond to high activations. * The head and front paw seem to have had the most influence on the model’s prediction.
# Remove the hook to avoid memory leaks
hook.remove()
print_source(hook.remove)
def remove(self) -> None:
hooks_dict = self.hooks_dict_ref()
if hooks_dict is not None and self.id in hooks_dict:
del hooks_dict[self.id]
# Update the custom hook to be a context manager
class Hook():
def __init__(self, m):
self.hook = m.register_forward_hook(self.hook_func)
def hook_func(self, m, i, o): self.stored = o.detach().clone()
# Automatically register the hook when entering it
def __enter__(self, *args): return self
# Automatically remove the hook when exiting it
def __exit__(self, *args): self.hook.remove()
# Pass the model to the `__enter__()` method for the Hook
with Hook(learn.model[0]) as hook:
with torch.no_grad(): output = learn.model.eval()(x.cuda())
= hook.stored act
Note: This method only works for the last layer.
Gradient CAM
- Grad-CAM: Why did you say that?
- uses the final activation for the desired class
- can use Grad-CAM on any layer
- recall the gradients of the output of the last layers with respect to the input of that layer are equal to the layer weights
# Define a hook that stores a copy of the gradients calculated by PyTorch during the backward pass
class HookBwd():
def __init__(self, m):
self.hook = m.register_backward_hook(self.hook_func)
def hook_func(self, m, gi, go): self.stored = go[0].detach().clone()
def __enter__(self, *args): return self
def __exit__(self, *args): self.hook.remove()
# The class index for the cat class
= 1
cls with HookBwd(learn.model[0]) as hookg:
with Hook(learn.model[0]) as hook:
= learn.model.eval()(x.cuda())
output = hook.stored
act # Get the gradients for the cat class for the first image in the test set
0,cls].backward()
output[= hookg.stored grad
# Calculate the average of the gradients across the feature map
= grad[0].mean(dim=[1,2], keepdim=True)
w = (w * act[0]).sum(0) cam_map
= plt.subplots()
_,ax =ax)
x_dec.show(ctx=0.6, extent=(0,224,224,0),
ax.imshow(cam_map.detach().cpu(), alpha='bilinear', cmap='magma'); interpolation
# Get the gradients for the output of the second to last ResNet group
with HookBwd(learn.model[0][-2]) as hookg:
with Hook(learn.model[0][-2]) as hook:
= learn.model.eval()(x.cuda())
output = hook.stored
act 0,cls].backward()
output[= hookg.stored grad
= grad[0].mean(dim=[1,2], keepdim=True)
w = (w * act[0]).sum(0) cam_map
= plt.subplots()
_,ax =ax)
x_dec.show(ctx=0.6, extent=(0,224,224,0),
ax.imshow(cam_map.detach().cpu(), alpha='bilinear', cmap='magma'); interpolation
print_source(ActivationStats)
@delegates()
class ActivationStats(HookCallback):
"Callback that record the mean and std of activations."
order=-20
def __init__(self, with_hist=False, **kwargs):
super().__init__(**kwargs)
self.with_hist = with_hist
def before_fit(self):
"Initialize stats."
super().before_fit()
self.stats = L()
def hook(self, m, i, o):
if isinstance(o, tuple): return self.hook_multi_ouput(o)
o = o.float()
res = {'mean': o.mean().item(), 'std': o.std().item(),
'near_zero': (o<=0.05).long().sum().item()/o.numel()}
if self.with_hist: res['hist'] = o.histc(40,0,10)
return res
def hook_multi_ouput(self,o_tuple):
"For outputs of RNN which are [nested] tuples of tensors"
res = []
for o in self._flatten_tuple(o_tuple):
if not(isinstance(o, Tensor)): continue
res.append(self.hook(None, None, o))
return res
def _flatten_tuple(self, o_tuple):
"Recursively flatten a [nested] tuple"
res = []
for it in o_tuple:
if isinstance(it, tuple): res += self._flatten_tuple(it)
else: res += [it]
return tuple(res)
def after_batch(self):
"Take the stored results and puts it in `self.stats`"
if self.training and (self.every is None or self.train_iter%self.every == 0):
self.stats.append(self.hooks.stored)
super().after_batch()
def layer_stats(self, idx):
lstats = self.stats.itemgot(idx)
return L(lstats.itemgot(o) for o in ('mean','std','near_zero'))
def hist(self, idx):
res = self.stats.itemgot(idx).itemgot('hist')
return torch.stack(tuple(res)).t().float().log1p()
def color_dim(self, idx, figsize=(10,5), ax=None):
"The 'colorful dimension' plot"
res = self.hist(idx)
if ax is None: ax = subplots(figsize=figsize)[1][0]
ax.imshow(res, origin='lower')
ax.axis('off')
def plot_layer_stats(self, idx):
_,axs = subplots(1, 3, figsize=(12,3))
for o,ax,title in zip(self.layer_stats(idx),axs,('mean','std','% near zero')):
ax.plot(o)
ax.set_title(title)
References
Previous: Notes on fastai Book Ch. 17
Next: Notes on fastai Book Ch. 19
About Me:
- I’m Christian Mills, a deep learning consultant specializing in computer vision and practical AI implementations.
- I help clients leverage cutting-edge AI technologies to solve real-world problems.
- Learn more about me or reach out via email at [email protected] to discuss your project.