Adjoint Schrodinger Bridge Sampler

SDE
markov
Schrodinger bridge
Published

22 03 2026

Modified

22 03 2026

The adjoint matching and adjoint sampling notes showed how to learn a controlled diffusion that transports samples to a target \(\pi \propto e^{-E}\) using only on-policy trajectories and energy evaluations. The price for that simplicity was the memoryless condition: the base process must satisfy \(\mathbb{P}(X_0, X_1) = \mathbb{P}(X_0) \, \mathbb{P}(X_1)\), which in practice forces a Dirac delta prior \(X_0 = 0\). No Gaussian priors, no domain-specific priors (harmonic oscillators for molecules, etc.).

This note describes the Adjoint Schrodinger Bridge Sampler (ASBS) of (liu2025adjoint?), which removes the memoryless restriction entirely. The main idea: instead of a pure stochastic optimal control problem, solve a full Schrodinger bridge problem. The SB potentials decompose into a forward piece (the adjoint, same as before) and a backward piece (the corrector, which fixes the bias from a non-trivial prior). Alternating between learning these two pieces is exactly IPFP on path-space, with a global convergence guarantee.

The memoryless condition and its cost

Recall from the SOC notes the setup: a base diffusion \[ dX_t = f_t(X_t) \, dt + \sigma_t \, dW_t, \qquad X_0 \sim \mu, \] defines a path measure \(\mathbb{P}\). Adding a control \(u_t(x)\) gives \[ dX_t = {\left[ f_t(X_t) + \sigma_t \, u_t(X_t) \right]} \, dt + \sigma_t \, dW_t, \qquad X_0 \sim \mu, \tag{1}\] with path measure \(\mathbb{P}^u\). The SOC problem \[ \min_u \; \mathbb{E}_{\mathbb{P}^u} {\left[ \int_0^1 \tfrac{1}{2} \|u_t(X_t)\|^2 \, dt + g(X_1) \right]} \tag{2}\] has the optimal control \(u^\star_t(x) = \sigma_t \, \nabla \log h(t,x)\) where \(h(t,x) = \mathbb{E}_\mathbb{P}[e^{-g(X_1)} \mid X_t = x]\) is the value function exponential. The optimal joint distribution of endpoints is \[ p^\star(X_0, X_1) = \mathbb{P}(X_0, X_1) \, e^{-g(X_1) + V_0(X_0)}, \tag{3}\] where \(V_0(x) = -\log \int \mathbb{P}(X_1 \mid X_0 = x) \, e^{-g(X_1)} \, dX_1\) is the initial value function. To sample from \(\pi\) at time \(t=1\), we need the terminal marginal \(p^\star(X_1) = \pi(X_1)\). Marginalizing Equation 3 over \(X_0\) gives \[ p^\star(X_1) = \int \mathbb{P}(X_0, X_1) \, e^{-g(X_1) + V_0(X_0)} \, dX_0. \]

The initial value function \(V_0(X_0)\) couples \(X_0\) and \(X_1\) in a complicated way. The memoryless condition \(\mathbb{P}(X_0, X_1) = \mathbb{P}(X_0) \, \mathbb{P}(X_1)\) cuts this coupling: \[ p^\star(X_1) \stackrel{\text{memoryless}}{=} \mathbb{P}(X_1) \, e^{-g(X_1)} \underbrace{\int \mathbb{P}(X_0) \, e^{V_0(X_0)} \, dX_0}_{\text{constant}}. \] Setting \(g(x) = \log \frac{\mathbb{P}_1(x)}{\pi(x)}\) then gives \(p^\star(X_1) \propto \pi(X_1)\). The standard way to enforce memorylessness is \(f_t = 0\) and \(\mu = \delta_0\), so the base process is just scaled Brownian motion started at the origin. This works, but rules out informative priors.

Without the memoryless condition, the \(V_0(X_0)\) term biases the terminal distribution away from \(\pi\). The question is: can we correct for this bias while keeping the scalable adjoint matching framework?

From SOC to Schrodinger bridges

The answer comes from the Schrodinger bridge formulation. Instead of the SOC problem Equation 2 (which only constrains the control cost and terminal cost), solve the full SB problem with both endpoint constraints: \[ \min_{\mathbb{Q}: \, \mathbb{Q}_0 = \mu, \, \mathbb{Q}_1 = \pi} \; D_{\text{KL}}(\mathbb{Q}\,\|\, \mathbb{P}). \tag{4}\]

