This is a deep dive into ZeRO Stage 1 — the Distributed Optimizer from the ZeRO paper. It covers the fundamentals and a complete toy implementation (~300 lines of core code). The implementation discussed throughout is open-sourced at zero_optim_toy, built on the original work by @xiabingquan.
Prerequisite: familiarity with Distributed Data Parallel (DDP).
The key idea: within a data-parallel group, all \(N\) ranks hold identical optimizer states. ZeRO-1 partitions the optimizer state across ranks so each rank holds only \(1/N\) of it — the fp32 master parameters, Adam’s first moment (\(m\)), and second moment (\(v\)). After a local optimizer step on each shard, ranks exchange results via all-gather to recover the full updated parameters.
For a bf16 model with Adam optimizer and \(P\) parameter elements:
| Memory Item | Vanilla DDP (per rank) | ZeRO-1 (per rank) |
|---|---|---|
| fp32 master params | \(4P\) bytes | \(4P / N\) bytes |
| Adam \(m\) (first moment) | \(4P\) bytes | \(4P / N\) bytes |
| Adam \(v\) (second moment) | \(4P\) bytes | \(4P / N\) bytes |
| Optimizer state subtotal | \(12P\) bytes | \(12P / N\) bytes |
With \(N=8\), optimizer state drops from \(12P\) to \(1.5P\). Model parameters (bf16) and gradient buffers remain at full size on each rank.
| Operation | Vanilla DDP | ZeRO-1 |
|---|---|---|
| Gradient sync | all-reduce (\(\sim 2P\)) | reduce-scatter (\(\sim P\)) |
| Parameter broadcast | none | all-gather (\(\sim P\)) |
| Total | \(\sim 2P\) | \(\sim 2P\) |
ZeRO-1 adds an all-gather round, but total communication volume stays the same — which is why it is effectively standard for large-scale training today.
Within a data-parallel (or combined dp-cp) group, all ranks have identical model parameters, and therefore identical optimizer states. Instead of storing \(N\) redundant copies, we keep one copy split into \(N\) shards, one per rank.
Vanilla DDP — each rank needs the full reduced gradient to update the full parameters:
backward → gradient all-reduce → optimizer.step()
ZeRO-1 — each rank only needs \(1/N\) of the reduced gradient for its own shard:
backward → gradient reduce-scatter → optimizer.step(shard) → parameter all-gather
reduce-scatter = reduce + scatter in a single collective. Each rank receives only the segment of the reduced gradient that it needs for its optimizer shard.
From the optimizer’s perspective, fixing (tp_rank, pp_rank), the ranks across (dp_rank, cp_rank) form a group with identical optimizer states. The communication group for ZeRO-1 is therefore the dp-cp group (or a subset thereof).
num_optimizer_instances=1): the entire dp-cp group is one optimizer’s partitioning domain.num_optimizer_instances>1): the dp-cp group is further split into: With redundancy, gradient sync becomes a two-phase process: reduce-scatter within intra group, then all-reduce across inter group. Parameter sync (all-gather) only happens within the intra group.
The implementation follows a clean three-layer architecture: Buffer → DDP → Distributed Optimizer. Each layer has a distinct responsibility and communicates through clear interfaces.
Responsibility: manage contiguous memory, provide shard-level read/write and communication primitives.
Why a buffer is necessary: model parameters and gradients are scattered across separate allocations in memory. Communicating per-tensor means one NCCL call per parameter — the kernel launch overhead becomes prohibitive. By remapping parameters and gradients into a contiguous buffer, a single collective call covers the entire buffer.
The buffer manages the low-precision (e.g., bf16) parameters and gradients used in forward/backward — not the fp32 master parameters inside the optimizer (those are handled separately by the Distributed Optimizer layer).
Key functions:
param.data and param.main_grad to views into the buffer.reduce_scatter_grads(), allgather_params(), get_grad_shard(param), write_param_shard(param, data).Sharding and parameter boundaries: the buffer is evenly partitioned across \(N\) ranks along its contiguous memory — partition boundaries do not respect parameter boundaries. A single parameter may be split across two ranks. The upside: each shard has exactly equal size, making reduce-scatter / all-gather direct operations on uniform slices without padding. The tradeoff: the optimizer must handle partial parameter fragments.
Responsibility: group parameters by type and precision, create buffers for each group, and provide unified gradient/parameter sync interfaces.
Typical grouping example:
| Param Type | Precision | Example |
|---|---|---|
| dense | bf16 | Attention/FFN linear weights |
| MoE | bf16 | Expert FFN linear weights |
| MoE | fp32 | Router weights |
The DDP layer provides:
Responsibility: manage optimizer state creation and update, coordinate precision conversion between low-precision (bf16) and high-precision (fp32) representations.
This layer:
shard_main_param.grad (fp32)Adam.step() on the fp32 shardInitialization
1. DDP creates buffers
├─ Groups params by (type, precision), calls Buffer layer
├─ Remaps param.data and param.main_grad to buffer views
└─ Computes shard range for each rank
2. Distributed Optimizer initialization
├─ Queries DDP for each parameter's shard info on this rank
├─ Creates fp32 master copy: shard_main_param = buffer_view[start:end].clone().float()
└─ Creates Adam optimizer over the fp32 shard (Adam is unaware of distribution)
Training Step
Forward
└─ Uses full bf16 parameters in param buffer
Backward
└─ autograd writes gradients to param.main_grad (views into grad buffer)
Gradient Sync (DDP)
└─ reduce_scatter_grads(): each rank gets 1/N reduced gradient shard
Optimizer Step (Distributed Optimizer)
├─ 1. Upcast: shard_main_param.grad = grad_shard.float()
├─ 2. Adam.step() on the fp32 shard (local, 1/N size)
├─ 3. Writeback: write_param_shard(param, shard_main_param.to(bf16))
└─ 4. Sync: allgather_params() → full parameters restored
Next Forward
└─ Uses fully updated parameters
grad buffer (bf16, full)
│
reduce-scatter
│
grad shard (bf16, 1/N)
│
.float()
│
shard_main_param.grad (fp32, 1/N)
│
Adam.step()
│
shard_main_param (fp32, 1/N)
│
.to(bf16)
│
param buffer shard (bf16, 1/N)
│
all-gather
│
param buffer (bf16, full)
optimizer.step(), asynchronously launch all-gather for the first bucket. Before each layer’s forward, a pre-hook waits on that layer’s all-gather and dispatches the next bucket, forming a pipeline.Further subdividing buffers into buckets gives finer-grained overlap and reduces per-collective latency. Each bucket communicates independently.
dp_world_size (required for even partitioning in reduce-scatter).The challenge: optimizer state is distributed across ranks. Common strategies:
When some parameters use FP8 precision for forward/backward, extra quantize/dequantize steps are needed: fp8 → fp32 → optimizer → fp32 → fp8. FP8 and non-FP8 parameters must live in separate buffers.
In the redundant case, \(K\) optimizer replicas are kept within the dp-cp group. Each replica covers a subset (intra group) of size \(N/K\).
For dp-cp group size \(N\) and total parameters \(P\):
| Operation | \(K=1\) (no redundancy) | \(K>1\) (with redundancy) |
|---|---|---|
| Gradient reduction | reduce-scatter, \(\sim P\) | reduce-scatter (intra) + all-reduce (inter), \(\sim P + \frac{2P}{N} \times (K-1)\) |
| Parameter sync | all-gather, \(\sim P\) | all-gather (intra), \(\sim P \times \frac{N/K-1}{N/K}\) |
| Total | \(\sim 2P\) | \(> 2P\) |
Total communication strictly increases with redundancy on uniform-bandwidth topologies. However, the pattern changes — which can be advantageous under heterogeneous bandwidth.
Example: \(N=8\), dp=4 (cross-node, slow link), cp=2 (intra-node, NVLink), \(K=4\):
| Operation | \(K=1\) slow-link data | \(K=4\) slow-link data |
|---|---|---|
| Gradient reduction | \(\sim P\) | \(\sim 6P/8\) (only inter all-reduce on slow link) |
| Parameter sync | \(\sim P\) | \(0\) (intra all-gather entirely on NVLink) |
| Slow-link total | \(\sim 2P\) | \(\sim 6P/8\) |
Slow-link traffic drops by ~62.5%. The derivation for \(6P/8\): inter all-reduce with \(K=4\) ranks, each holding \(P/(N/K) = P/2\) data after intra reduce-scatter. Ring all-reduce cost: \(2 \times (K-1)/K \times P/2 = 2 \times 3/4 \times P/2 = 3P/4 = 6P/8\).
Caveat: NCCL already applies hierarchical algorithms for large cross-node communicators, so even with \(K=1\) the actual behavior may approximate a two-phase pattern. The marginal benefit of manual partitioning (\(K>1\)) depends on whether NCCL’s automatic optimization is sufficient. The implementation complexity is non-trivial — two process groups, two-phase gradient sync with synchronization dependencies between them.
The complete implementation is at zero_optim_toy. The codebase also includes a utils.py module with shared distributed testing helpers, a package __init__.py, and a CLI via __main__.py (python -m zero_optim_toy test). Here we walk through each layer.
A simple test model — stacked MLP blocks with residual connections:
class MLP(nn.Module):
def __init__(self, hidden_dim: int):
super().__init__()
self.fc1 = nn.Linear(hidden_dim, hidden_dim * 4)
self.act = nn.GELU()
self.fc2 = nn.Linear(hidden_dim * 4, hidden_dim)
def forward(self, x):
return self.fc2(self.act(self.fc1(x)))
class ToyModel(nn.Module):
def __init__(self, hidden_dim=512, num_layers=3):
super().__init__()
self.layers = nn.ModuleList([MLP(hidden_dim) for _ in range(num_layers)])
def forward(self, x):
for layer in self.layers:
x = x + layer(x)
return x
The Buffer takes a list of tensors, flattens them into contiguous memory, and exposes shard views:
class Buffer:
def __init__(self, tensors, rank, world_size, dtype, device):
numels = [t.numel() for t in tensors]
total_numel = sum(numels)
# Pad to multiple of world_size for even reduce-scatter partitioning
self.padded_numel = math.ceil(total_numel / world_size) * world_size
self.shard_numel = self.padded_numel // world_size
self._data = torch.zeros(self.padded_numel, dtype=dtype, device=device)
# _local_shard is a view — no extra memory allocation
shard_start = rank * self.shard_numel
shard_end = shard_start + self.shard_numel
self._local_shard = self._data[shard_start:shard_end]
# Create views into _data for each original tensor
self._views = []
offset = 0
for t, numel in zip(tensors, numels):
view = self._data[offset : offset + numel].view(t.shape)
self._views.append(view)
offset += numel
Communication primitives — symmetric and minimal:
def reduce_scatter(self, group=None):
# full buffer → local shard (SUM reduction)
dist.reduce_scatter_tensor(self._local_shard, self._data, group=group)
def all_gather(self, group=None):
# local shard → full buffer
dist.all_gather_into_tensor(self._data, self._local_shard, group=group)
Since _local_shard is a view into _data, writing to the shard writes directly into the correct region of the full buffer.
The DDP layer creates two Buffers — one for parameters, one for gradients — and remaps param.data and param.main_grad:
class DistributedDataParallel:
def __init__(self, module, rank, world_size, process_group=None, device=None):
self.params = list(module.parameters())
self.param_buffer = Buffer(self.params, rank, world_size,
dtype=torch.bfloat16, device=device)
self.grad_buffer = Buffer(self.params, rank, world_size,
dtype=torch.bfloat16, device=device)
self._init_buffers()
def _init_buffers(self):
param_views = self.param_buffer.get_views()
grad_views = self.grad_buffer.get_views()
for param, p_view, g_view in zip(self.params, param_views, grad_views):
p_view.copy_(param.data)
param.data = p_view # param.data now points into buffer
param.main_grad = g_view # attach gradient view
Gradient sync: the critical handoff from autograd to the buffer system. PyTorch’s autograd writes to param.grad (a separate tensor), so we must manually copy into the grad buffer view:
def sync_grads(self):
for param in self.params:
if param.grad is not None:
param.main_grad.copy_(param.grad)
param.grad = None
self.grad_buffer.reduce_scatter(self.process_group)
Parameter sync is a single all-gather:
def sync_params(self):
self.param_buffer.all_gather(self.process_group)
The simplest layer — clone the rank’s parameter shard as an fp32 master, hand it to standard Adam:
class DistributedOptimizer:
def __init__(self, ddp, lr=1e-3, betas=(0.9, 0.999), eps=1e-8):
self.ddp = ddp
self.shard_fp32 = torch.nn.Parameter(
ddp.get_param_shard().clone().float(),
requires_grad=False,
)
self.adam = torch.optim.Adam([self.shard_fp32], lr=lr, betas=betas, eps=eps)
def step(self):
# 1. reduced grad shard (bf16) → fp32
self.shard_fp32.grad = self.ddp.get_grad_shard().float()
# 2. Adam updates fp32 master
self.adam.step()
# 3. fp32 → bf16, write back to param buffer shard
self.ddp.write_param_shard(self.shard_fp32.data.bfloat16())
# 4. all-gather to restore full parameters
self.ddp.sync_params()
Adam sees only one fp32 tensor. The distribution is entirely transparent to the optimizer.
The test strategy: run the exact same precision path (bf16 forward/backward → fp32 Adam → bf16 writeback) on a single process without ZeRO-1, and compare parameters step-by-step against the multi-process ZeRO-1 run.
All ranks use the same random seed for input data, and the loss is divided by world_size:
loss = output.sum() / world_size
loss.backward()
ddp.sync_grads()
optimizer.step()
Since all ranks have identical gradients, after reduce-scatter(SUM), each shard equals N * (grad / N) = grad, matching the reference exactly.
Tests cover three scenarios:
hidden_dim=13 (parameter count not divisible by world_size) — padding testThe verification criterion is atol=0, rtol=0 — bit-exact equality. Since the precision path is identical, ZeRO-1 results should match single-GPU training exactly.
Model: 10-layer MLP with \(H=4096\), ~1.34B parameters, world_size=4.
| Baseline (1 GPU) | ZeRO-1 (4 GPUs) | |
|---|---|---|
| bf16 model params | 2.50 GB | 2.50 GB |
| grad_buffer | — | 2.50 GB |
| fp32 master params | 5.00 GB | 1.25 GB |
| Adam \(m\) | 5.00 GB | 1.25 GB |
| Adam \(v\) | 5.00 GB | 1.25 GB |
| Total | 17.50 GB | 8.75 GB |
Theoretical saving: 50%. Optimizer state drops from 15.00 GB to 3.75 GB — a 75% reduction (i.e., \(1 - 1/N\)). However, ZeRO-1 introduces a persistent grad_buffer (2.50 GB) that baseline doesn’t have (baseline creates param.grad temporarily during backward and frees it during optimizer.step()). Net saving is ~42.7%, consistent with actual profile measurements.
ZeRO-1’s core insight is simple: since data-parallel ranks hold identical optimizer states, keep one copy and shard it. The three-layer architecture (Buffer → DDP → Distributed Optimizer) cleanly separates distributed communication from optimization logic. With no increase in total communication volume, optimizer state memory drops to \(1/N\).
For even larger models, ZeRO Stages 2 and 3 extend this idea further — sharding gradients and parameters as well — which would be the natural next topic to explore.