Mathematics of Transformers
1. Attention Mechanism
The self-attention mechanism is the cornerstone of transformers, allowing the model to weigh the importance of different tokens in a sequence:
Scaled Dot-Product Attention: \[ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V \] - \( Q \): Query matrix - \( K \): Key matrix - \( V \): Value matrix - \( d_k \): Dimensionality of keys
To improve representation learning, multi-head attention computes multiple attention outputs from different subspaces:
\[ \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)W^O \]
2. Positional Encoding
Transformers incorporate positional encoding to account for sequence order:
\[ \text{PE}_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{\frac{2i}{d}}}\right), \quad \text{PE}_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{\frac{2i}{d}}}\right) \]
Where:
- \( pos \): Position index
- \( i \): Dimension index
- \( d \): Embedding size
3. Feedforward Networks
Transformers use position-wise feedforward networks (FFN) for additional processing:
\[ \text{FFN}(x) = \text{ReLU}(xW_1 + b_1)W_2 + b_2 \]
4. Layer Normalization
Layer normalization ensures stable training:
\[ \text{LayerNorm}(x) = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} \cdot \gamma + \beta \] - \( \mu \): Mean of \( x \) - \( \sigma^2 \): Variance of \( x \) - \( \gamma, \beta \): Learnable parameters
5. Optimization
Transformers are optimized with methods like the Adam optimizer and learning rate scheduling:
Learning Rate Scheduling: \[ \text{lr} = d^{-0.5} \cdot \min(\text{step}^{-0.5}, \text{step} \cdot \text{warmup\_steps}^{-1.5}) \]
6. Tokenization and Embedding
Input sequences are tokenized and converted to dense vectors using an embedding matrix:
\[ \text{Embedding}(x) = W_e \cdot x \]
7. Loss Function
For tasks like language modeling, transformers optimize a cross-entropy loss function:
\[ \mathcal{L} = -\sum_{i=1}^{N} y_i \log(\hat{y}_i) \] - \( y_i \): True probability - \( \hat{y}_i \): Predicted probability
8. Computational Complexity
Self-attention has a computational complexity of \( O(n^2d) \), which scales quadratically with sequence length. Optimizations such as sparse attention reduce this complexity.
This work is licensed under a Creative Commons Attribution 4.0 International License.
