ZeRO-1 Distributed Optimizer - A Deep Dive

Overview

This is a deep dive into ZeRO Stage 1 — the Distributed Optimizer from the ZeRO paper. It covers the fundamentals and a complete toy implementation (~300 lines of core code). The implementation discussed throughout is open-sourced at zero_optim_toy, built on the original work by @xiabingquan.

Prerequisite: familiarity with Distributed Data Parallel (DDP).

The key idea: within a data-parallel group, all \(N\) ranks hold identical optimizer states. ZeRO-1 partitions the optimizer state across ranks so each rank holds only \(1/N\) of it — the fp32 master parameters, Adam’s first moment (\(m\)), and second moment (\(v\)). After a local optimizer step on each shard, ranks exchange results via all-gather to recover the full updated parameters.

Memory Savings

For a bf16 model with Adam optimizer and \(P\) parameter elements:

Memory Item Vanilla DDP (per rank) ZeRO-1 (per rank)
fp32 master params \(4P\) bytes \(4P / N\) bytes
Adam \(m\) (first moment) \(4P\) bytes \(4P / N\) bytes
Adam \(v\) (second moment) \(4P\) bytes \(4P / N\) bytes
Optimizer state subtotal \(12P\) bytes \(12P / N\) bytes

With \(N=8\), optimizer state drops from \(12P\) to \(1.5P\). Model parameters (bf16) and gradient buffers remain at full size on each rank.

Communication Volume

Operation Vanilla DDP ZeRO-1
Gradient sync all-reduce (\(\sim 2P\)) reduce-scatter (\(\sim P\))
Parameter broadcast none all-gather (\(\sim P\))
Total \(\sim 2P\) \(\sim 2P\)

ZeRO-1 adds an all-gather round, but total communication volume stays the same — which is why it is effectively standard for large-scale training today.

Applicability

Limitations

How It Works

Partitioning Rationale

Within a data-parallel (or combined dp-cp) group, all ranks have identical model parameters, and therefore identical optimizer states. Instead of storing \(N\) redundant copies, we keep one copy split into \(N\) shards, one per rank.

Communication Pattern

Vanilla DDP — each rank needs the full reduced gradient to update the full parameters:

backward → gradient all-reduce → optimizer.step()

ZeRO-1 — each rank only needs \(1/N\) of the reduced gradient for its own shard:

backward → gradient reduce-scatter → optimizer.step(shard) → parameter all-gather

reduce-scatter = reduce + scatter in a single collective. Each rank receives only the segment of the reduced gradient that it needs for its optimizer shard.

Process Groups

From the optimizer’s perspective, fixing (tp_rank, pp_rank), the ranks across (dp_rank, cp_rank) form a group with identical optimizer states. The communication group for ZeRO-1 is therefore the dp-cp group (or a subset thereof).

With redundancy, gradient sync becomes a two-phase process: reduce-scatter within intra group, then all-reduce across inter group. Parameter sync (all-gather) only happens within the intra group.

Core Implementation

The implementation follows a clean three-layer architecture: Buffer → DDP → Distributed Optimizer. Each layer has a distinct responsibility and communicates through clear interfaces.

Buffer Layer

Responsibility: manage contiguous memory, provide shard-level read/write and communication primitives.

Why a buffer is necessary: model parameters and gradients are scattered across separate allocations in memory. Communicating per-tensor means one NCCL call per parameter — the kernel launch overhead becomes prohibitive. By remapping parameters and gradients into a contiguous buffer, a single collective call covers the entire buffer.

The buffer manages the low-precision (e.g., bf16) parameters and gradients used in forward/backward — not the fp32 master parameters inside the optimizer (those are handled separately by the Distributed Optimizer layer).

Key functions:

  1. Accept a list of parameters (with their dtypes) from the DDP layer and allocate corresponding contiguous buffers.
  2. Gradient buffer: always created (gradient communication is fundamental to DDP). Parameter buffer: only created when ZeRO-1 is enabled (vanilla DDP updates full parameters locally, no parameter communication needed).
  3. Remap param.data and param.main_grad to views into the buffer.
  4. Provide communication primitives: reduce_scatter_grads(), allgather_params(), get_grad_shard(param), write_param_shard(param, data).

