Adjoint Schrodinger Bridge Sampler

SDE
markov
Schrodinger bridge
Published

22 03 2026

Modified

22 03 2026

The adjoint sampling setup forces \(X_0 = 0\). What if we want a Gaussian prior, or a harmonic oscillator prior for molecules? With a non-trivial prior \(\mu\), the base process couples \(X_0\) and \(X_1\), and the SOC solution produces a terminal marginal \(p^\star(X_1) \neq \pi(X_1)\): the initial value function \(V_0(X_0)\) biases the result. To fix this bias while keeping the scalable adjoint matching framework, solve a full Schrodinger bridge instead of a plain SOC problem. This is the idea behind the Adjoint Schrodinger Bridge Sampler (ASBS) of (liu2025adjoint?).

From SOC to Schrodinger bridge: the corrector

Recall from the SB notes that the Schrodinger bridge \(\mathbb{Q}^\star\) minimizing \(D_{\text{KL}}(\mathbb{Q}\,\|\, \mathbb{P})\) subject to \(\mathbb{Q}_0 = \mu\) and \(\mathbb{Q}_1 = \pi\) has path measure \[ \frac{d\mathbb{Q}^\star}{d\mathbb{P}}(\boldsymbol{X}) \;\propto\; \widehat{\varphi}_0(X_0) \, \varphi_1(X_1), \tag{1}\] with time-dependent SB potentials \[ \varphi_t(x) = \mathbb{E}_\mathbb{P}[\varphi_1(X_1) \mid X_t = x], \qquad \widehat{\varphi}_t(x) = \mathbb{E}_\mathbb{P}[\widehat{\varphi}_0(X_0) \mid X_t = x]. \tag{2}\] The forward potential \(\varphi_t\) is a Doob h-transform: the SB process has drift \(f_t + \sigma_t^2 \nabla \log \varphi_t\). The backward potential \(\widehat{\varphi}_t\) propagates information about the initial distribution \(\mu\) forward in time.

Now, the SB solution \(\mathbb{Q}^\star\) is also a controlled diffusion, so it must solve some SOC problem. Which one? From Equation 1, the optimal joint density of endpoints satisfies \[ p^\star(X_0, X_1) \;\propto\; \mathbb{P}(X_0, X_1) \, \widehat{\varphi}_0(X_0) \, \varphi_1(X_1). \tag{3}\] Compare this with the SOC joint: for a terminal cost \(g\), the SOC optimal joint is \(p^\star(X_0, X_1) = \mathbb{P}(X_0, X_1) \, e^{-g(X_1) + V_0(X_0)}\), where \(V_0(x) = \log \mathbb{E}_\mathbb{P}[e^{-g(X_1)} \mid X_0 = x]\) is the initial value function (the HJB notes use a maximization convention where \(g\) has the opposite sign; our minimization convention follows adjoint sampling). Matching the two expressions requires \[ e^{-g(X_1)} \propto \varphi_1(X_1), \qquad e^{V_0(X_0)} \propto \frac{\widehat{\varphi}_0(X_0)}{\mu(X_0)} \tag{4}\] (up to multiplicative constants absorbed into the other factor). The first condition gives \(g(x) = -\log \varphi_1(x) + \text{const}\). To express this in terms of the corrector \(\widehat{\varphi}_1\), use the SB marginal constraint at \(t = 1\) from the SB notes: the SB density at time 1 is \(q_1(x) = p^{\mathbb{P}}_1(x) \, \widehat{\varphi}_1(x) \, \varphi_1(x) = \pi(x)\). So \(\varphi_1(x) = \pi(x) / (p^{\mathbb{P}}_1(x) \, \widehat{\varphi}_1(x))\). Substituting into \(g = -\log \varphi_1\): \[ \textcolor{blue}{g(x) = \log \frac{\widehat{\varphi}_1(x)}{\pi(x)}} + \text{const}. \tag{5}\] This is the modified terminal cost. Compared to the adjoint sampling terminal cost \(g(x) = \log p^{\mathbb{P}}_1(x) / \pi(x)\), the base marginal \(p^{\mathbb{P}}_1\) is replaced by \( \textcolor{blue}{\widehat{\varphi}_1}\): a corrector that accounts for the coupling between \(X_0\) and \(X_1\).

