Notes on fastai Book Ch. 12
- The Data
- Our First Language Model from Scratch
- Improving the RNN
- Multilayer RNNs
- LSTM
- Regularizing an LSTM
- References
#hide
# !pip install -Uqq fastbook
import fastbook
fastbook.setup_book()
from fastbook import *
import inspect
def print_source(obj):
for line in inspect.getsource(obj).split("\n"):
print(line)
A Language Model from Scratch
The Data
- try to think of the simplest useable dataset when starting on a new problem
- the starter dataset should allow you to quickly and easily try out methods and interpret the results
- one of the most common practical mistakes is failing to use appropriate datasets at appropriate times during the analysis process
- most people tend to start with datasets that are too big and too complicated
from fastai.text.all import *
fastai Human Numbers Dataset
- A synthetic dataset consisting of human number counts in text such as one, two, three, four..
- Useful for experimenting with Language Models
URLs.HUMAN_NUMBERS
'https://s3.amazonaws.com/fast-ai-sample/human_numbers.tgz'
= untar_data(URLs.HUMAN_NUMBERS)
path path
Path('/home/innom-dt/.fastai/data/human_numbers')
path.ls()
(#2) [Path('/home/innom-dt/.fastai/data/human_numbers/train.txt'),Path('/home/innom-dt/.fastai/data/human_numbers/valid.txt')]
= path.ls()[0] train_file
| head -5 cat $train_file
one
two
three
four
five
cat: write error: Broken pipe
= path.ls()[1] valid_file
| head -5 cat $valid_file
eight thousand one
eight thousand two
eight thousand three
eight thousand four
eight thousand five
cat: write error: Broken pipe
= L()
lines # Combine the training and validation sets into a single List
with open(path/'train.txt') as f: lines += L(*f.readlines())
with open(path/'valid.txt') as f: lines += L(*f.readlines())
lines
(#9998) ['one \n','two \n','three \n','four \n','five \n','six \n','seven \n','eight \n','nine \n','ten \n'...]
# Remove the '\n' new line characters and separate the words with a '.'
= ' . '.join([l.strip() for l in lines])
text 100] text[:
'one . two . three . four . five . six . seven . eight . nine . ten . eleven . twelve . thirteen . fo'
# Separate the words into a list
= text.split(' ')
tokens 10] tokens[:
['one', '.', 'two', '.', 'three', '.', 'four', '.', 'five', '.']
# Generate unique vocab
= L(*tokens).unique()
vocab vocab
(#30) ['one','.','two','three','four','five','six','seven','eight','nine'...]
list(vocab)) pd.DataFrame(
0 | |
---|---|
0 | one |
1 | . |
2 | two |
3 | three |
4 | four |
5 | five |
6 | six |
7 | seven |
8 | eight |
9 | nine |
10 | ten |
11 | eleven |
12 | twelve |
13 | thirteen |
14 | fourteen |
15 | fifteen |
16 | sixteen |
17 | seventeen |
18 | eighteen |
19 | nineteen |
20 | twenty |
21 | thirty |
22 | forty |
23 | fifty |
24 | sixty |
25 | seventy |
26 | eighty |
27 | ninety |
28 | hundred |
29 | thousand |
# Map words to their vocab indices
= {w:i for i,w in enumerate(vocab)}
word2idx # Numericalize dataset
= L(word2idx[i] for i in tokens)
nums nums
(#63095) [0,1,2,1,3,1,4,1,5,1...]
Our First Language Model from Scratch
# Create a list of (input, target) tuples
# input: the previous three words
# target: the next word
+3], tokens[i+3]) for i in range(0,len(tokens)-4,3)) L((tokens[i:i
(#21031) [(['one', '.', 'two'], '.'),(['.', 'three', '.'], 'four'),(['four', '.', 'five'], '.'),(['.', 'six', '.'], 'seven'),(['seven', '.', 'eight'], '.'),(['.', 'nine', '.'], 'ten'),(['ten', '.', 'eleven'], '.'),(['.', 'twelve', '.'], 'thirteen'),(['thirteen', '.', 'fourteen'], '.'),(['.', 'fifteen', '.'], 'sixteen')...]
# # Create a list of (input, target) tuples
# input: a tensor containing the numericalized forms of previous three words
# target: the numericalized form of the next word
= L((tensor(nums[i:i+3]), nums[i+3]) for i in range(0,len(nums)-4,3))
seqs seqs
(#21031) [(tensor([0, 1, 2]), 1),(tensor([1, 3, 1]), 4),(tensor([4, 1, 5]), 1),(tensor([1, 6, 1]), 7),(tensor([7, 1, 8]), 1),(tensor([1, 9, 1]), 10),(tensor([10, 1, 11]), 1),(tensor([ 1, 12, 1]), 13),(tensor([13, 1, 14]), 1),(tensor([ 1, 15, 1]), 16)...]
DataLoaders.from_dsets
<bound method DataLoaders.from_dsets of <class 'fastai.data.core.DataLoaders'>>
print_source(DataLoaders.from_dsets)
@classmethod
def from_dsets(cls, *ds, path='.', bs=64, device=None, dl_type=TfmdDL, **kwargs):
default = (True,) + (False,) * (len(ds)-1)
defaults = {'shuffle': default, 'drop_last': default}
tfms = {k:tuple(Pipeline(kwargs[k]) for i in range_of(ds)) for k in _batch_tfms if k in kwargs}
kwargs = merge(defaults, {k: tuplify(v, match=ds) for k,v in kwargs.items() if k not in _batch_tfms}, tfms)
kwargs = [{k: v[i] for k,v in kwargs.items()} for i in range_of(ds)]
return cls(*[dl_type(d, bs=bs, **k) for d,k in zip(ds, kwargs)], path=path, device=device)
= 64
bs # Split data between train and valid 80/20
= int(len(seqs) * 0.8)
cut = DataLoaders.from_dsets(seqs[:cut], seqs[cut:], bs=64, shuffle=False) dls
0].shape, dls.one_batch()[1].shape dls.one_batch()[
(torch.Size([64, 3]), torch.Size([64]))
0][0], dls.one_batch()[1][0] dls.one_batch()[
(tensor([0, 1, 2]), tensor(1))
Our Language Model in PyTorch
- Every word is interpreted in the information context of any words preceding it
class LMModel1(Module):
def __init__(self, vocab_sz, n_hidden):
# Input to hidden
self.i_h = nn.Embedding(vocab_sz, n_hidden)
# Hidden to hidden
self.h_h = nn.Linear(n_hidden, n_hidden)
# Hidden to output
self.h_o = nn.Linear(n_hidden,vocab_sz)
def forward(self, x):
# First input word
# Pass embedding for first word to first linear layer
= F.relu(self.h_h(self.i_h(x[:,0])))
h # Second input word
# Add embedding for second word to previous output
= h + self.i_h(x[:,1])
h # Pass to first linear layer
= F.relu(self.h_h(h))
h # Third input word
# Add embeddingfor third word to previous output
= h + self.i_h(x[:,2])
h # Pass to first linear layer
= F.relu(self.h_h(h))
h # Pass output to second linear layer
return self.h_o(h)
= Learner(dls, LMModel1(len(vocab), 64), loss_func=F.cross_entropy,
learn =accuracy)
metrics4, 1e-3) learn.fit_one_cycle(
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 1.813438 | 1.944979 | 0.466603 | 00:01 |
1 | 1.405106 | 1.702907 | 0.474447 | 00:01 |
2 | 1.427549 | 1.650981 | 0.489898 | 00:00 |
3 | 1.380016 | 1.685956 | 0.470882 | 00:01 |
range_of
<function fastcore.basics.range_of(a, b=None, step=None)>
print_source(range_of)
def range_of(a, b=None, step=None):
"All indices of collection `a`, if `a` is a collection, otherwise `range`"
if is_coll(a): a = len(a)
return list(range(a,b,step) if step is not None else range(a,b) if b is not None else range(a))
# Get the number of occurrences of each unique vocab item in the validation set
= 0,torch.zeros(len(vocab))
n,counts
n, countsfor x,y in dls.valid:
+= y.shape[0]
n # Keep track of
for i in range_of(vocab): counts[i] += (y==i).long().sum()
# Get the index for the most common token in the validation set
= torch.argmax(counts)
idx # Print the most common index
(idx, # Print the corresponding word for the index
vocab[idx.item()], # Calculate the likelihood of randomly picking the most common word
/n) counts[idx].item()
(tensor(29), 'thousand', 0.15165200855716662)
Note: This indicates the model is performing much better than picking a word at random.
Our First Recurrent Neural Network (a.k.a A Looping Network)
- replace the hardcoded forward function in the LMModel1 with a for loop
Improving the RNN
- the above LMModel2 version resets the hidden state for every new input sequence
- throwing away all the information we have about the sentences we have seen so far
- the above LMModel2 version only tries to predict the fourth word
Maintaining the State of an RNN
class LMModel3(Module):
def __init__(self, vocab_sz, n_hidden):
self.i_h = nn.Embedding(vocab_sz, n_hidden)
self.h_h = nn.Linear(n_hidden, n_hidden)
self.h_o = nn.Linear(n_hidden,vocab_sz)
# Maintain the same hidden state across input sequences
self.h = 0
def forward(self, x):
for i in range(3):
self.h = self.h + self.i_h(x[:,i])
self.h = F.relu(self.h_h(self.h))
= self.h_o(self.h)
out # Detach the hidden state from the pytorch computation graph
self.h = self.h.detach()
return out
def reset(self): self.h = 0
Backpropogation Through Time (BPTT)
- Treating a neural net with effectively one layer per time step (usually refactored using a loop) as one big model, and calculating gradients on it in the usual way
- usually use Truncated BPTT which detaches the history of computation steps in the hidden state every few time steps.
= len(seqs)//bs
m len(seqs) m,bs,
(328, 64, 21031)
def group_chunks(ds, bs):
# Calculate the number of groups
= len(ds) // bs
m # Initialize new dataset container
= L()
new_ds # Group dataset into chunks
for i in range(m): new_ds += L(ds[i + m*j] for j in range(bs))
return new_ds
# Split dataset 80/20 into training and validation
= int(len(seqs) * 0.8)
cut = DataLoaders.from_dsets(
dls
group_chunks(seqs[:cut], bs),
group_chunks(seqs[cut:], bs), =bs,
bs# Drop the last batch that does not have the shape of bs
=True,
drop_last=False) shuffle
ModelResetter
fastai.callback.rnn.ModelResetter
print_source(ModelResetter)
@docs
class ModelResetter(Callback):
"`Callback` that resets the model at each validation/training step"
def before_train(self): self.model.reset()
def before_validate(self): self.model.reset()
def after_fit(self): self.model.reset()
_docs = dict(before_train="Reset the model before training",
before_validate="Reset the model before validation",
after_fit="Reset the model after fitting")
fastai Callbacks
- Documentation
after_create:
called after the Learner is createdbefore_fit:
called before starting training or inference, ideal for initial setup.before_epoch:
called at the beginning of each epoch, useful for any behavior you need to reset at each epoch.before_train:
called at the beginning of the training part of an epoch.before_batch:
called at the beginning of each batch, just after drawing said batch. It can be used to do any setup necessary for the batch (like hyper-parameter scheduling) or to change the input/target before it goes in the model (change of the input with techniques like mixup for instance).after_pred:
called after computing the output of the model on the batch. It can be used to change that output before it’s fed to the loss.after_loss:
called after the loss has been computed, but before the backward pass. It can be used to add any penalty to the loss (AR or TAR in RNN training for instance).before_backward:
called after the loss has been computed, but only in training mode (i.e. when the backward pass will be used)before_step:
called after the backward pass, but before the update of the parameters. It can be used to do any change to the gradients before said update (gradient clipping for instance).after_step:
called after the step and before the gradients are zeroed.after_batch:
called at the end of a batch, for any clean-up before the next one.after_train:
called at the end of the training phase of an epoch.before_validate:
called at the beginning of the validation phase of an epoch, useful for any setup needed specifically for validation.after_validate:
called at the end of the validation part of an epoch.after_epoch:
called at the end of an epoch, for any clean-up before the next one.after_fit:
called at the end of training, for final clean-up.
Callback
fastai.callback.core.Callback
print_source(Callback)
@funcs_kwargs(as_method=True)
class Callback(Stateful,GetAttr):
"Basic class handling tweaks of the training loop by changing a `Learner` in various events"
order,_default,learn,run,run_train,run_valid = 0,'learn',None,True,True,True
_methods = _events
def __init__(self, **kwargs): assert not kwargs, f'Passed unknown events: {kwargs}'
def __repr__(self): return type(self).__name__
def __call__(self, event_name):
"Call `self.{event_name}` if it's defined"
_run = (event_name not in _inner_loop or (self.run_train and getattr(self, 'training', True)) or
(self.run_valid and not getattr(self, 'training', False)))
res = None
if self.run and _run: res = getattr(self, event_name, noop)()
if event_name=='after_fit': self.run=True #Reset self.run to True at each end of fit
return res
def __setattr__(self, name, value):
if hasattr(self.learn,name):
warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
super().__setattr__(name, value)
@property
def name(self):
"Name of the `Callback`, camel-cased and with '*Callback*' removed"
return class2attr(self, 'Callback')
= Learner(dls,
learn len(vocab), 64),
LMModel3(=F.cross_entropy,
loss_func=accuracy,
metrics# reset the model at the beginning of each epoch and before each validation phase
=ModelResetter)
cbs10, 3e-3) learn.fit_one_cycle(
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 1.695570 | 1.837262 | 0.474519 | 00:01 |
1 | 1.316114 | 1.939660 | 0.366346 | 00:01 |
2 | 1.102734 | 1.578932 | 0.469471 | 00:01 |
3 | 1.017313 | 1.470766 | 0.552163 | 00:01 |
4 | 0.961458 | 1.568437 | 0.551923 | 00:01 |
5 | 0.920572 | 1.632755 | 0.574519 | 00:01 |
6 | 0.932616 | 1.634864 | 0.588221 | 00:01 |
7 | 0.848161 | 1.668468 | 0.587500 | 00:01 |
8 | 0.802442 | 1.698610 | 0.591827 | 00:01 |
9 | 0.794550 | 1.716233 | 0.594952 | 00:01 |
Creating More Signal
- we can increase the amount of signal for updating the model weights by predicting the next word after every single word, rather than every three words
# Define the sequence length
= 16
sl # Update the dependent variable to include each of the words
# that follow each of the words in the independent variable
= L((tensor(nums[i:i+sl]), tensor(nums[i+1:i+sl+1]))
seqs for i in range(0,len(nums)-sl-1,sl))
# Define the split for the training and validation set
= int(len(seqs) * 0.8)
cut = DataLoaders.from_dsets(group_chunks(seqs[:cut], bs),
dls
group_chunks(seqs[cut:], bs),=bs, drop_last=True, shuffle=False) bs
for o in s) for s in seqs[0]] [L(vocab[o]
[(#16) ['one','.','two','.','three','.','four','.','five','.'...],
(#16) ['.','two','.','three','.','four','.','five','.','six'...]]
class LMModel4(Module):
def __init__(self, vocab_sz, n_hidden):
self.i_h = nn.Embedding(vocab_sz, n_hidden)
self.h_h = nn.Linear(n_hidden, n_hidden)
self.h_o = nn.Linear(n_hidden,vocab_sz)
self.h = 0
def forward(self, x):
= []
outs for i in range(sl):
self.h = self.h + self.i_h(x[:,i])
self.h = F.relu(self.h_h(self.h))
# Store the output for each word in the current sequence
self.h_o(self.h))
outs.append(self.h = self.h.detach()
# stack the output for each word in the current sequence
return torch.stack(outs, dim=1)
def reset(self): self.h = 0
# Define custom loss function that flattens the output before calculating cross entropy
def loss_func(inp, targ):
return F.cross_entropy(inp.view(-1, len(vocab)), targ.view(-1))
= Learner(dls, LMModel4(len(vocab), 64), loss_func=loss_func,
learn =accuracy, cbs=ModelResetter)
metrics15, 3e-3) learn.fit_one_cycle(
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 3.229987 | 3.069768 | 0.249756 | 00:00 |
1 | 2.291759 | 1.903835 | 0.468018 | 00:00 |
2 | 1.719411 | 1.769336 | 0.469157 | 00:00 |
3 | 1.444394 | 1.729377 | 0.459554 | 00:00 |
4 | 1.273674 | 1.625678 | 0.531169 | 00:00 |
5 | 1.141202 | 1.762818 | 0.545898 | 00:00 |
6 | 1.037926 | 1.575556 | 0.573812 | 00:00 |
7 | 0.939284 | 1.470020 | 0.614095 | 00:00 |
8 | 0.858596 | 1.532887 | 0.628255 | 00:00 |
9 | 0.784250 | 1.495697 | 0.655843 | 00:00 |
10 | 0.739764 | 1.539676 | 0.666423 | 00:00 |
11 | 0.693413 | 1.550242 | 0.662191 | 00:00 |
12 | 0.661127 | 1.519285 | 0.680908 | 00:00 |
13 | 0.635551 | 1.523878 | 0.676921 | 00:00 |
14 | 0.621697 | 1.531653 | 0.684408 | 00:00 |
Multilayer RNNs
- pass the activations from one RNN into another RNN
The Model
class LMModel5(Module):
def __init__(self, vocab_sz, n_hidden, n_layers):
self.i_h = nn.Embedding(vocab_sz, n_hidden)
self.rnn = nn.RNN(n_hidden, n_hidden, n_layers, batch_first=True)
self.h_o = nn.Linear(n_hidden, vocab_sz)
self.h = torch.zeros(n_layers, bs, n_hidden)
def forward(self, x):
= self.rnn(self.i_h(x), self.h)
res,h self.h = h.detach()
return self.h_o(res)
def reset(self): self.h.zero_()
len(vocab), 64, 2) LMModel5(
LMModel5(
(i_h): Embedding(30, 64)
(rnn): RNN(64, 64, num_layers=2, batch_first=True)
(h_o): Linear(in_features=64, out_features=30, bias=True)
)
= Learner(dls, LMModel5(len(vocab), 64, 2),
learn =CrossEntropyLossFlat(),
loss_func=accuracy, cbs=ModelResetter)
metrics15, 3e-3) learn.fit_one_cycle(
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 3.070420 | 2.586252 | 0.460775 | 00:00 |
1 | 2.154392 | 1.760734 | 0.471680 | 00:00 |
2 | 1.709090 | 1.851027 | 0.327311 | 00:00 |
3 | 1.523287 | 1.790196 | 0.412028 | 00:00 |
4 | 1.364664 | 1.816422 | 0.468262 | 00:00 |
5 | 1.247051 | 1.796951 | 0.493001 | 00:00 |
6 | 1.156087 | 1.907447 | 0.489095 | 00:00 |
7 | 1.073325 | 2.014389 | 0.499268 | 00:00 |
8 | 0.995001 | 2.056770 | 0.501139 | 00:00 |
9 | 0.927453 | 2.080244 | 0.503743 | 00:00 |
10 | 0.874861 | 2.084781 | 0.502441 | 00:00 |
11 | 0.837194 | 2.102611 | 0.514974 | 00:00 |
12 | 0.812340 | 2.111124 | 0.512126 | 00:00 |
13 | 0.797198 | 2.110253 | 0.513346 | 00:00 |
14 | 0.789102 | 2.108808 | 0.513997 | 00:00 |
Note: The multi-layer RNN performs worse than the single-layer RNN
Exploding or Disappearing Activations
deeper models are more difficult to train
- performing matrix multiplication so many times can cause numbers to get extremely big or extremely small
- floating point numbers get less accurate the further away they get from zero
What you never wanted to know about floating point but will be forced to find out
Two types of layers are frequently used to avoid exploding activations in RNNs
- Gated Recurrent Units (GRUs)
- Long short-term memory (LSTM)
LSTM
- introduced in 1997 by Jürgen Schmidhuber and Sepp Hochreiter
- Normal RNNs are realy bad at retaining memory of what happened much earlier in a sentence
- LSTMs maintain two hidden states to address this
- cell state: responsible for keeping long short-term memory
- hidden state: focuses on the next token to predict
- \(x_{t}\) the current input
- \((h_{t-1})\): the previous hidden state
- \((c_{t-1})\): the previous hidden state
- \(\sigma\): sigmoid function
- \(tanh\): a sigmoid function rescaled to the range \([-1,1]\)
- \(tanh(x) = \frac{e^{x}+e^{-x}}{e^{x}-e^{-x}} = 2\sigma(2x)-1\)
- four neural nets (orange) called gates (left to right):
- forget gate: a linear layer followed by a sigmoid (i.e. output will be scalars [0,1])
- multipy output by cell state to determine which information to keep
- gives the LSTM the ability to forget things about its long-term state
- input gate: works with the third gate (
tanh
) to update the cell state- decided which elements of the cell state to updates (values close to 1)
- cell gate: determines what the updated values are for the cell state
- output gate: determines which information from the cell state to use to generate the new hidden state
- forget gate: a linear layer followed by a sigmoid (i.e. output will be scalars [0,1])
2*torch.sigmoid(2*tensor(0.5)) - 1
tensor(0.4621)
0.5)) torch.tanh(tensor(
tensor(0.4621)
Building an LSTM from Scratch
class LSTMCell(Module):
def __init__(self, ni, nh):
self.forget_gate = nn.Linear(ni + nh, nh)
self.input_gate = nn.Linear(ni + nh, nh)
self.cell_gate = nn.Linear(ni + nh, nh)
self.output_gate = nn.Linear(ni + nh, nh)
def forward(self, input, state):
= state
h,c = torch.cat([h, input], dim=1)
h = torch.sigmoid(self.forget_gate(h))
forget = c * forget
c = torch.sigmoid(self.input_gate(h))
inp = torch.tanh(self.cell_gate(h))
cell = c + inp * cell
c = torch.sigmoid(self.output_gate(h))
out = out * torch.tanh(c)
h return h, (h,c)
Note: It is better for performance reasons to do one big matrix multiplication than four smaller ones
- launch the special fast kernel on the GPU only once
- give the GPU more work to do in parallel
class LSTMCell(Module):
def __init__(self, ni, nh):
self.ih = nn.Linear(ni,4*nh)
self.hh = nn.Linear(nh,4*nh)
def forward(self, input, state):
= state
h,c # One big multiplication for all the gates is better than 4 smaller ones
= (self.ih(input) + self.hh(h)).chunk(4, 1)
gates = map(torch.sigmoid, gates[:3])
ingate,forgetgate,outgate = gates[3].tanh()
cellgate
= (forgetgate*c) + (ingate*cellgate)
c = outgate * c.tanh()
h return h, (h,c)
torch chunk
help(torch.chunk)
Help on built-in function chunk:
chunk(...)
chunk(input, chunks, dim=0) -> List of Tensors
Attempts to split a tensor into the specified number of chunks. Each chunk is a view of
the input tensor.
.. note::
This function may return less then the specified number of chunks!
.. seealso::
:func:`torch.tensor_split` a function that always returns exactly the specified number of chunks
If the tensor size along the given dimesion :attr:`dim` is divisible by :attr:`chunks`,
all returned chunks will be the same size.
If the tensor size along the given dimension :attr:`dim` is not divisible by :attr:`chunks`,
all returned chunks will be the same size, except the last one.
If such division is not possible, this function may return less
than the specified number of chunks.
Arguments:
input (Tensor): the tensor to split
chunks (int): number of chunks to return
dim (int): dimension along which to split the tensor
Example::
>>> torch.arange(11).chunk(6)
(tensor([0, 1]),
tensor([2, 3]),
tensor([4, 5]),
tensor([6, 7]),
tensor([8, 9]),
tensor([10]))
>>> torch.arange(12).chunk(6)
(tensor([0, 1]),
tensor([2, 3]),
tensor([4, 5]),
tensor([6, 7]),
tensor([8, 9]),
tensor([10, 11]))
>>> torch.arange(13).chunk(6)
(tensor([0, 1, 2]),
tensor([3, 4, 5]),
tensor([6, 7, 8]),
tensor([ 9, 10, 11]),
tensor([12]))
= torch.arange(0,10); t t
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
2) t.chunk(
(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9]))
Training a Language Model Using LSTMs
class LMModel6(Module):
def __init__(self, vocab_sz, n_hidden, n_layers):
self.i_h = nn.Embedding(vocab_sz, n_hidden)
self.rnn = nn.LSTM(n_hidden, n_hidden, n_layers, batch_first=True)
self.h_o = nn.Linear(n_hidden, vocab_sz)
self.h = [torch.zeros(n_layers, bs, n_hidden) for _ in range(2)]
def forward(self, x):
= self.rnn(self.i_h(x), self.h)
res,h self.h = [h_.detach() for h_ in h]
return self.h_o(res)
def reset(self):
for h in self.h: h.zero_()
# Using a two-layer LSTM
= Learner(dls, LMModel6(len(vocab), 64, 2),
learn =CrossEntropyLossFlat(),
loss_func=accuracy, cbs=ModelResetter)
metrics learn.model
LMModel6(
(i_h): Embedding(30, 64)
(rnn): LSTM(64, 64, num_layers=2, batch_first=True)
(h_o): Linear(in_features=64, out_features=30, bias=True)
)
15, 1e-2) learn.fit_one_cycle(
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 3.013088 | 2.705310 | 0.417074 | 00:01 |
1 | 2.215323 | 1.904673 | 0.406657 | 00:01 |
2 | 1.622977 | 1.772446 | 0.438232 | 00:01 |
3 | 1.319893 | 1.853711 | 0.519613 | 00:00 |
4 | 1.096065 | 1.868788 | 0.554118 | 00:01 |
5 | 0.872888 | 1.679482 | 0.609375 | 00:01 |
6 | 0.590291 | 1.355017 | 0.661458 | 00:01 |
7 | 0.385917 | 1.319989 | 0.667887 | 00:01 |
8 | 0.284691 | 1.221118 | 0.689290 | 00:01 |
9 | 0.228731 | 1.181922 | 0.730632 | 00:01 |
10 | 0.172228 | 1.250237 | 0.727946 | 00:01 |
11 | 0.124468 | 1.155407 | 0.754720 | 00:01 |
12 | 0.090831 | 1.183195 | 0.749674 | 00:01 |
13 | 0.071399 | 1.179867 | 0.750081 | 00:01 |
14 | 0.061995 | 1.168421 | 0.753499 | 00:01 |
Note: We were able to use a higher learning rate and achieve a much higher accuracy than the multi-layer RNN. Note: There is still some overfitting.
Regularizing an LSTM
- Regularizing and Optimizing LSTM Language Models
- used an LSTM with dropout, activation regularization, and temporal activation regularization to beat state-of-the-art results that previously required much more complicated models
- called the combination an AWD-LSTM
Dropout
- Improving neural networks by preventing co-adaptation of feature detectors
- Dropout: A Simple Way to Prevent Neural Networks from Overfitting
- randomly change some activations to zero at training time
- makes activations more noisy
- makes sure all neurons actively work toward the output
- makes the model more robust
- need to rescale activations after applying dropout
- divide activations by \(1-p\) where p is the probability to keep an activation
- using dropout before passing the output of our LSTM to the final layer will help reduce overfitting
- make sure to turn off dropout during inference
class Dropout(Module):
def __init__(self, p): self.p = p
def forward(self, x):
if not self.training: return x
= x.new(*x.shape).bernoulli_(1-p)
mask return x * mask.div_(1-p)
help(torch.bernoulli)
Help on built-in function bernoulli:
bernoulli(...)
bernoulli(input, *, generator=None, out=None) -> Tensor
Draws binary random numbers (0 or 1) from a Bernoulli distribution.
The :attr:`input` tensor should be a tensor containing probabilities
to be used for drawing the binary random number.
Hence, all values in :attr:`input` have to be in the range:
:math:`0 \leq \text{input}_i \leq 1`.
The :math:`\text{i}^{th}` element of the output tensor will draw a
value :math:`1` according to the :math:`\text{i}^{th}` probability value given
in :attr:`input`.
.. math::
\text{out}_{i} \sim \mathrm{Bernoulli}(p = \text{input}_{i})
The returned :attr:`out` tensor only has values 0 or 1 and is of the same
shape as :attr:`input`.
:attr:`out` can have integral ``dtype``, but :attr:`input` must have floating
point ``dtype``.
Args:
input (Tensor): the input tensor of probability values for the Bernoulli distribution
Keyword args:
generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling
out (Tensor, optional): the output tensor.
Example::
>>> a = torch.empty(3, 3).uniform_(0, 1) # generate a uniform random matrix with range [0, 1]
>>> a
tensor([[ 0.1737, 0.0950, 0.3609],
[ 0.7148, 0.0289, 0.2676],
[ 0.9456, 0.8937, 0.7202]])
>>> torch.bernoulli(a)
tensor([[ 1., 0., 0.],
[ 0., 0., 0.],
[ 1., 1., 1.]])
>>> a = torch.ones(3, 3) # probability of drawing "1" is 1
>>> torch.bernoulli(a)
tensor([[ 1., 1., 1.],
[ 1., 1., 1.],
[ 1., 1., 1.]])
>>> a = torch.zeros(3, 3) # probability of drawing "1" is 0
>>> torch.bernoulli(a)
tensor([[ 0., 0., 0.],
[ 0., 0., 0.],
[ 0., 0., 0.]])
Activation Regularization and Temporal Activation Regularization
- both are similar to weight decay (AR)
- activation regularization: try to make the final activations produced by the LSTM as small as possible
loss += alpha * activations.pow(2).mean()
- often applied on dropped-out activations to not penalize the activations set to zero
- temporal activation regularization (TAR)
- linked to the fact we are predicting tokens in a sentence
- the outputs of our LSTMs should somewhat make sense when we read them in order
- TAR encourages that behavior by adding a penalty to the loss to make the difference between two consecutive activations as small as possible
loss += beta * (activations[:,1:] - activations[:,:-1]).pow(2).mean()
- alpha and beta are tunable hyperparameters
- applied to non-dropped-out activations (because the zeros in the dropped-out activations create big differences)
Training a Weight-Tied Regularized LSTM
- need to return the normal output from the LSTM, the dropped-out activations, and the activations from the LSTMs
Weight Tying
- in a language model, the input embeddings represent a mapping from English words to activations and the output hidden layer represents a mapping from activations to English words
- these mappings could be the same
- introduced in AWD-LSTM paper
self.h_o.weight = self.i_h.weight
class LMModel7(Module):
def __init__(self, vocab_sz, n_hidden, n_layers, p):
self.i_h = nn.Embedding(vocab_sz, n_hidden)
self.rnn = nn.LSTM(n_hidden, n_hidden, n_layers, batch_first=True)
self.drop = nn.Dropout(p)
self.h_o = nn.Linear(n_hidden, vocab_sz)
self.h_o.weight = self.i_h.weight
self.h = [torch.zeros(n_layers, bs, n_hidden) for _ in range(2)]
def forward(self, x):
= self.rnn(self.i_h(x), self.h)
raw,h = self.drop(raw)
out self.h = [h_.detach() for h_ in h]
return self.h_o(out),raw,out
def reset(self):
for h in self.h: h.zero_()
= Learner(dls, LMModel7(len(vocab), 64, 2, 0.5),
learn =CrossEntropyLossFlat(), metrics=accuracy,
loss_func=[ModelResetter, RNNRegularizer(alpha=2, beta=1)]) cbs
RNNRegularizer
fastai.callback.rnn.RNNRegularizer
print_source(RNNRegularizer)
class RNNRegularizer(Callback):
"Add AR and TAR regularization"
order,run_valid = RNNCallback.order+1,False
def __init__(self, alpha=0., beta=0.): store_attr()
def after_loss(self):
if not self.training: return
if self.alpha: self.learn.loss_grad += self.alpha * self.rnn.out.float().pow(2).mean()
if self.beta:
h = self.rnn.raw_out
if len(h)>1: self.learn.loss_grad += self.beta * (h[:,1:] - h[:,:-1]).float().pow(2).mean()
= TextLearner(dls, LMModel7(len(vocab), 64, 2, 0.4),
learn =CrossEntropyLossFlat(), metrics=accuracy) loss_func
learn.model
LMModel7(
(i_h): Embedding(30, 64)
(rnn): LSTM(64, 64, num_layers=2, batch_first=True)
(drop): Dropout(p=0.4, inplace=False)
(h_o): Linear(in_features=64, out_features=30, bias=True)
)
15, 1e-2, wd=0.1) learn.fit_one_cycle(
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 2.620218 | 1.797085 | 0.484294 | 00:01 |
1 | 1.622718 | 1.452620 | 0.652181 | 00:01 |
2 | 0.864787 | 0.726230 | 0.773275 | 00:01 |
3 | 0.434755 | 0.699705 | 0.828613 | 00:01 |
4 | 0.225359 | 0.579946 | 0.842855 | 00:01 |
5 | 0.126518 | 0.571510 | 0.850911 | 00:01 |
6 | 0.076041 | 0.444107 | 0.874349 | 00:01 |
7 | 0.051340 | 0.366569 | 0.882487 | 00:01 |
8 | 0.037389 | 0.547799 | 0.854818 | 00:01 |
9 | 0.027291 | 0.392787 | 0.880615 | 00:01 |
10 | 0.022100 | 0.354383 | 0.889648 | 00:01 |
11 | 0.018304 | 0.380172 | 0.885417 | 00:01 |
12 | 0.015668 | 0.384031 | 0.885010 | 00:01 |
13 | 0.013562 | 0.389092 | 0.884033 | 00:01 |
14 | 0.012376 | 0.383106 | 0.885254 | 00:01 |
Note: This performance is significantly better than the regular LSTM.
References
Previous: Notes on fastai Book Ch. 11
Next: Notes on fastai Book Ch. 13
I’m Christian Mills, a deep learning consultant specializing in practical AI implementations. I help clients leverage cutting-edge AI technologies to solve real-world problems.
Interested in working together? Fill out my Quick AI Project Assessment form or learn more about me.