Disrete diffusion model - 1
Discrete diffusion model is getting viral. People from image generation field are wondering if we could use accelerated sampling to generate discrete sequences. However, this is hard to achieve with simple ideas, e.g., DDIM, that we have seen in continuous diffusion models used in image generation. The fundamental challenge stems from the difference between continuous and discrete state spaces.
The mathematical “trick” of DDIM relies entirely on the data living in a continuous vector space (like \(\mathbb{R}^n\), where images live).
Let’s look at the deterministic DDIM equation again: \(\begin{equation} x_{t-1} = \underbrace{\sqrt{\bar{\alpha}_{t-1}} \hat{x}_0(x_t, t)}_{\text{Component 1}} + \underbrace{\sqrt{1 - \bar{\alpha}_{t-1}} \cdot \epsilon_\theta(x_t, t)}_{\text{Component 2}} \end{equation}\)
This equation is a linear interpolation in a vector space. It works because:
- Scaling: You can scale a vector (an image) by a scalar like \(\sqrt{\bar{\alpha}_{t-1}}\).
- Addition: You can add two vectors (the “clean” component and the “noise” component) together.
In a discrete diffusion model (like for text), these operations are meaningless.
- The Data: Your data \(x_t\) is not a vector of real numbers. It’s a sequence of integers (token IDs) from a finite vocabulary \(V = \{1, 2, ..., \vert V \vert\}\).
- The “Noise”: The forward process isn’t adding Gaussian noise. It’s applying a probabilistic transition matrix \(Q_t\). This usually means “randomly replace a token with
[MASK]” or “randomly replace a token with another token.” - The Problem: What is \(\sqrt{0.5} \times \text{"hello"}\)? What is \(\text{"hello"} + \text{"world"}\)? These operations are undefined. You cannot perform linear interpolation on token IDs.
Because the math of DDPM/DDIM doesn’t apply, discrete diffusion models (like D3PM or Mask-GIT) use a completely different mathematical framework.
- Forward Process \(q(x_t \vert x_{t-1})\):
- This is defined by a transition matrix \(Q_t\).
- For example, \(Q_t(j \vert i)\) might be the probability of token \(i\) at \(t-1\) becoming token \(j\) at time \(t\). (e.g., 90% chance of staying the same, 5% chance of becoming
[MASK], 5% chance of becoming a random token).
- Reverse Process \(p_\theta(x_{t-1} \vert x_t)\):
- The model \(\theta\) is trained to invert this transition.
- It takes the corrupted text \(x_t\) (e.g., “The
[MASK]brown fox”) and predicts the probability distribution for the previous state \(x_{t-1}\) (e.g., “The[MASK]brown fox”). - Crucially, the model outputs a categorical distribution over the entire vocabulary for each token.
- \[p_\theta(x_{t-1} \vert x_t) = \text{softmax}(\dots)\]
- Sampling (The “Denoising” Step):
- To get \(x_{t-1}\) from \(x_t\), you must sample from this predicted categorical distribution.
- \[x_{t-1} \sim p_\theta(\cdot \vert x_t)\]
- This step is inherently stochastic. You are drawing a token from a probability distribution.
So, what is the discrete “deterministic” equivalent?
You can’t have the DDIM-style interpolation, but you can make the discrete sampling step deterministic. The closest equivalent to a “deterministic” step in a discrete model is not DDIM, but greedy decoding (i.e., argmax). Instead of sampling from the predicted distribution \(p_\theta(\cdot \vert x_t)\), you just choose the single most likely token at each position:
\begin{equation} x_{t-1} = \text{argmax}(p_\theta(\cdot \vert x_t)) \end{equation}
This is deterministic, but it is not “DDIM” and doesn’t share its mathematical properties. It’s simply greedy sampling, which often leads to lower-quality and less diverse results compared to stochastic sampling. The challenge of creating a “fast sampling” method for discrete models that also maintains high quality (the way DDIM does for continuous models) is an active area of research. These methods are typically called “ancestral sampling” modifications, not “implicit” models.
Enjoy Reading This Article?
Here are some more articles you might like to read next: