Efficient Deep Learning
1. Memory Requirements
The memory required to load a model with
For training, we need:
To find
2. FLOPs
For
is FLOPs. (dot product) is FLOPs. is FLOPs.
For a model layer:
- Layer Norm: computing the mean, variance, normalisation, scale, and shift involves several element-wise passes over the
input FLOPs. - QKV projections (
): three matrix multiplications of FLOPs. - Attention scores (
): per head, ; across heads FLOPs (since ). The scaling by adds element-wise operations (negligible). - Softmax: applied to each row of the
score matrix for each head. Each row requires exponentiations, additions, and divisions FLOPs (negligible). - Attention
Values ( ): per head, ; across heads FLOPs. - Output projection (
): FLOPs. - First linear (
): FLOPs. - GeLU activation: element-wise over
elements (negligible). - Second linear (
): FLOPs.
3. Training Efficiency
3.1 Gradient Accumulation
ALWAYS DO THIS!
Memory requirements scale linearly with batch size. A small batch size will lead to unstable training, as the gradient estimates will be noisy.
Gradient Accumulation will iteratively compute gradients over smaller batches and accumulate them before performing an update. The optimiser only steps once.
3.2 Gradient Checkpointing
ALWAYS DO THIS!
Gradient Checkpointing trades compute for memory by only storing a subset of intermediate activations during the forward pass. During backpropagation, the missing activations are recomputed as needed, which increases training time but reduces memory usage. Memory scales down to
3.3 Mixture of Experts (MoE)
MoE allows increasing parameter count without a proportional increase in computational cost. It consists of multiple "expert" sub-networks, but only a subset of these experts are activated for each input, based on a gating mechanism. A router determines which expert to use, by finding the similarity between the input and the expert centroids (learned during training).
Each token is sent to the top
3.4 Finetuning
After pretraining we may want to adapt the model to a specific downstream task.
Low Rank Adaptation (LoRA) is a finetuning method that adds low-rank matrices to the original model's weights, allowing for efficient adaptation with fewer trainable parameters. For a weight matrix in a pretrained model
QLoRA further reduces memory by quantising the pretrained model's weights to
Quantisation maps a continuous range of values to a finite set of discrete values.
3.5 Mixed Precision
Mixed Precision Training uses lower-precision data types (like FP16 or BF16) for certain parts of the training process, while keeping critical operations in higher precision (like FP32). This can significantly reduce memory usage and increase computational speed without sacrificing model performance. The key is to maintain numerical stability by using higher precision for operations that are sensitive to rounding errors, such as gradient accumulation and weight updates.
4. Inference Efficiency
4.1 KV Caching
Autoregressive language models generate text one token at a time. A naive implementation recomputes the full attention over all previous tokens at each generation step, leading to massive redundant computation. The key-value (KV) cache stores the key and value projections for all previously generated tokens, so that at each new generation step only the key and value for the new token need to be computed. The insight comes from the causal attention mask: when generating token
For a model with