As discussed in the Schrodinger bridges note, the SB solution has path measure \[ \frac{d\mathbb{Q}^\star}{d\mathbb{P}}(X_{[0,1]}) \propto \widehat{\varphi}_0(X_0) \, \varphi_1(X_1), \] where the time-dependent SB potentials are \[ \left\{ \begin{aligned} \varphi_t(x) &= \mathbb{E}_\mathbb{P}[\varphi_1(X_1) \mid X_t = x], \\ \widehat{\varphi}_t(x) &= \mathbb{E}_\mathbb{P}[\widehat{\varphi}_0(X_0) \mid X_t = x], \end{aligned} \right. \tag{5}\] with boundary conditions determined by the marginal constraints \(\mathbb{Q}_0 = \mu\) and \(\mathbb{Q}_1 = \pi\).

The forward potential \(\varphi_t\) is a Doob h-transform: it generates the optimal forward drift \(\sigma_t^2 \, \nabla \log \varphi_t(x)\). The backward potential \(\widehat{\varphi}_t\) carries information about the initial distribution.

The key theorem of (liu2025adjoint?) connects the SB problem back to SOC.

SOC characteristics of SB. The kinetic-optimal drift \(u_t^\star\) of the SB problem Equation 4 also solves the SOC problem \[ \min_u \; \mathbb{E}_{\mathbb{P}^u} {\left[ \int_0^1 \tfrac{1}{2} \|u_t(X_t)\|^2 \, dt + \log \frac{ \textcolor{blue}{\widehat{\varphi}_1(X_1)}}{\pi(X_1)} \right]} \quad \text{s.t. } \eqref{eq-controlled-sde}. \tag{6}\]

Comparing to the adjoint sampling SOC problem where the terminal cost was \(g(x) = \log \frac{\mathbb{P}_1(x)}{\pi(x)}\), the only difference is the replacement \(\mathbb{P}_1 \to \textcolor{blue}{\widehat{\varphi}_1}\). This extra function \(\widehat{\varphi}_1\) is exactly the corrector that absorbs the initial value function bias.

Why does the corrector fix the bias?

Using Equation 3 with the terminal cost from Equation 6, the optimal joint under this SOC problem is \[ p^\star(X_0, X_1) = \mathbb{P}(X_0, X_1) \, \exp {\left( -\log \frac{\widehat{\varphi}_1(X_1)}{\pi(X_1)} + V_0(X_0) \right)} . \] Marginalizing over \(X_0\): \[ \begin{aligned} p^\star(X_1) &= \frac{\pi(X_1)}{\widehat{\varphi}_1(X_1)} \int \mathbb{P}(X_0, X_1) \, e^{V_0(X_0)} \, dX_0 \\ &= \frac{\pi(X_1)}{\widehat{\varphi}_1(X_1)} \int \mathbb{P}(X_1 \mid X_0) \, \widehat{\varphi}_0(X_0) \, dX_0 \\ &= \frac{\pi(X_1)}{\widehat{\varphi}_1(X_1)} \, \widehat{\varphi}_1(X_1) = \pi(X_1). \end{aligned} \] The second line uses \(\mu(X_0) \, e^{V_0(X_0)} = \widehat{\varphi}_0(X_0)\) (which follows from the SB potential definitions), and the third line uses the backward propagation in Equation 5. The corrector \(\widehat{\varphi}_1\) cancels the bias exactly.

In other words, every SB problem decomposes into an SOC problem (the “adjoint” part) plus a corrector function \(\widehat{\varphi}_1\) that accounts for the non-trivial coupling between \(X_0\) and \(X_1\).

The two matching objectives

For the Boltzmann sampling case (\(\pi \propto e^{-E}\), \(f_t = 0\)), the terminal cost in Equation 6 becomes \(g(x) = E(x) + \log \widehat{\varphi}_1(x) + \text{const}\). Applying adjoint matching to this SOC problem yields the adjoint matching (AM) objective: \[ u^\star = \mathop{\mathrm{argmin}}_u \; \mathbb{E}_{\mathbb{P}_{t \mid 0,1} \, p^{\bar{u}}_{0,1}} {\left[ \big\| u_t(X_t) + \sigma_t {\left( \nabla E + \textcolor{blue}{\nabla \log \widehat{\varphi}_1} \right)} (X_1) \big\|^2 \right]} , \tag{7}\] where \(\bar{u} = \texttt{stopgrad}(u)\). Compared to the adjoint sampling objective, the only new ingredient is the \( \textcolor{blue}{\nabla \log \widehat{\varphi}_1}\) term. When \(\mu = \delta_0\) (memoryless case), \(\widehat{\varphi}_1 = \mathbb{P}_1\) and this recovers the standard adjoint matching loss.

The corrector \(\nabla \log \widehat{\varphi}_1\) itself has a variational characterization. From the definition Equation 5, \(\widehat{\varphi}_t(x)\) is the expected value of \(\widehat{\varphi}_0(X_0)\) given \(X_t = x\) under the base measure. This leads to the corrector matching (CM) objective: \[ \nabla \log \widehat{\varphi}_1 = \mathop{\mathrm{argmin}}_h \; \mathbb{E}_{p^{u^\star}_{0,1}} {\left[ \big\| h(X_1) - \nabla_{X_1} \log \mathbb{P}(X_1 \mid X_0) \big\|^2 \right]} . \tag{8}\]

This is a regression problem: fit \(h\) to the score of the base transition density \(\nabla_{X_1} \log \mathbb{P}(X_1 \mid X_0)\), evaluated at pairs \((X_0, X_1)\) drawn from the current optimal process. When the base drift is \(f_t = 0\), the transition density \(\mathbb{P}(X_1 \mid X_0)\) is Gaussian and its score is known in closed form.

Derivation of the corrector matching objective:

From Equation 5, \(\widehat{\varphi}_t(x) = \int \mathbb{P}_{t \mid 0}(x \mid y) \, \widehat{\varphi}_0(y) \, dy\). Differentiating: \[ \begin{aligned} \nabla \log \widehat{\varphi}_t(x) &= \frac{\int \nabla_x \log \mathbb{P}_{t \mid 0}(x \mid y) \, \mathbb{P}_{t \mid 0}(x \mid y) \, \widehat{\varphi}_0(y) \, dy}{\widehat{\varphi}_t(x)} \\ &= \mathbb{E}_{p^\star} {\left[ \nabla_x \log \mathbb{P}_{t \mid 0}(x \mid X_0) \mid X_t = x \right]} . \end{aligned} \] The last line uses Bayes’ rule: \(p^\star(X_0 \mid X_t = x) \propto \mathbb{P}_{t|0}(x \mid X_0) \, \widehat{\varphi}_0(X_0)\). Writing this conditional expectation as a least-squares regression target gives Equation 8.

Alternating optimization = IPFP

The AM objective Equation 7 requires knowing \(\nabla \log \widehat{\varphi}_1\). The CM objective Equation 8 requires samples from \(p^{u^\star}\). Neither can be solved in isolation. The natural approach: alternate. Given a corrector approximation \(h^{(k-1)}\) from stage \(k-1\):

  1. Adjoint matching step: Solve Equation 7 with \(\nabla \log \widehat{\varphi}_1 \approx h^{(k-1)}\) to get \(u^{(k)}\).
  2. Corrector matching step: Solve Equation 8 with \(u^\star \approx u^{(k)}\) to get \(h^{(k)}\).

Initialize with \(h^{(0)} = 0\). The first AM stage then simply regresses \(u^{(1)}\) onto \(\sigma_t \nabla E(X_1)\), i.e. pure energy-guided transport with no corrector.

This alternating scheme has a clean variational interpretation. At each stage:

  • The AM step solves a forward half-bridge: minimize \(D_{\text{KL}}(\mathbb{Q}\,\|\, \mathbb{Q}^{\text{bwd}})\) subject to \(\mathbb{Q}_0 = \mu\), where \(\mathbb{Q}^{\text{bwd}}\) is a backward process defined by \(h^{(k-1)}\) and \(\pi\).

  • The CM step solves a backward half-bridge: minimize \(D_{\text{KL}}(\mathbb{P}^{u^{(k)}} \,\|\, \mathbb{Q})\) subject to \(\mathbb{Q}_1 = \pi\).

Alternating between forward and backward half-bridge projections is exactly the IPFP/Sinkhorn algorithm applied on path-space. Just as Sinkhorn alternates between enforcing the two marginal constraints of a static Schrodinger bridge, ASBS alternates between enforcing the source constraint (\(\mathbb{Q}_0 = \mu\)) and the target constraint (\(\mathbb{Q}_1 = \pi\)).

Global convergence

The IPFP interpretation immediately gives a convergence guarantee. The analysis of (de2021diffusion?) shows that IPFP on path-space converges to the SB solution, provided each half-bridge step is solved exactly. Applied to ASBS:

Theorem (Global convergence). If each AM and CM stage achieves its critical point, then \[ \lim_{k \to \infty} u^{(k)} = u^\star, \] where \(u^\star\) is the kinetic-optimal drift solving Equation 4.

In practice, the matching objectives are solved approximately with neural networks, and the convergence is up to approximation errors. The practical observation from (liu2025adjoint?) is that a small number of stages (3-5) suffices.

Summary and relation to prior work

The table below clarifies the relationships:

Method Prior Terminal cost \(g(x)\) Corrector
Adjoint Sampling \(\delta_0\) (memoryless) \(\log \frac{\mathbb{P}_1(x)}{\pi(x)}\) None (or \(\nabla \log \mathbb{P}_1\))
ASBS Arbitrary \(\mu\) \(\log \frac{\widehat{\varphi}_1(x)}{\pi(x)}\) Learned \(\nabla \log \widehat{\varphi}_1\)

ASBS reduces to Adjoint Sampling when \(\mu = \delta_0\): the corrector becomes \(\widehat{\varphi}_1 = \mathbb{P}_1\) and the CM step is trivially solved.

The broader picture:

  • The Doob h-transform provides the forward potential \(\varphi_t\) (the “adjoint” piece).
  • The backward potential \(\widehat{\varphi}_t\) is the “corrector” piece unique to the SB formulation.
  • The decomposition \(\text{SB} = \text{SOC (adjoint)} + \text{corrector}\) is the core structural insight of ASBS.
  • Alternating between them is IPFP, connecting back to the Schrodinger bridge note.
  • The adjoint method provides the backward ODE for computing the adjoint state, which underlies the AM objective.

The practical benefit: domain-specific priors (Gaussian, harmonic oscillators for molecular systems) can substantially improve sampling, since the base process already starts in a reasonable region of state space, rather than at the origin.