Masked Discrete Diffusion

DDPM
score
Published

21 10 2025

Modified

21 10 2025

Masked Discrete Diffusion

We consider a finite state space \[ \mathcal{X} = \{M, 1,2, \ldots, V\}, \] where \(M\) is a special masked state and \(1, \ldots, V\) correspond to token values in a vocabulary of size \(V\). On the time interval \([0,T]\), we define a continuous-time Markov chain with initial distribution \(p_0\) and time-dependent infinitesimal rate matrix \(Q_t \in \mathbb{R}^{(V+1)\times(V+1)}\) so that for any \(x \ne y\), \[ \mathbb{P}(X_{t+h} = y \mid X_t = x) = Q_t(x,y) \, h + o(h). \]

If the total jump rate out of state \(x\) is \(J_t(x) = \sum_{y \ne x} Q_t(x,y)\) , then

\[\mathbb{P}(X_{t+h} = x \mid X_t = x) = 1 - J_t(x) \, h + o(h).\]

Bayes’ rule implies that the time-reversal of this Markov chain is itself Markov, with infinitesimal rate matrix \(Q_t^{\star}\) satisfying \[ Q_t^{\star}(x,y) = \frac{p_t(y)}{p_t(x)} \, Q_t(y,x), \] where \(p_t\) is the marginal distribution of \(X_t\) at time \(t\). We have: \[ \mathbb{P}(X_{t-h} = y \mid X_t = x) = Q_t^{\star}(x,y) \, h + o(h). \]

We are interested in modeling a Markov chain that progressively masks the initial value into the masked state \(M\) as time \(t\) goes from \(0\) to \(T\). Transitions are only allowed from any token \(i \in \{1,\dots,V\}\) to the masked state \(M\), and once in \(M\) the process remains there. Thus, outside the diagonal, the only nonzero entries of \(Q_t\) are \(Q_t(x,M)\). As it will be useful later, we denote by \(\tau\) the jump time to \(M\) and we assume \(\tau < T\) almost surely, so that \(X_T = M\) with probability one, and that \(\tau\) has a continuous distribution. In words: the process starts at some token value and at a random time \(\tau\) jumps to the masked state \(M\), where it remains until time \(T\).

Extension to Sequences

We are interested in modeling sequences comprised of \(L\) discrete tokens, eg: binary images, genomic sequences, chemical compounds, protein sequences, etc… Each token takes value in \(\{1,2,\ldots,V\}\). We denote by \(\overline{p}_0\) the data distribution over such sequences. For this purpose, we consider \(L\) independent copies of the above Markov chain, one per coordinate: \[ X_t = (X_t^1, \ldots, X_t^L), \] each with rate matrix \(Q_t\) as defined previously. At time \(T\), the process reaches the fully masked sequence \(\overline{X}_T = (M, \ldots, M)\) with probability one. Denote by \(\tau_i\) the jump time of coordinate \(i\). Since the jump times \(\tau_i\) are almost surely distinct, the infinitesimal rate matrix \(\overline{Q}_t\) of the joint process is nonzero only when a single coordinate changes. If \(x,\widehat{x} \in \mathcal{X}^L\) differ by a single coordinate \(i\), we have \[ \overline{Q}_t(x,\widehat{x}) = Q_t(x^i, \widehat{x}^i). \]

As before, the time-reversal has infinitesimal rate matrix \(\overline{Q}_t^{\star}\) satisfying \[ \overline{Q}_t^{\star}(\widehat{x}, x) = \frac{\overline{p}_t(x)}{\overline{p}_t(\widehat{x})} \, Q_t(x^i, \widehat{x}^i), \]