Sharding and parameter boundaries: the buffer is evenly partitioned across \(N\) ranks along its contiguous memory — partition boundaries do not respect parameter boundaries. A single parameter may be split across two ranks. The upside: each shard has exactly equal size, making reduce-scatter / all-gather direct operations on uniform slices without padding. The tradeoff: the optimizer must handle partial parameter fragments.

DDP Layer

Responsibility: group parameters by type and precision, create buffers for each group, and provide unified gradient/parameter sync interfaces.

Typical grouping example:

Param Type Precision Example
dense bf16 Attention/FFN linear weights
MoE bf16 Expert FFN linear weights
MoE fp32 Router weights

The DDP layer provides:

Distributed Optimizer Layer

Responsibility: manage optimizer state creation and update, coordinate precision conversion between low-precision (bf16) and high-precision (fp32) representations.

This layer:

  1. Creates fp32 master copies and optimizer states only for the shard owned by the current rank (\(1/N\) the size).
  2. Wraps a standard Adam optimizer over the fp32 shard.
  3. Each training step executes:
    • Gradient upcast: low-precision grad shard → shard_main_param.grad (fp32)
    • Parameter update: standard Adam.step() on the fp32 shard
    • Parameter writeback: updated fp32 shard → low-precision param buffer shard
    • Parameter sync: trigger all-gather to restore full parameters

Execution Flow

Initialization

1. DDP creates buffers
   ├─ Groups params by (type, precision), calls Buffer layer
   ├─ Remaps param.data and param.main_grad to buffer views
   └─ Computes shard range for each rank

2. Distributed Optimizer initialization
   ├─ Queries DDP for each parameter's shard info on this rank
   ├─ Creates fp32 master copy: shard_main_param = buffer_view[start:end].clone().float()
   └─ Creates Adam optimizer over the fp32 shard (Adam is unaware of distribution)

Training Step

Forward
  └─ Uses full bf16 parameters in param buffer

Backward
  └─ autograd writes gradients to param.main_grad (views into grad buffer)

Gradient Sync (DDP)
  └─ reduce_scatter_grads(): each rank gets 1/N reduced gradient shard

Optimizer Step (Distributed Optimizer)
  ├─ 1. Upcast: shard_main_param.grad = grad_shard.float()
  ├─ 2. Adam.step() on the fp32 shard (local, 1/N size)
  ├─ 3. Writeback: write_param_shard(param, shard_main_param.to(bf16))
  └─ 4. Sync: allgather_params() → full parameters restored

Next Forward
  └─ Uses fully updated parameters

Data Precision Pipeline

grad buffer (bf16, full)
        │
   reduce-scatter
        │
   grad shard (bf16, 1/N)
        │
   .float()
        │
shard_main_param.grad (fp32, 1/N)
        │
   Adam.step()
        │
shard_main_param (fp32, 1/N)
        │
   .to(bf16)
        │
param buffer shard (bf16, 1/N)
        │
   all-gather
        │
param buffer (bf16, full)

Optional Optimizations

Communication-Computation Overlap

Bucketing

Further subdividing buffers into buckets gives finer-grained overlap and reduces per-collective latency. Each bucket communicates independently.

Buffer Padding & Alignment

Checkpoint Save & Load

The challenge: optimizer state is distributed across ranks. Common strategies:

FP8 Quantization Awareness

When some parameters use FP8 precision for forward/backward, extra quantize/dequantize steps are needed: fp8 → fp32 → optimizer → fp32 → fp8. FP8 and non-FP8 parameters must live in separate buffers.

Redundant Optimizers

In the redundant case, \(K\) optimizer replicas are kept within the dp-cp group. Each replica covers a subset (intra group) of size \(N/K\).

For dp-cp group size \(N\) and total parameters \(P\):

Operation \(K=1\) (no redundancy) \(K>1\) (with redundancy)
Gradient reduction reduce-scatter, \(\sim P\) reduce-scatter (intra) + all-reduce (inter), \(\sim P + \frac{2P}{N} \times (K-1)\)
Parameter sync all-gather, \(\sim P\) all-gather (intra), \(\sim P \times \frac{N/K-1}{N/K}\)
Total \(\sim 2P\) \(> 2P\)

Total communication strictly increases with redundancy on uniform-bandwidth topologies. However, the pattern changes — which can be advantageous under heterogeneous bandwidth.

Example: \(N=8\), dp=4 (cross-node, slow link), cp=2 (intra-node, NVLink), \(K=4\):

Operation \(K=1\) slow-link data \(K=4\) slow-link data
Gradient reduction \(\sim P\) \(\sim 6P/8\) (only inter all-reduce on slow link)
Parameter sync \(\sim P\) \(0\) (intra all-gather entirely on NVLink)
Slow-link total \(\sim 2P\) \(\sim 6P/8\)

Slow-link traffic drops by ~62.5%. The derivation for \(6P/8\): inter all-reduce with \(K=4\) ranks, each holding \(P/(N/K) = P/2\) data after intra reduce-scatter. Ring all-reduce cost: \(2 \times (K-1)/K \times P/2 = 2 \times 3/4 \times P/2 = 3P/4 = 6P/8\).

Caveat: NCCL already applies hierarchical algorithms for large cross-node communicators, so even with \(K=1\) the actual behavior may approximate a two-phase pattern. The marginal benefit of manual partitioning (\(K>1\)) depends on whether NCCL’s automatic optimization is sufficient. The implementation complexity is non-trivial — two process groups, two-phase gradient sync with synchronization dependencies between them.

Code Walkthrough

The complete implementation is at zero_optim_toy. The codebase also includes a utils.py module with shared distributed testing helpers, a package __init__.py, and a CLI via __main__.py (python -m zero_optim_toy test). Here we walk through each layer.

Step 1: Model

A simple test model — stacked MLP blocks with residual connections:

class MLP(nn.Module):
    def __init__(self, hidden_dim: int):
        super().__init__()
        self.fc1 = nn.Linear(hidden_dim, hidden_dim * 4)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim * 4, hidden_dim)

    def forward(self, x):
        return self.fc2(self.act(self.fc1(x)))

class ToyModel(nn.Module):
    def __init__(self, hidden_dim=512, num_layers=3):
        super().__init__()
        self.layers = nn.ModuleList([MLP(hidden_dim) for _ in range(num_layers)])

    def forward(self, x):
        for layer in self.layers:
            x = x + layer(x)
        return x

Step 2: Buffer Layer

The Buffer takes a list of tensors, flattens them into contiguous memory, and exposes shard views:

class Buffer:
    def __init__(self, tensors, rank, world_size, dtype, device):
        numels = [t.numel() for t in tensors]
        total_numel = sum(numels)

        # Pad to multiple of world_size for even reduce-scatter partitioning
        self.padded_numel = math.ceil(total_numel / world_size) * world_size
        self.shard_numel = self.padded_numel // world_size

        self._data = torch.zeros(self.padded_numel, dtype=dtype, device=device)

        # _local_shard is a view — no extra memory allocation
        shard_start = rank * self.shard_numel
        shard_end = shard_start + self.shard_numel
        self._local_shard = self._data[shard_start:shard_end]

        # Create views into _data for each original tensor
        self._views = []
        offset = 0
        for t, numel in zip(tensors, numels):
            view = self._data[offset : offset + numel].view(t.shape)
            self._views.append(view)
            offset += numel

Communication primitives — symmetric and minimal:

def reduce_scatter(self, group=None):
    # full buffer → local shard (SUM reduction)
    dist.reduce_scatter_tensor(self._local_shard, self._data, group=group)

def all_gather(self, group=None):
    # local shard → full buffer
    dist.all_gather_into_tensor(self._data, self._local_shard, group=group)

Since _local_shard is a view into _data, writing to the shard writes directly into the correct region of the full buffer.

Step 3: DDP Layer

The DDP layer creates two Buffers — one for parameters, one for gradients — and remaps param.data and param.main_grad:

class DistributedDataParallel:
    def __init__(self, module, rank, world_size, process_group=None, device=None):
        self.params = list(module.parameters())
        self.param_buffer = Buffer(self.params, rank, world_size,
                                   dtype=torch.bfloat16, device=device)
        self.grad_buffer = Buffer(self.params, rank, world_size,
                                  dtype=torch.bfloat16, device=device)
        self._init_buffers()

    def _init_buffers(self):
        param_views = self.param_buffer.get_views()
        grad_views = self.grad_buffer.get_views()
        for param, p_view, g_view in zip(self.params, param_views, grad_views):
            p_view.copy_(param.data)
            param.data = p_view         # param.data now points into buffer
            param.main_grad = g_view    # attach gradient view

Gradient sync: the critical handoff from autograd to the buffer system. PyTorch’s autograd writes to param.grad (a separate tensor), so we must manually copy into the grad buffer view:

def sync_grads(self):
    for param in self.params:
        if param.grad is not None:
            param.main_grad.copy_(param.grad)
            param.grad = None
    self.grad_buffer.reduce_scatter(self.process_group)

Parameter sync is a single all-gather:

def sync_params(self):
    self.param_buffer.all_gather(self.process_group)

Step 4: Distributed Optimizer

The simplest layer — clone the rank’s parameter shard as an fp32 master, hand it to standard Adam:

class DistributedOptimizer:
    def __init__(self, ddp, lr=1e-3, betas=(0.9, 0.999), eps=1e-8):
        self.ddp = ddp
        self.shard_fp32 = torch.nn.Parameter(
            ddp.get_param_shard().clone().float(),
            requires_grad=False,
        )
        self.adam = torch.optim.Adam([self.shard_fp32], lr=lr, betas=betas, eps=eps)

    def step(self):
        # 1. reduced grad shard (bf16) → fp32
        self.shard_fp32.grad = self.ddp.get_grad_shard().float()
        # 2. Adam updates fp32 master
        self.adam.step()
        # 3. fp32 → bf16, write back to param buffer shard
        self.ddp.write_param_shard(self.shard_fp32.data.bfloat16())
        # 4. all-gather to restore full parameters
        self.ddp.sync_params()

Adam sees only one fp32 tensor. The distribution is entirely transparent to the optimizer.

Step 5: Testing

The test strategy: run the exact same precision path (bf16 forward/backward → fp32 Adam → bf16 writeback) on a single process without ZeRO-1, and compare parameters step-by-step against the multi-process ZeRO-1 run.

All ranks use the same random seed for input data, and the loss is divided by world_size:

loss = output.sum() / world_size
loss.backward()

ddp.sync_grads()
optimizer.step()

Since all ranks have identical gradients, after reduce-scatter(SUM), each shard equals N * (grad / N) = grad, matching the reference exactly.

Tests cover three scenarios:

The verification criterion is atol=0, rtol=0 — bit-exact equality. Since the precision path is identical, ZeRO-1 results should match single-GPU training exactly.

Profile Results

Model: 10-layer MLP with \(H=4096\), ~1.34B parameters, world_size=4.

  Baseline (1 GPU) ZeRO-1 (4 GPUs)
bf16 model params 2.50 GB 2.50 GB
grad_buffer 2.50 GB
fp32 master params 5.00 GB 1.25 GB
Adam \(m\) 5.00 GB 1.25 GB
Adam \(v\) 5.00 GB 1.25 GB
Total 17.50 GB 8.75 GB

Theoretical saving: 50%. Optimizer state drops from 15.00 GB to 3.75 GB — a 75% reduction (i.e., \(1 - 1/N\)). However, ZeRO-1 introduces a persistent grad_buffer (2.50 GB) that baseline doesn’t have (baseline creates param.grad temporarily during backward and frees it during optimizer.step()). Net saving is ~42.7%, consistent with actual profile measurements.

Summary

ZeRO-1’s core insight is simple: since data-parallel ranks hold identical optimizer states, keep one copy and shard it. The three-layer architecture (Buffer → DDP → Distributed Optimizer) cleanly separates distributed communication from optimization logic. With no increase in total communication volume, optimizer state memory drops to \(1/N\).

For even larger models, ZeRO Stages 2 and 3 extend this idea further — sharding gradients and parameters as well — which would be the natural next topic to explore.