Attention Head Pruning
Layer-wise Pruning of Transformer Attention Heads for Efficient Language Modeling
Introduction
Attention head pruning is an effective method for reducing the computational cost of Transformer models without significantly impacting their performance. This paper introduces a trainable, layer-wise attention head pruning approach, where the model learns which heads to remove during training using a gating mechanism and L₀ sparsity loss. By gradually eliminating less important heads, the method achieves efficient model compression while maintaining accuracy.
This pruning method can be categorized based on:
- Granularity: The approach operates at the attention head level, making it a structured pruning technique that removes entire heads rather than individual weights or neurons. This ensures efficient execution on modern hardware.
- Pruning Criteria: The decision to prune an attention head is based on a trainable gating mechanism that evaluates its contribution to the model’s output. A special L₀ sparsity loss encourages the removal of heads that minimally impact performance, making the process data-driven rather than heuristic.
- Pruning Ratio: The pruning ratio is learned automatically during training rather than being manually predefined. The model determines the optimal number of heads to remove based on performance preservation and sparsity constraints, ensuring a balance between efficiency and accuracy.
- Need for Retraining: Unlike traditional pruning methods that require full retraining after pruning, this approach does not need retraining from scratch. Instead, it includes a short fine-tuning phase after pruning to allow the remaining heads to adapt, leading to a more stable and practical pruning framework.
Research Gap
While attention head pruning has been explored in prior work, most existing methods focus on removing heads after pretraining using heuristic importance measures or sensitivity analysis. However, these approaches often require additional retraining steps and do not fully optimize the pruning process during model training. Furthermore, traditional attention head pruning does not proportionally reduce the feedforward module’s computational cost, limiting overall efficiency gains. This paper addresses these gaps by introducing layer-wise, trainable attention head pruning, where the model learns which heads to remove during training, guided by a trainable gating mechanism and L₀ sparsity loss.
Methodology
This paper proposes a method to prune unnecessary attention heads from Transformer models to reduce computation and model size while maintaining accuracy.
So, let us see how it works step by step. Each attention head is assigned a trainable gating parameter that learns whether the head is important or can be removed.
Step 1: Attach a Trainable “Switch” to Each Head
- Every attention head gets a gating parameter (a trainable value).
- This parameter decides whether a head stays active or gets pruned.
Step 2: Train the Model with Gating Mechanism
- The model is trained as usual, but now it also learns which heads are useful.
- A special L₀ loss function helps decide which heads to remove.
- The loss is based on the difference between the layer’s output with and without a specific head.
- If removing a head doesn’t significantly change the output, its gating value decreases, making it more likely to be pruned.
Step 3: Permanently Remove Weak Heads
- Once training is done, heads with low gating parameters are fully deleted from the model.
- This reduces both model size and computational cost.
Step 4: Fine-Tune the Pruned Model
- Since removing heads slightly changes the model structure, a short fine-tuning phase is needed.
- This helps the remaining heads adjust and recover any lost performance.
Conclusion
This paper presents a layer-wise, trainable attention head pruning approach that enables Transformer models to become more efficient while maintaining strong performance. By introducing trainable gating parameters and an L₀ sparsity loss, the model learns which attention heads are unimportant and gradually removes them during training. Unlike traditional pruning techniques that rely on heuristic importance measures or require full retraining, this method ensures automatic, structured pruning with only a short fine-tuning phase. Additionally, by leveraging the All-Attention Transformer, the approach effectively reduces both computation and parameter count, making it more scalable for real-world deployment. Experimental results demonstrate that this pruning method outperforms conventional head pruning techniques in parameter efficiency and inference speed, highlighting its potential for optimizing Transformer-based language models.