where \(\overline{p}_t\) is the marginal distribution of \(\overline{X}_t\). Since \(x\) and \(\widehat{x}\) differ at coordinate \(i\) only, for \(\overline{Q}_t^{\star}(\widehat{x}, x)\) to be non-zero, necessarily \(\widehat{x}^i = M\) and \(x^i \in \{1,\ldots,V\}\). Let \(S = \{j : x^j \neq M\}\) be the set of unmasked coordinates in \(\widehat{x}\). To observe configuration \(\widehat{x}\) at time \(t\), the \((L - |S|)\) masked coordinates must have \(\tau < t\) and the \(|S|\) unmasked ones \(\tau \ge t\): \[ \overline{p}_t(\widehat{x}) = \mathbb{P}(\tau < t)^{L-|S|} \, \mathbb{P}(\tau \ge t)^{|S|} \, \overline{p}_0(\widehat{x}^{S}). \]

Similarly, and since \(x\) differs from \(\widehat{x}\) only at coordinate \(i\): \[ \begin{align*} \overline{p}_t(x) &= \mathbb{P}(\tau < t)^{L-|S|-1} \, \mathbb{P}(\tau \ge t)^{|S|+1} \, \overline{p}_0(x^{S \cup \{i\}})\\ &= \mathbb{P}(\tau < t)^{L-|S|-1} \, \mathbb{P}(\tau \ge t)^{|S|+1} \, \overline{p}_0(\widehat{x}^{S})\, \overline{p}_0(x^i \mid \widehat{x}^{S}). \end{align*} \]

This shows that the time-reversal rate matrix becomes \[ \overline{Q}_t^{\star}(\widehat{x}, x) = R(t) \, \overline{p}_0(x^i \mid \widehat{x}^{S})\, Q_t(x^i, M), \tag{1}\]

with time dependent scalar \(R(t) = \frac{\mathbb{P}(\tau \ge t)}{\mathbb{P}(\tau < t)}\). To simulate the reverse process that progressively unmasks a fully masked sequence, one only needs to model the conditional distribution \(\overline{p}_0(\widehat{x}^i \mid x^{S})\) of the data distribution \(\overline{p}_0\). This is precisely the prediction task of masked language models such as BERT, which estimate token probabilities conditioned on visible context.

Training

To train the denoising model, Equation 1 shows that it is natural to parametrize the conditional distribution

\[f_\theta(x^i \mid \widehat{x}^{S}) \approx \overline{p}_0(x^i \mid \widehat{x}^{S})\]

for all sets \(S \subset \{1,\ldots,L\}\) with \(i \notin S\). Once done, one can define the rate matrix of the time-reversal process as:

\[ \overline{Q}_{t,\theta}^{\star}(\widehat{x}, x) = R(t) \, f_\theta(x^i \mid \widehat{x}^{S}) \, Q_t(x^i, M). \]

If one denotes by \(\mathbb{P}\) the law of the forward noising process started from \(\overline{p}_0\), and by \(\mathbb{P}_{\theta}\) the law of the time-reversal process started from the fully masked sequence \((M,M, \ldots, M)\) at time \(T\) and with learned denoising model \(f_\theta\), one can train the model by minimizing

\[ D_{\text{KL}} {\left( \mathbb{P}\; || \; \mathbb{P}_{\theta} \right)} = \mathbb{E}_{\mathbb{P}} {\left( \log \frac{\mathbb{P}}{\mathbb{P}_{\theta}} \right)} . \]

Consider a trajectory \(x_{[0,T]}\) of the forward noising process. The jump at time \(\tau_i\) of the \(i\)-th coordinate is denoted by \(\Delta_i: = (x_{\tau_i^-}^i, x_{\tau_i^+}^i) = (x_0^i, M)\). For simplifying the notations, we denote the reverse jump by \(\Delta_i^{\star}\). The log-likelihood ratio between the two processes is easily shown to be:

\[ \log \frac{\mathbb{P}}{\mathbb{P}_{\theta}}(x_{[0,T]}) = \log \overline{p}_0(x_0) + \sum_i \log \frac{\overline{Q}_{\tau_i}(\Delta_i)}{\overline{Q}^{\star}_{\tau_i, \theta}(\Delta_i^{\star})} - \int_0^T {\left\{ \overline{J}_t(x_t) - \overline{J}^{\star}_{t,\theta}(x_t) \right\}} \, dt, \]

