Chapter 3: Recurrent Neural Network
9 min readWhat this chapter is ultimately trying to achieve
To introduce a type of neural network specifically designed to process sequences of data (like words in a sentence) one element at a time, while maintaining an internal “memory” or “state” that captures information from previous elements in the sequence. This “memory” allows RNNs to understand context that spans multiple tokens, which is something BoW or simple n-gram models struggle with significantly.
Let’s break down the key concepts:
3.1 Elman RNN (Simple Recurrent Neural Network)
What it’s ultimately trying to achieve: To process a sequence of inputs (e.g., word embeddings) step-by-step, and at each step, produce an output and update an internal hidden state. This hidden state acts as a compressed summary of the sequence seen so far.
The Core Idea (The Loop): Imagine a standard neural network unit. Now, give it a loop: the output of the unit at a given time step
t(specifically, its hidden stateh_t) is fed back into the unit as an additional input at the next time stept+1, along with the actual next input from the sequencex_{t+1}.- Input: At each time step
t, the RNN unit takes two things:- The current input from the sequence,
x_t(e.g., the embedding of the current word). - The hidden state from the previous time step,
h_{t-1}.
- The current input from the sequence,
- Calculation: Inside the unit, these inputs are typically transformed by weight matrices and an activation function (often
tanhin classic RNNs) to produce:- The new hidden state for the current time step,
h_t. - (Optionally) An output for the current time step,
y_t. For language modeling, thisy_twould be related to predicting the next word.
- The new hidden state for the current time step,
- Formula (Conceptual):
h_t = tanh(W_hh * h_{t-1} + W_xh * x_t + b_h)(Hidden state update)y_t = W_hy * h_t + b_y(Output at time t, often passed through softmax for probabilities) WhereW_hh,W_xh,W_hyare weight matrices andb_h,b_yare bias terms. These weights are shared across all time steps, which is key to how RNNs generalize.
- Input: At each time step
Visualizing It: You can “unroll” an RNN in time. It looks like a chain of identical network units, where the hidden state from one unit is passed to the next.
Realism and Challenges:
- Vanishing/Exploding Gradients: When training RNNs with backpropagation through time (BPTT), gradients can become very small (vanish) or very large (explode) as they are propagated back through many time steps. This makes it hard for simple RNNs to learn long-range dependencies (e.g., connecting a word at the beginning of a long sentence to a word at the end). ReLU helps with vanishing gradients compared to tanh/sigmoid in deep feedforward nets, but the recurrent nature still poses challenges. LSTMs and GRUs (which are more advanced RNN variants, not deeply covered in a 100-page book but important to know about) were developed to mitigate this.
3.2 Mini-Batch Gradient Descent (Revisited for Sequences)
What it’s ultimately trying to achieve: To efficiently train RNNs (and other large models) by processing multiple sequences in parallel within each training step, rather than one sequence at a time or the entire dataset at once.
The Setup for Sequences: When we feed data to an RNN, it’s often in the shape of
(batch_size, sequence_length, embedding_dimensionality).batch_size: The number of sequences processed together.sequence_length: The number of tokens in each sequence (sequences are often padded to be the same length in a batch).embedding_dimensionality: The size of the vector representing each token.
Why it’s important: Processing batches leverages the parallel processing capabilities of modern hardware (like GPUs), making training much faster. It also provides a more stable estimate of the gradient compared to processing single examples (stochastic gradient descent).
3.3 Programming an RNN (in PyTorch)
What it’s ultimately trying to achieve: To translate the mathematical concept of an RNN unit and a stack of RNN layers into working code.
Key PyTorch Components:
nn.Module: The base class for all neural network modules in PyTorch. Our RNN unit and the full RNN model will inherit from this.nn.Parameter: Wraps a tensor to tell PyTorch that it’s a learnable model parameter (like the weight matricesW_hh,W_xh).- The
__init__method: Where you define the layers and parameters of your model. - The
forwardmethod: Where you define how the input data flows through the layers to produce an output. For an RNN, this will involve a loop over the time steps of the input sequence.
Implementing the ElmanRNNUnit:
- Initialize weight matrices (
Uhfor hidden-to-hidden,Whfor input-to-hidden) and a bias term (b). - The
forwardmethod takes current inputxand previous hidden stateh_prevand computesh_new = tanh(x @ Wh + h_prev @ Uh + b).
- Initialize weight matrices (
Implementing the full ElmanRNN (stacking layers):
- The
ElmanRNNclass would contain a list ofElmanRNNUnitinstances (one for each layer). - Its
forwardmethod would:- Initialize hidden states for all layers (usually to zeros).
- Loop through each token (time step
t) in the input sequences of the batch. - For each token, loop through each RNN layer:
- The input to the first layer is the token’s embedding.
- The input to subsequent layers is the hidden state from the layer below at the same time step.
- Each layer updates its hidden state.
- Collect the outputs (usually the hidden states of the last layer at each time step).
- The
3.4 RNN as a Language Model
What it’s ultimately trying to achieve: To use the RNN architecture to perform the core language modeling task: predicting the next token in a sequence.
The Architecture:
- Embedding Layer: Converts input token IDs into dense embedding vectors. This is often
nn.Embeddingin PyTorch. - RNN Layers: One or more RNN layers (like our
ElmanRNN) process the sequence of embeddings and output a sequence of hidden states (usually from the final RNN layer). - Output (Linear) Layer / Classification Head: A fully connected linear layer takes the RNN’s hidden state output at each time step
tand transforms it into a vector of logits, where the size of this vector is the vocabulary size. - Softmax (implicitly with CrossEntropyLoss): These logits are then (conceptually, often combined within the loss function) passed through a softmax function to get probabilities for each word in the vocabulary being the next word.
- Embedding Layer: Converts input token IDs into dense embedding vectors. This is often
Training:
- Input Sequence: A sequence of token IDs, e.g.,
[token_A, token_B, token_C]. - Target Sequence: The input sequence shifted by one position, e.g.,
[token_B, token_C, token_D]. - At each time step
t, the model processesinput_token_tand its goal is to output a high probability fortarget_token_t. - The cross-entropy loss is calculated between the predicted probability distribution and the actual target token at each position, and then averaged.
- Input Sequence: A sequence of token IDs, e.g.,
3.5 Embedding Layer (Deeper Dive with nn.Embedding)
What it’s ultimately trying to achieve: To provide a learnable lookup table that maps discrete token indices (integers) to dense, continuous-valued embedding vectors.
How it works in PyTorch (
nn.Embedding):- When you create
nn.Embedding(vocab_size, emb_dim), PyTorch initializes a weight matrix of shape(vocab_size, emb_dim)with random values. Each rowiof this matrix is the embedding vector for token IDi. - When you pass a tensor of token IDs to this layer, it simply looks up and returns the corresponding rows (embedding vectors).
- These embedding vectors are learnable parameters. During training, gradients flow back to them, and they get updated to better represent the tokens for the given task.
padding_idx: You can specify an index to be treated as a padding token. The embedding for this token will be a zero vector and (importantly) will not be updated during training.
- When you create
3.6 Training an RNN Language Model (The Full Loop in PyTorch)
What it’s ultimately trying to achieve: To put all the pieces together – data preparation, model instantiation, loss function, optimizer, and the training loop – to actually train an RNN LM.
Key Steps in the Training Loop (per epoch, per batch):
model.train(): Set the model to training mode.- Get
input_seqandtarget_seqfrom theDataLoader. - Move data to the correct device (CPU/GPU).
optimizer.zero_grad(): Clear old gradients.outputs = model(input_seq): Forward pass to get logits.- Reshape
outputsandtarget_seqso that the loss can be computed across all tokens in the batch efficiently. Typically, this means flattening them so that each row corresponds to a single token prediction:outputsbecomes(batch_size * seq_len, vocab_size)target_seqbecomes(batch_size * seq_len)
loss = criterion(outputs, target_seq): Calculate the cross-entropy loss. Remembernn.CrossEntropyLossin PyTorch expects raw logits and handles the softmax internally. It also allows anignore_indexparameter, which is crucial for not calculating loss on padding tokens in thetarget_seq.loss.backward(): Backward pass to compute gradients.optimizer.step(): Update model parameters.
Reproducibility: Setting seeds (
random.seed(),torch.manual_seed(),torch.cuda.manual_seed_all()) is important for consistent results, especially when debugging or comparing experiments.
3.7 Dataset and DataLoader (PyTorch Utilities)
- What they are ultimately trying to achieve:
To provide a standardized and efficient way to load, preprocess, and iterate over data in batches during training.
Dataset: An abstract class representing your dataset. You need to implement:__init__(self, ...): Load/prepare your data (e.g., read from file, tokenize).__len__(self): Return the total number of samples in the dataset.__getitem__(self, idx): Return theidx-th sample (e.g., an input sequence and its corresponding target sequence, as tensors).
DataLoader: Wraps aDatasetand provides an iterator to loop over the data in batches. It handles:- Batching.
- Shuffling (optional, but good for training).
- Parallel data loading using multiple worker processes (optional, for speed).
3.8 Training Data and Loss Computation (for Language Modeling)
What it’s ultimately trying to achieve: To clarify exactly how input and target sequences are structured for training an autoregressive language model, and how the loss is computed across all positions.
The “Shifted” Target: For an input sequence like
[T1, T2, T3, T4], the target sequence is[T2, T3, T4, T5].- When the model sees
T1, it tries to predictT2. - When it sees
T1, T2, it tries to predictT3. - And so on. The hidden state
h_tcarries context fromT1...T_tto help predictT_{t+1}.
- When the model sees
Loss Calculation: The cross-entropy loss is calculated at each position where a prediction is made. The total loss for a sequence is typically the average of these per-position losses. When batching, it’s the average loss over all predictable tokens in the batch.