Optimizing softmax on GPU
Introduction
Softmax is a fundamental operation in deep learning, particularly in attention mechanisms of transformer models. While mathematically simple, its efficient implementation on GPU hardware presents interesting challenges and optimization opportunities. In this blog post, we’ll explore the mathematical properties of softmax, implement multiple GPU kernels with increasing sophistication, and analyze their performance characteristics. Complete code is on GitHub Repository.
Softmax: Translation Invariance
For a vector \(\mathbf{x} = [x_1, x_2, ..., x_n]\), the softmax function is defined as:
\[\text{softmax}(x_i) = \frac{e^{x_i}}{\sum_{j=1}^{n} e^{x_j}}\]A key property of softmax is its translation invariance:
\[\text{softmax}(\mathbf{x} - c) = \text{softmax}(\mathbf{x})\]where \(c\) is any constant. This property is crucial for numerical stability, as we can set #$c = \max(\mathbf{x})#$ to avoid overflow in exponential calculations—this is known as “safe softmax.”
This property has a important insight: softmax can be computed incrementally using a reduction approach. If we partition a vector into two intervals \(I_1\) and \(I_2\), we can compute:
Let \(m_1 = \max(I_1)\), \(m_2 = \max(I_2)\), and \(m = \max(m_1, m_2)\).
Then for the exponential sums:
\[l = l_1 \cdot e^{m_1 - m} + l_2 \cdot e^{m_2 - m}\]where \(l_1 = \sum_{i \in I_1} e^{x_i - m_1}\) and \(l_2 = \sum_{i \in I_2} e^{x_i - m_2}\).
This incremental computation forms the theoretical foundation for FlashAttention’s tiling strategy, which reduces HBM (High Bandwidth Memory) accesses by computing softmax in SRAM (static RAM) with minimal data movement.
Implementation Journey: Five GPU Kernels
We implemented five progressively optimized softmax kernels in CUDA to demonstrate different optimization techniques. Let’s examine each approach:
Kernel 0: The Baseline
// Simple grid-stride loop
__global__ void softmax_basic(const float* input, float* output, int batch_size, int dim) {
int row = blockIdx.x * blockDim.x + threadIdx.x;
if (row >= batch_size) return;
const float* row_input = input + row * dim;
float* row_output = output + row * dim;
// Find max
float row_max = -INFINITY;
for (int i = 0; i < dim; i++) {
row_max = fmaxf(row_max, row_input[i]);
}
// Compute sum of exponentials
float sum_exp = 0.0f;
for (int i = 0; i < dim; i++) {
sum_exp += expf(row_input[i] - row_max);
}
// Normalize
for (int i = 0; i < dim; i++) {
row_output[i] = expf(row_input[i] - row_max) / sum_exp;
}
}
This naive implementation suffers from multiple passes over the data and poor memory coalescing.
Kernel 1: Shared Memory Optimization
Kernel 1 introduces shared memory for parallel reduction within a thread block. Each block processes one row, with threads cooperatively computing the maximum and sum through shared memory reductions.
Key Insight: Shared memory is ~100x faster than global memory, making reduction operations significantly faster.
Kernel 2: Warp-Level Primitives
Kernel 2 leverages CUDA’s warp shuffle instructions (__shfl_sync) for efficient warp-level reductions:
__inline__ __device__ float warpReduceMax(float val) {
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
val = fmaxf(val, __shfl_down_sync(0xffffffff, val, offset));
}
return val;
}
Advantage: Warp shuffle instructions bypass shared memory entirely, using register-to-register communication within a warp.
Kernel 3: 2D Block Organization
Kernel 3 organizes threads in a 2D block (32×4), where each warp processes one row independently. This improves occupancy by allowing multiple warps per block while maintaining efficient warp-level reductions.
Kernel 4: Multi-Row Per Warp with Templates
The most sophisticated kernel uses template parameters to process multiple rows per warp:
template<int ROWS_PER_WARP, int COLS_PER_THREAD>
__global__ void softmax_warp_multi_row(const float* input, float* output,
int batch_size, int dim) {
// Each warp processes ROWS_PER_WARP rows
// Each thread processes COLS_PER_THREAD elements per row
// Uses register caching for maximum performance
}
Performance Comparison
We benchmarked all kernels on an NVIDIA RTX 4500 Ada GPU with the following configuration:
- Batch size: 32768
- Dimension: 128
- Repetitions: 100
- CUDA version: 12.8
| Kernel | Description | Avg Time (ms) | Throughput | Speedup vs Baseline |
|---|---|---|---|---|
| 0 | Baseline | 0.29562 | 14.2M elements/ms | 1.0× |
| 1 | Shared Memory | 0.09587 | 43.8M elements/ms | 3.1× |
| 2 | Warp Primitives | 0.03016 | 139.1M elements/ms | 9.8× |
| 3 | 2D Block | 0.01770 | 236.9M elements/ms | 16.7× |
| 4 | Multi-row per Warp | 0.01669 | 251.2M elements/ms | 17.7× |
Key Observations
-
Warp-level optimizations dominate: Kernels 2-4 show dramatic improvements by leveraging warp shuffle instructions, with Kernel 4 achieving 17.7× speedup over the baseline.
-
Memory hierarchy matters: Kernel 1’s 3.1× improvement comes primarily from using shared memory for reductions.
-
Algorithm meets architecture: Kernel 4’s template-based approach demonstrates how algorithm design must consider hardware constraints (warp size, register count).
Conclusion
Softmax optimization on GPU demonstrates the beautiful interplay between mathematical theory and hardware architecture, from the basic translation invariance property to sophisticated warp-level implementations. The 17.7× speedup from Kernel 0 to Kernel 4 shows how far we can push performance when we deeply understand both the algorithm and the hardware. These techniques form the foundation for more complex optimizations like FlashAttention.
Enjoy Reading This Article?
Here are some more articles you might like to read next: