Efficient Deep Learning

1. Memory Requirements

The memory required to load a model with parameters stored at bits per parameter is: bytes. This is the memory needed for inference.

For training, we need:

To find , find all the stored tensors for backward pass (usually inputs to operations). For each tensor with batch size , sequence length and dimension , . Then bytes.

2. FLOPs

For :

For a model layer:

  1. Layer Norm: computing the mean, variance, normalisation, scale, and shift involves several element-wise passes over the input FLOPs.
  2. QKV projections (): three matrix multiplications of FLOPs.
  3. Attention scores (): per head, ; across heads FLOPs (since ). The scaling by adds element-wise operations (negligible).
  4. Softmax: applied to each row of the score matrix for each head. Each row requires exponentiations, additions, and divisions FLOPs (negligible).
  5. Attention Values (): per head, ; across heads FLOPs.
  6. Output projection (): FLOPs.
  7. First linear (): FLOPs.
  8. GeLU activation: element-wise over elements (negligible).
  9. Second linear (): FLOPs.

3. Training Efficiency

3.1 Gradient Accumulation

ALWAYS DO THIS!

Memory requirements scale linearly with batch size. A small batch size will lead to unstable training, as the gradient estimates will be noisy.

Gradient Accumulation will iteratively compute gradients over smaller batches and accumulate them before performing an update. The optimiser only steps once.

3.2 Gradient Checkpointing

ALWAYS DO THIS!

Gradient Checkpointing trades compute for memory by only storing a subset of intermediate activations during the forward pass. During backpropagation, the missing activations are recomputed as needed, which increases training time but reduces memory usage. Memory scales down to , but training time becomes .

3.3 Mixture of Experts (MoE)

MoE allows increasing parameter count without a proportional increase in computational cost. It consists of multiple "expert" sub-networks, but only a subset of these experts are activated for each input, based on a gating mechanism. A router determines which expert to use, by finding the similarity between the input and the expert centroids (learned during training).

Each token is sent to the top experts, where is a hyperparameter. Routing collapse can happen when the router consistently selects the same expert(s) for all inputs, leading to underutilization of the model's capacity. To mitigate this: use token dropping (randomly drop some tokens during training to encourage diversity in expert selection) and auxiliary loss (add a loss term that encourages the router to distribute tokens more evenly across experts). A bias term (added to the routing scores to encourage exploration of less frequently selected experts) can also be used.

3.4 Finetuning

After pretraining we may want to adapt the model to a specific downstream task.

Low Rank Adaptation (LoRA) is a finetuning method that adds low-rank matrices to the original model's weights, allowing for efficient adaptation with fewer trainable parameters. For a weight matrix in a pretrained model , we learn , where , where and with . This significantly reduces memory during finetuning.

QLoRA further reduces memory by quantising the pretrained model's weights to bits, while keeping the LoRA parameters in full precision. This allows for efficient finetuning of large models on limited hardware.

Quantisation maps a continuous range of values to a finite set of discrete values.

3.5 Mixed Precision

Mixed Precision Training uses lower-precision data types (like FP16 or BF16) for certain parts of the training process, while keeping critical operations in higher precision (like FP32). This can significantly reduce memory usage and increase computational speed without sacrificing model performance. The key is to maintain numerical stability by using higher precision for operations that are sensitive to rounding errors, such as gradient accumulation and weight updates.

4. Inference Efficiency

4.1 KV Caching

Autoregressive language models generate text one token at a time. A naive implementation recomputes the full attention over all previous tokens at each generation step, leading to massive redundant computation. The key-value (KV) cache stores the key and value projections for all previously generated tokens, so that at each new generation step only the key and value for the new token need to be computed. The insight comes from the causal attention mask: when generating token , the attention scores for tokens do not depend on this new token, so caching avoids recomputing them.

For a model with layers, batch size , sequence length , key-value head dimension , number of KV heads , stored at bits per element, the KV cache requires

Back to Home