Does this corrector actually remove the \(V_0\) bias? Marginalize Equation 3 over \(X_0\). Using Equation 5 to replace \(\varphi_1\): \[ \begin{aligned} p^\star(X_1) &\propto \varphi_1(X_1) \int \mathbb{P}(X_0, X_1) \, \widehat{\varphi}_0(X_0) \, dX_0 \\ &= \varphi_1(X_1) \int \mathbb{P}(X_1 \mid X_0) \, \widehat{\varphi}_0(X_0) \, \mu(X_0) \, dX_0 \\ &= \varphi_1(X_1) \, \widehat{\varphi}_1(X_1) \, p^{\mathbb{P}}_1(X_1) \;\propto\; \pi(X_1). \end{aligned} \tag{6}\] The second-to-last step uses \(\int \mathbb{P}_{1|0}(x \mid y) \, \mu(y) \, \widehat{\varphi}_0(y) \, dy = p_1^{\mathbb{P}}(x) \, \widehat{\varphi}_1(x)\) (from the definition and Bayes’ rule), and the last step uses \(\varphi_1 \, \widehat{\varphi}_1 \, p_1^{\mathbb{P}} = \pi\). The corrector \(\widehat{\varphi}_1\) cancels the initial value function bias exactly. In Equation 4, the second condition \(\mu(X_0) \, e^{V_0(X_0)} \propto \widehat{\varphi}_0(X_0)\) is automatically satisfied by the SB boundary conditions at \(t = 0\).

In other words, every SB problem decomposes as: SOC with the standard adjoint (the Doob h-transform piece \(\varphi_t\)) plus a corrector \(\widehat{\varphi}_1\) that absorbs the prior bias.

Adjoint matching with corrector

For Boltzmann sampling (\(\pi \propto e^{-E}\), \(f_t = 0\)), the modified terminal cost Equation 5 becomes \(g(x) = E(x) + \log \widehat{\varphi}_1(x) + \text{const}\). Applying adjoint matching to this SOC problem, the lean adjoint is constant (since the base drift is zero, the lean adjoint ODE \(\dot{\tilde{a}} = -(\nabla_x b)^\top \tilde{a}\) becomes \(\dot{\tilde{a}} = 0\), same as in adjoint sampling) and equals \(\nabla g(X_1) = (\nabla E + \textcolor{blue}{\nabla \log \widehat{\varphi}_1})(X_1)\). The AM loss is \[ L_\text{AM}(u) = \mathbb{E}_{\mathbb{P}_{t \mid 0,1} \, p^{\bar{u}}_{0,1}} {\left[ \big\| u_t(X_t) + \sigma_t {\left( \nabla E + \textcolor{blue}{\nabla \log \widehat{\varphi}_1} \right)} (X_1) \big\|^2 \right]} , \tag{7}\] where \(\bar{u} = \texttt{stopgrad}(u)\), \(p^{\bar{u}}_{0,1}\) is the joint endpoint distribution under the current control, and \(\mathbb{P}_{t \mid 0,1}\) is the Brownian bridge kernel (since \(f_t = 0\)). Compared to the adjoint sampling loss, the only new ingredient is \( \textcolor{blue}{\nabla \log \widehat{\varphi}_1}\). When \(\mu = \delta_0\), \(\widehat{\varphi}_0 = \text{const}\), hence \(\widehat{\varphi}_1 = \text{const}\). The terminal cost Equation 5 becomes \(g(x) = -\log \pi(x) + \text{const}\), and through the SB marginal condition \(\varphi_1(x) = \pi(x)/(p_1^{\mathbb{P}}(x) \cdot \text{const})\), this recovers the adjoint sampling terminal cost \(\log(p_1^{\mathbb{P}}(x)/\pi(x))\).

The same reciprocal projection trick from adjoint sampling applies: sample \(X_1\) from the current model, then draw \(X_t\) from a Brownian bridge conditioned on \((X_0, X_1)\). No SDE simulation needed during training.

Corrector matching: deriving the CM objective

The AM loss Equation 7 requires \(\nabla \log \widehat{\varphi}_1\), which is unknown. We need a way to learn it. Start from the definition Equation 2: \(\widehat{\varphi}_t(x) = \mathbb{E}_\mathbb{P}[\widehat{\varphi}_0(X_0) \mid X_t = x]\). Writing this as an integral and differentiating: \[ \nabla \log \widehat{\varphi}_t(x) = \frac{\int \nabla_x \mathbb{P}_{t \mid 0}(x \mid y) \, \widehat{\varphi}_0(y) \, dy}{\widehat{\varphi}_t(x)}. \tag{8}\] Pull the gradient inside the transition kernel: \[ \nabla \log \widehat{\varphi}_t(x) = \frac{\int \nabla_x \log \mathbb{P}_{t \mid 0}(x \mid y) \, \mathbb{P}_{t \mid 0}(x \mid y) \, \widehat{\varphi}_0(y) \, dy}{\widehat{\varphi}_t(x)}. \tag{9}\] The ratio \(\mathbb{P}_{t \mid 0}(x \mid y) \, \widehat{\varphi}_0(y) / \widehat{\varphi}_t(x)\) is, by Bayes’ rule, the posterior \(p^\star(X_0 = y \mid X_t = x)\). So the corrector score is a conditional expectation: \[ \nabla \log \widehat{\varphi}_t(x) = \mathbb{E}_{p^\star} {\left[ \nabla_x \log \mathbb{P}_{t \mid 0}(x \mid X_0) \mid X_t = x \right]} . \tag{10}\] This is precisely the Tweedie-type formula: the corrector score equals the conditional expectation of the transition score \(\nabla_x \log \mathbb{P}_{t|0}(x \mid X_0)\) given \(X_t = x\). Since any conditional expectation is the minimizer of a least-squares regression, it follows that \[ \nabla \log \widehat{\varphi}_1 = \mathop{\mathrm{argmin}}_h \; \mathbb{E}_{(X_0, X_1) \sim p^{u^\star}_{0,1}} {\left[ \big\| h(X_1) - \nabla_{X_1} \log \mathbb{P}(X_1 \mid X_0) \big\|^2 \right]} . \tag{11}\] This is the corrector matching (CM) objective: regress \(h\) onto the score of the base transition kernel, evaluated at endpoint pairs from the current optimal process.

