From Denoising Diffusion to ODEs

DDPM
score
Published

02 07 2023

Setting & Goals

Consider an empirical data distribution πdata. In order to simulate approximate samples from πdata, Denoising Diffusion Probabilistic Models (DDPM) simulate a forward diffusion process {Xt}[0,T] on an interval [0,T]. The diffusion is initialized at the data distribution, i.e. X0πdata, and is chosen so that that the distribution of XT is very close to a known and tractable reference distribution πref, e.g. a Gaussian distribution. Denote by pt(dx) the marginal distribution at time 0tT, i.e. P(Xtdx)=pt(dx). By choosing the forward distribution with simple and tractable transition probabilities, e.g. an Ornstein-Uhlenbeck, it is relatively easy to estimate logpt(x) from simulated data: this can be formulated as a simple regression problem. This allows one to simulate the diffusion backward in time and generate approximate samples from πdata. Why this is useful is another question…

The fact that the mapping from data-samples at time t=0 to (approximate) Gaussian samples at time t=T is stochastic and described by diffusion processes is cumbersome. This would be much more convenient to build a deterministic mapping between the data-distribution πdata and the Gaussian reference distribution πref: this would allows one to associate a likelihood to data samples and to easily “encode”/“decode” data-samples. To so this, one can try to replace diffusion by Ordinary Differential Equations.

The diffusion-ODE trick

Consider an arbitrary diffusion process dX=μ(t,X)dt+dB with associated distribution pt(dx) at time t. The Fokker-Planck equation that describes the evolution of pt reads

tpt=div(ptμ)+12Δpt

If there were no diffusion term dB and X was describing instead the evolution of differential equation dX/dt=F(t,X), the associated evolution of the density of X would simply read

(1)tpt=div(ptF).

If one can find a vector field F(t,x) such that

div(ptμ)+12Δpt=div(ptF),

then one can basically replace diffusions by ODEs. The diffusion-ODE trick is the simple remark that

F(t,x)=μ(t,x)12xlogpt(x)

does exactly this, as algebra immediately shows it. The additional term 12xlogpt(x) is intuitive. The coefficient 1/2 is because one is trying to match the term (1/2)Δpt in the Fokker-Planck equation. And the overall term 12xlogpt(x) is just driving the ODE in direction where the probability density pt is small, i.e. it follows the negative gradient of the log-density: it is exactly trying to imitate the diffusion term dB.

What this means is that a diffusion process dX=μ(t,X)dt+dB started from X0p0 and marginal distribution Xtpt(dx) can be imitated by an ODE process dY/dt=μ(t,Y)12logpt(Y) started from p0. At any time t>0, the marginal distributions of Xt and Yt both exactly equal pt(dx).

The diffusion-ODE trick: application to DDPM

Consider a DDPM with forward dynamics given by an Ornstein-Uhlenbeck (OU) process

dX=12Xdt+dB

and initial condition X0p0πdata. As explained in these notes, it is relatively straightforward to estimate the score function

S(t,xt)=logpt(xt)

from data. This means that the forward OU process can be replaced by the forward ODE

ddtYt=12Yt12S(t,Yt)=F(t,Yt)

with F(t,x)=12x12S(t,x). Similarly, the reverse diffusion (i.e. the “denoising” diffusion) defined as Xs=XTs follows the dynamics

dXs=12Xsds+S(Ts,Xs)+dW.

As described for the first time in the beautiful article (), the diffusion-ODE trick now shows that the denoising diffusion can be replaced by a denoising ODE with dynamics

ddsYs=12Ysds+S(Ts,Ys)12S(Ts,Xs)=12Ysds+12S(Ts,Ys)=F(Ts,Ys).

Interestingly [and I do not know whether there was an obvious way of seeing this from the start], this shows that the forward and backward ODE are actually the same but run forward and backward in time. They corresponds to the ODE described by the vector field

(2)F(t,x)=12x12S(t,x).

The animation belows display the denoising ODE and the associated vector field Equation 2.

Likelihood computation

With the diffusion-ODE trick, we have just seen that it is possible to build a vector fields F[0,T]×RdRd such that the forward ODE

(3)ddtYt=F(t,Yt)initialized atY0πdata

and the backward ODE defined as

ddsYs=F(Ts,Ys)initialized atY0πref

are such that YTπref and YTπdata.

In general, consider a vector field F(t,x) and a bunch of particles distributed according to a distribution pt at time t. If each particle follows the vector field for an amount of time δ1, the particles that were in the vicinity of some xRd at time t end up in the vicinity of x+F(x,t)δ at time t+δ. At the same time, a volume element dx around xRd gets stretch by a factor 1+δTr[JacF(x,t)]=1+δdivF(x,t) while following the vector field F, which means that the density of particles at time t+δ and around x+F(x,t)δ equals pt(x)/[1+δdivF(x,t)]. In other words logpt+δ(x+F(x,t)δ)logpt(x)δdivF(x,t). This means that if we follows a trajectory of ddtXt=F(t,Xt) one gets

logpT(XT)=logp0(X0)t=0TdivF(Xt,t)dt.

That is the Lagrangian description of the density pt of particles. Indeed, one could directly get this identity by differentiating pt(Xt) with respect to time while using the continuity Equation 1. When applied to the DDPM, this gives a way to assign likelihood the data samples, namely

logπdata(x)=logπref(YT)+t=0TdivF(t,Yt)dt

where Yt is trajectory of the forward ODE Equation 3 initialized as Y0=x. Note that in high-dimensional setting, it may be computationally expensive to compute the divergence term divF(t,Yt) since it typically is d times slower that a gradient computation; for this reason, it is often advocated to use the Hutchinson trace estimator to get an unbiased estimate of it at a much lower computational cost.

References

Song, Yang, Jascha Sohl-Dickstein, Diederik P Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole. 2020. “Score-Based Generative Modeling Through Stochastic Differential Equations.” ICLR 2021. https://arxiv.org/abs/2011.13456.