From Denoising Diffusion to ODEs

DDPM
score
Published

02 07 2023

Setting & Goals

Consider an empirical data distribution \(\pi_{\mathrm{data}}\). In order to simulate approximate samples from \(\pi_{\mathrm{data}}\), Denoising Diffusion Probabilistic Models (DDPM) simulate a forward diffusion process \(\{X_t\}_{[0,T]}\) on an interval \([0,T]\). The diffusion is initialized at the data distribution, i.e. \(X_0 \sim \pi_{\mathrm{data}}\), and is chosen so that that the distribution of \(X_T\) is very close to a known and tractable reference distribution \(\pi_{\mathrm{ref}}\), e.g. a Gaussian distribution. Denote by \(p_t(dx)\) the marginal distribution at time \(0 \leq t \leq T\), i.e. \(\mathop{\mathrm{\mathbb{P}}}(X_t \in dx) = p_t(dx)\). By choosing the forward distribution with simple and tractable transition probabilities, e.g. an Ornstein-Uhlenbeck, it is relatively easy to estimate \(\nabla \log p_t(x)\) from simulated data: this can be formulated as a simple regression problem. This allows one to simulate the diffusion backward in time and generate approximate samples from \(\pi_{\mathrm{data}}\). Why this is useful is another question…

The fact that the mapping from data-samples at time \(t=0\) to (approximate) Gaussian samples at time \(t=T\) is stochastic and described by diffusion processes is cumbersome. This would be much more convenient to build a deterministic mapping between the data-distribution \(\pi_{\mathrm{data}}\) and the Gaussian reference distribution \(\pi_{\mathrm{ref}}\): this would allows one to associate a likelihood to data samples and to easily “encode”/“decode” data-samples. To so this, one can try to replace diffusion by Ordinary Differential Equations.

The diffusion-ODE trick

Consider an arbitrary diffusion process \(dX = \mu(t,X) \, dt + dB\) with associated distribution \(p_t(dx)\) at time \(t\). The Fokker-Planck equation that describes the evolution of \(p_t\) reads

\[ \partial_t p_t = -\mathop{\mathrm{div}}( p_t \, \mu) + \frac12 \, \Delta p_t \]

If there were no diffusion term \(dB\) and \(X\) was describing instead the evolution of differential equation \(dX/dt = F(t,X)\), the associated evolution of the density of \(X\) would simply read

\[ \partial_t p_t = -\mathop{\mathrm{div}}( p_t \, F). \tag{1}\]

If one can find a vector field \(F(t,x)\) such that

\[ -\mathop{\mathrm{div}}( p_t \, \mu) + \frac12 \, \Delta p_t = -\mathop{\mathrm{div}}( p_t \, F), \]

then one can basically replace diffusions by ODEs. The diffusion-ODE trick is the simple remark that

\[ F(t,x) = \mu(t,x) \; {\color{red} - \frac12 \, \nabla_x \log p_t(x)} \]

does exactly this, as algebra immediately shows it. The additional term \({\color{red} - \frac12 \, \nabla_x \log p_t(x)}\) is intuitive. The coefficient \({\color{blue} 1/2}\) is because one is trying to match the term \({\color{blue} (1/2)} \, \Delta p_t\) in the Fokker-Planck equation. And the overall term \({\color{red} - \frac12 \, \nabla_x \log p_t(x)}\) is just driving the ODE in direction where the probability density \(p_t\) is small, i.e. it follows the negative gradient of the log-density: it is exactly trying to imitate the diffusion term \(dB\).

What this means is that a diffusion process \(dX = \mu(t,X) \, dt + dB\) started from \(X_0 \sim p_0\) and marginal distribution \(X_t \sim p_t(dx)\) can be imitated by an ODE process \(dY/dt = \mu(t,Y) - \frac12 \, \nabla \log p_t(Y)\) started from \(p_0\). At any time \(t>0\), the marginal distributions of \(X_t\) and \(Y_t\) both exactly equal \(p_t(dx)\).

The diffusion-ODE trick: application to DDPM

Consider a DDPM with forward dynamics given by an Ornstein-Uhlenbeck (OU) process

\[ d\overrightarrow{X} = -\frac12 \, \overrightarrow{X} \, dt + dB \]

and initial condition \(X_0 \sim p_0 \equiv \pi_{\mathrm{data}}\). As explained in these notes, it is relatively straightforward to estimate the score function

\[ \mathcal{S}(t,x_t) \; = \; \nabla \log p_t(x_t) \]

from data. This means that the forward OU process can be replaced by the forward ODE

\[ \frac{d}{dt} \overrightarrow{Y_t} = -\frac12 \, \overrightarrow{Y_t} - \frac12 \mathcal{S}(t,\overrightarrow{Y_t}) = F(t,Y_t) \]

with \(F(t,x) = -\tfrac{1}{2} x -\tfrac{1}{2} \mathcal{S}(t,x)\). Similarly, the reverse diffusion (i.e. the “denoising” diffusion) defined as \(\overleftarrow{X}_s = X_{T-s}\) follows the dynamics