where \(\overline{J}_t\) and \(\overline{J}^{\star}_{t,\theta}\) are the total jump rates of the forward and reverse processes respectively. Since \(\overline{J}^{\star}_{t,\theta}(x_t)\) in fact does not depend on \(\theta\), minimizing the KL divergence is equivalent to minimizing:

\[ -\mathbb{E}_{\mathbb{P}} {\left( \sum_{i=1}^L \log f_\theta(X_0^i \mid X_{\tau_i}^{S_{{\tau_i}}}) \right)} \]

where \(S_t\) is the set of unmasked coordinates at time \(t\). It is more convenient to rewrite this quantity as an integral over time so that one can sample a time \(t\) uniformly in \([0,T]\) during training. With the Dirac delta function \(\delta[\tau_i = t]\), we can rewrite this expectation as:

\[ \begin{align*} -\mathbb{E}_{\mathbb{P}} {\left( \sum_{i=1}^L \log f_\theta(X_0^i \mid X_{\tau_i}^{S_{i}}) \right)} &= -\mathbb{E}_{\mathbb{P}} \int_{t=0}^{T} \sum_{i=1}^L \delta[\tau_i = t] \, \log f_\theta(X_0^i \mid X_{t}^{S_{t}}) \, dt\\ &= -\int_{t=0}^{T} \frac{\dot{\beta_t}}{\beta_t} \sum_{i: \, x_t^i=M} \log f_\theta(X_0^i \mid X_{t}^{S_{t}}) \, dt \end{align*} \]

where \(\beta_t = \mathbb{P}(\tau \le t)\) so that \(\mathbb{P}(\tau \in dt | \tau \le t) = \frac{\dot{\beta_t}}{\beta_t} \, dt\). For training, it suffices to sample \(t\) uniformly in \([0,T]\), then choose \(X_0 \sim \overline{p}_0\), then sample the noised configuration \(X_t\) according to the forward process, and finally obtain an unbiased estimate of the loss. The term \(\dot{\beta_t}/\beta_t\) is large for small \(t\), counter-balancing the fact that the reconstruction task is much easier when only a few tokens are masked. For standard denoising diffusion model, there is a similar “signal-to-noise” weighting term that balances the easy and hard denoising tasks.

Conclusion

Discrete diffusion models with one-way masking are mathematically almost identical to masked language models. Hence their similar behavior and performance on text generation tasks are not coincidental. The ideas summarized in these notes were developed in a very interesting stream of papers, including (Ou et al. 2024), (Sahoo et al. 2024), (Shi et al. 2024) and a number of more recent works. One of the potential drawbacks of such masked discrete diffusion models is that the support of the noising distribution is typically strictly smaller, and indeed often much smaller, than the whole state space. This means that when the denoising model is not perfect and wanders outside the support of the noising distribution, one can quickly end up in regions never seen during training. This can leads to poor sample quality and unstable behavior. Other discrete diffusion models such as the ones reaching the uniform distribution over all tokens at time \(T\) are not as badly affected by this issue, although they do suffer from other important computational and modeling challenges. Exciting research directions remain to be explored in this area!

References

Ou, Jingyang, Shen Nie, Kaiwen Xue, Fengqi Zhu, Jiacheng Sun, Zhenguo Li, and Chongxuan Li. 2024. “Your Absorbing Discrete Diffusion Secretly Models the Conditional Distributions of Clean Data.” arXiv Preprint arXiv:2406.03736.
Sahoo, Subham, Marianne Arriola, Yair Schiff, Aaron Gokaslan, Edgar Marroquin, Justin Chiu, Alexander Rush, and Volodymyr Kuleshov. 2024. “Simple and Effective Masked Diffusion Language Models.” Advances in Neural Information Processing Systems 37: 130136–84.
Shi, Jiaxin, Kehang Han, Zhe Wang, Arnaud Doucet, and Michalis Titsias. 2024. “Simplified and Generalized Masked Diffusion for Discrete Data.” Advances in Neural Information Processing Systems 37: 103131–67.