For \(f_t = 0\), the base process is a scaled Brownian motion, so \(\mathbb{P}(X_1 \mid X_0) = \mathcal{N}(X_1; X_0, (\nu_1 - \nu_0) I)\) where \(\nu_t = \int_0^t \sigma_s^2 \, ds\). The transition score is \[ \nabla_{X_1} \log \mathbb{P}(X_1 \mid X_0) = -\frac{X_1 - X_0}{\nu_1 - \nu_0}, \tag{12}\] which is the score of the base transition kernel, known in closed form. The CM regression target is just this Gaussian score, averaged over the posterior on \(X_0\).

Alternating optimization = IPFP

The AM objective Equation 7 needs \(\nabla \log \widehat{\varphi}_1\). The CM objective Equation 11 needs samples from \(p^{u^\star}\). Neither can be solved alone. The natural fix: alternate.

  1. AM step: Solve Equation 7 with \(\nabla \log \widehat{\varphi}_1 \approx h^{(k-1)}\) to get \(u^{(k)}\).
  2. CM step: Solve Equation 11 with \(u^\star \approx u^{(k)}\) to get \(h^{(k)}\).

Initialize with \(h^{(0)} = 0\). The first AM stage then regresses the control onto \(\sigma_t \nabla E(X_1)\): pure energy-guided transport with no corrector. Subsequent CM stages progressively learn the bias correction.

This alternation has a clean interpretation. The AM step solves a forward half-bridge: \(\min D_{\text{KL}}(\mathbb{Q}\,\|\, \mathbb{Q}^{\text{bwd}})\) subject to \(\mathbb{Q}_0 = \mu\). The CM step solves a backward half-bridge: \(\min D_{\text{KL}}(\mathbb{P}^{u^{(k)}} \,\|\, \mathbb{Q})\) subject to \(\mathbb{Q}_1 = \pi\). Alternating between these two projections is exactly IPFP/Sinkhorn on path-space. Convergence to the SB solution follows from standard IPFP analysis (de2021diffusion?), provided each step is solved exactly. Each step reduces \(D_{\text{KL}}(\mathbb{Q}^{(k)} \| \mathbb{Q}^\star)\) since projections onto convex constraint sets are contractive in KL. In practice, 3-5 stages suffice.

Comparison and practical remarks

Method Prior Terminal cost \(g(x)\) Corrector
Adjoint Sampling \(\delta_0\) (memoryless) \(\log \frac{p^{\mathbb{P}}_1(x)}{\pi(x)}\) None (\(\widehat{\varphi}_1 = \text{const}\))
ASBS Arbitrary \(\mu\) \(\log \frac{\widehat{\varphi}_1(x)}{\pi(x)}\) Learned \(\nabla \log \widehat{\varphi}_1\)

ASBS reduces to Adjoint Sampling when \(\mu = \delta_0\): the corrector becomes \(\widehat{\varphi}_1 = \text{const}\) and the CM step is trivially solved.

A concrete example of a useful prior: for molecular conformer generation, one can use a harmonic oscillator prior \(\mu(x) \propto \exp(-\frac{\alpha}{2} \sum_{i < j} \|x_i - x_j - r^0_{ij}\|^2)\), where \(r^0_{ij}\) are equilibrium bond lengths from the molecular graph. Particles start in a physically reasonable arrangement rather than at the origin, substantially reducing the transport cost.

One caveat: like all SOC-based samplers, the AM regression is mode-seeking: it fits \(u\) to conditional expectations under \(\mathbb{P}^{u^\star}\), which concentrates on high-probability regions of \(\pi\).