\[d\overleftarrow{X}_s = \frac12 \overleftarrow{X}_s \, ds + \mathcal{S}(T-s,\overleftarrow{X}_s) + dW.\]

As described for the first time in the beautiful article (Song et al. 2020), the diffusion-ODE trick now shows that the denoising diffusion can be replaced by a denoising ODE with dynamics

\[ \begin{align} \frac{d}{ds} \overleftarrow{Y_s} &= \frac12 \overleftarrow{Y_s} \, ds + \mathcal{S}(T-s,\overleftarrow{Y_s}) - \frac12 \, \mathcal{S}(T-s,\overleftarrow{X}_s)\\ &= \frac12 \overleftarrow{Y_s} \, ds + \frac12 \, \mathcal{S}(T-s,\overleftarrow{Y_s}) = -F(T-s, \overleftarrow{Y_s}). \end{align} \]

Interestingly [and I do not know whether there was an obvious way of seeing this from the start], this shows that the forward and backward ODE are actually the same but run forward and backward in time. They corresponds to the ODE described by the vector field

\[ F(t,x) \; = \; -\frac{1}{2} x -\frac{1}{2} \mathcal{S}(t,x). \tag{2}\]

The animation belows display the denoising ODE and the associated vector field Equation 2.

Likelihood computation

With the diffusion-ODE trick, we have just seen that it is possible to build a vector fields \(F[0,T] \times \mathbb{R}^d \to \mathbb{R}^d\) such that the forward ODE

\[ \frac{d}{dt} \overrightarrow{Y_t} = F(t,\overrightarrow{Y_t}) \qquad \textrm{initialized at} \qquad \overrightarrow{Y}_0 \sim \pi_{\mathrm{data}} \tag{3}\]

and the backward ODE defined as

\[ \frac{d}{ds} \overleftarrow{Y_s} = -F(T-s,\overleftarrow{Y_s}) \qquad \textrm{initialized at} \qquad \overleftarrow{Y}_0 \sim \pi_{\mathrm{ref}} \]

are such that \(\overrightarrow{Y}_T \approx \pi_{\mathrm{ref}}\) and \(\overleftarrow{Y}_T \approx \pi_{\mathrm{data}}\).

In general, consider a vector field \(F(t,x)\) and a bunch of particles distributed according to a distribution \(p_t\) at time \(t\). If each particle follows the vector field for an amount of time \(\delta \ll 1\), the particles that were in the vicinity of some \(x \in \mathbb{R}^d\) at time \(t\) end up in the vicinity of \(x + F(x,t) \, \delta\) at time \(t+\delta\). At the same time, a volume element \(dx\) around \(x \in \mathbb{R}^d\) gets stretch by a factor \(1+\delta \, \mathop{\mathrm{Tr}}[\mathop{\mathrm{\mathrm{Jac}}}F(x,t)] = 1 + \delta \mathop{\mathrm{div}}F(x,t)\) while following the vector field \(F\), which means that the density of particles at time \(t+\delta\) and around \(x + F(x,t) \, \delta\) equals \(p_t(x) / [1 + \delta \mathop{\mathrm{div}}F(x,t)]\). In other words \(\log p_{t+\delta}(x + F(x,t) \, \delta) \approx \log p_t(x) - \delta \, \mathop{\mathrm{div}}F(x,t)\). This means that if we follows a trajectory of \(\tfrac{d}{dt} X_t = F(t,X_t)\) one gets

\[ \log p_T(X_T) = \log p_0(X_0) - \int_{t=0}^{T} \mathop{\mathrm{div}}F(X_t,t) \, dt. \]

That is the Lagrangian description of the density \(p_t\) of particles. Indeed, one could directly get this identity by differentiating \(p_t(X_t)\) with respect to time while using the continuity Equation 1. When applied to the DDPM, this gives a way to assign likelihood the data samples, namely

\[ \log \pi_{\mathrm{data}}(x) = \log \pi_{\mathrm{ref}}(\overrightarrow{Y_T}) + \int_{t=0}^{T} \mathop{\mathrm{div}}F(t, \overrightarrow{Y_t})\, dt \]

where \(\overrightarrow{Y_t}\) is trajectory of the forward ODE Equation 3 initialized as \(\overrightarrow{Y_0} = x\). Note that in high-dimensional setting, it may be computationally expensive to compute the divergence term \(\mathop{\mathrm{div}}F(t, \overrightarrow{Y_t})\) since it typically is \(d\) times slower that a gradient computation; for this reason, it is often advocated to use the Hutchinson trace estimator to get an unbiased estimate of it at a much lower computational cost.

References

Song, Yang, Jascha Sohl-Dickstein, Diederik P Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole. 2020. “Score-Based Generative Modeling Through Stochastic Differential Equations.” ICLR 2021. https://arxiv.org/abs/2011.13456.