Transformer Decoder Block
The decoder block is the fundamental building unit of the transformer.
Each block combines attention, feedforward networks, normalization, and residual connections into a repeatable structure.
Structure of a Decoder Block
A decoder block has two main parts:
- Multi-Head Self-Attention (MHA) → lets tokens exchange information.
- Feedforward Network (FFN) → transforms the attended features into richer representations.
Surrounding these are:
- RMSNorm → stabilizes training by normalizing activations.
- Residual Connections → ensure information from earlier layers isn’t lost.
The primary block flow is:
Input → Norm → Attention → Residual → Norm → Feedforward → Residual → Output
This “pre-norm” setup (normalize before each sub-layer) is known to improve stability in deep transformers.
Example Walkthrough
Let’s step through what happens inside one decoder block.
Suppose we have an input tensor x of shape (batch, seq_len, n_embd).
1. First Normalization
h = self.norm1(x)
x.- This ensures the activations are scaled to a stable range before entering attention.
- Unlike LayerNorm, RMSNorm does not recenter the mean — it only rescales variance.
2. Multi-Head Self-Attention
attn_out = self.attention(h, freqs_complex)
- Rotary Position Embeddings (RoPE) are applied to Q and K to inject positional info.
- Attention computes how strongly each token attends to others in the sequence.
- The output has the same shape as the input:
(batch, seq_len, n_embd).
3. First Residual Connection
h = x + attn_out
x back to the attention output.- This is called a residual connection (or skip connection).
Why is this important?
- Imagine stacking dozens of layers. Without skip connections, the network could "forget" the original signal after being transformed multiple times.
- By adding
xback, we preserve the original information while also giving the model access to the new transformed features from attention. - During backpropagation, residuals also help gradients flow more smoothly, preventing vanishing or exploding gradients.
- In practice, you can think of it as: the model learns adjustments (deltas) on top of the original input, instead of rewriting it from scratch every time.
4. Second Normalization
h_norm = self.norm2(h)
- This keeps the values stable before passing into the FFN.
5. Feedforward Network
ffn_out = self.ffwd(h_norm)
- Adds nonlinearity and transformation capacity.
- Output shape:
(batch, seq_len, n_embd).
6. Second Residual Connection
out = h + ffn_out
- Instead, it layers on additional transformations from the FFN.
- By the time you stack many decoder blocks, each one is contributing refinements while keeping the original context intact.
- This makes the network much more robust and trainable.
Final output shape: (batch, seq_len, n_embd).
In This Project
- Attention type: defaults to standard multi-head self-attention, with optional MLA for efficiency.
- Normalization: RMSNorm used everywhere (simpler than LayerNorm, but empirically stable).
- Activation: SiLU-based feedforward (SwiGLU).
- Dropout: applied after projections, mainly used during fine-tuning (SFT/RLHF).
- Residuals: used after both the attention and FFN sublayers.
Together, these form the repeating backbone of the SimpleLLaMA model.
By stacking many of these blocks, the network can build increasingly complex representations of text sequences.