RWM & HMC on manifolds

MCMC
manifold
Published

09 03 2024

Modified

09 03 2024

Consider a smooth manifold MRn of dimension dM=(nd) defined as the zero set of a well-behaved “constraint” function C:RnRd,

M={xRnsuch thatC(x)=0}.

We would like to use MCMC to sample from a probability distribution supported on M with density π(x) with respect to the uniform Hausdorff measure on M. It is relatively straightforward to adapt standard MCMC methods when dealing with simple manifolds such as a sphere or a torus since their geodesics and several other geometric quantities are analytically tractable. Maybe surprisingly, it is in fact relatively straightforward to design MCMC samplers on general implicitly defined manifold such as M. The article () explains these ideas beautifully.

Manifold Random Walk Metropolis-Hastings

Assume that xnM is the current position of the MCMC chain. To generate a proposal ynM that will eventually be accepted or rejected, one can proceed very similarly to the standard RWM algorithm with Gaussian perturbations with variance σ2. First, generate a vector vTxn from a centred Gaussian distribution with covariance σ2I on the tangent space Txn to M at xn. To do so, it suffices for example to generate a standard Gaussian vector zN(0,σ2In) in Rn and orthogonal-project it onto Txn. Indeed, one cannot simply define the proposal as xn+v since it would not necessarily lie on M. Instead, one projects xn+v back to M. To do so, one needs to define the direction used for the projection and the manifold RWM algorithm uses Txn, for reasons that will become clear later. In other words, the proposal yn is obtained by seeking a vector wTxn such that xn+v+wM.

Projection onto M from ()

If one calls Jxn the Jacobian matrix of C at xn, i.e. the matrix whose rows are the gradients of the components of C, this projection operation boils down to finding a vector λRd such that

(1)C(xn+v+Jxnλ)=0Rd.

Note that Equation 1 is a non-linear equation in λ that can have no solution, one solution or many solutions – this can seem like a fundamental roadblock to the design of a valid MCMC algorithm, but we will see that it is not! Before discussing in slightly more details the resolution of Equation 1, assume that a standard root-finding algorithm takes the pair (xn+v,Jxn) as input and attempts to produces the projection yn,

Proj:(xn+v,Jxn)root-findingynM.

The algorithm will either converge to one of the possible solutions or fail. If the algorithm fails to converge, one can simply reject the proposal yn and set yn=(Failed) and set xn+1=xn. If the algorithm converges, this defines a valid proposal ynM. To ensure reversibility, and it is one of the main novelty of the article (), one needs to verify that the reverse proposal ynxn is possible.

To do so, note that the only possibility for the reverse move ynxn to happen is if xn=Proj(yn+v,Jyn) where

xnyn=vTyn+wTyn.

The uniqueness follows from the decomposition RnTynTyn. The reverse move is consequently possible if and only if the following reversibility check condition is satisfied,

(2)xn=Proj(yn+v,Jyn).

This reversibility check is necessary as it is not guaranteed that the root-finding algorithm started from yn+v converges at all, or converges to xn in the case when there are several solutions. If Equation 2 is not satisfied, the proposal yn is rejected and one sets xn+1=xn. On the other hand, if Equation 2 is satisfied, the proposal yn is accepted with the usual Metropolis-Hastings probability

min{1,π(yn)p(v|xn)π(xn)p(v|xn)}

where p(v|x)=Z1exp(v2/2σ2) denotes the Gaussian density on the tangent space Txn The above description defines a valid MCMC algorithm on M that is reversible with respect to the target distribution π(x).

Projection onto the manifold

As described above, the main difficulty is to solve the non-linear equation Equation 1 describing the projection of the proposal (xn+v) back to the manifold M. The projection is along the space spanned by the columns of JxnRn,d, i.e. find a vector λRd such that

Φ(λ)=C(xn+v+Jxnλ)=0Rd.

One can use a standard Newton’s method to solve this equation started from λ0=0. Setting for notational convenience q(λ)=xn+v+JxnTλ, this boils down to iterating

λk+1λk=(Jq(λk)Jxn)1Φ(λk).

As described in (), it can sometimes be computationally advantageous to use a quasi-Newton method and use instead

λk+1λk=G1Φ(λk)

with fixed positive definite matrix G=JxnJxn since one can then pre-compute a decomposition of G and use it to solve the linear systems at each iterations. In some recent and related work (), we observed that the standard Newton method performed well in the settings we considered and there was most of the time no computational advantage to using a quasi-Newton method. In practice, the main computational bottleneck is to compute the Jacobian matrix Jxn, although it is problem-dependent and some structure can typically be exploited. In practice, only a relatively small number of iterations are performed and the root-finding algorithm is stopped as soon as Φ(λk) is below a certain threshold. If the stepsize is small, i.e. v1, it is typically the case that the Newton’s method will converge to a solution in only a very small number of iterations – indeed, Newton’s method is quadratically convergent when close to a solution.

30k RWM chains ran in parallel to explore a double torus.

In the figure above, I have implemented the RWM algorithm above described to sample from the uniform distribution supported on a double torus described by the constraint function C:R3R given as

C(x,y,z)=(x2(x21)+y2)2+z20.03.

The figure shows 30,000 chains ran in parallel, which is straightforward to implement in practice with JAX (). All the chains are initialized from the same position so that one can visualize the evolution of the density of particles.

Tuning of manifold-RWM

One can for example monitor the usual expected squared jump distance

(ESJD)E[Xn+1Xn2]

and maximize it to tune the RWM step-size; it would probably make slightly more sense to monitor the squared geodesic distances instead the naive squared norm Xn+1Xn2, but that’s way to much hassle and would probably make only a negligible difference. In the figure above, I have plotted the expected squared jump distance as a function of the acceptance rate for different step-sizes. It is interesting to see a pattern extremely similar to the one observed in the standard RWM algorithm (): in this double torus example, the optimal acceptance rate is around 25%. Note that since the target distribution is uniform, the rate of acceptance is only very slightly lower than the proportion of successful reversibility checks.

Hamiltonian Monte Carlo (HMC) on manifolds

While the Random Walk Metropolis-Hastings algorithm is interesting, exploiting gradient information is often necessary to design efficient MCMC samplers. Consider a single iteration of a standard Hamiltonian Monte Carlo (HMC) sampler targeting a density π(q) on qRn. The method proceeds by simulating from a dynamics that is reversible with respect to an extended target density π¯(q,p) on RnRn defined as

π¯(q,p)π(q)exp{12mp2}=exp{logπ(q)K(p)}

for a user-defined mass parameter m>0. In general, the mass parameter is a positive definite matrix but generalizing this to manifolds is slightly less useful in practice. For a time-discretization step ε>0 and a current position (qn,pn), the method proceeds by generating a proposal (q,p) defined as

{pn+1/2=pn+ε2logπ(qn)q=qn+εm1pn+1/2p=pn+1/2+ε2logπ(q).

This proposal is accepted with probability min(1,π¯(q,p)/π¯(qn,pn)). Indeed, in standard implementation, several leapfrog steps are performed instead of a single one. One can also choose to perform a single leapfrog step as above and only do a partial refreshment of the momentum after each leapfrog step – this may be more efficient or easier to implement when running a large number of HMC chains in parallel on a GPU for example. To adapt the HMC algorithm to sample from a density π(q) supported on a manifold M, one can proceed similarly to the RWM algorithm by interleaving additional projection steps. These projections are needed to ensure that the momentum vectors pn remain in the right tangent spaces and the position vectors qn remain on the manifold M,

(qn,pn)M×Tqn.

As in the RWM algorithm, reversibility checks need to be performed to ensure that the overall algorithm is reversible with respect to the target distribution π(q,p). The resulting algorithm for generating a proposal (qn,pn)(q,p) reads as follows:

{p~n+1/2=pn+ε2logπ(qn)pn+1/2=orthogonal project p~n+1/2 onto Tqnq~=qn+εm1pn+1/2q=Proj(q~,Jqn)pn+1/2=orthogonal project (qqn)m/ε onto Tqp~=pn+1/2+ε2logπ(q)p=orthogonal project p~ onto Tq.

If any of the projection operations fail, the proposal is rejected. If no failure occurs, a reversibility check is performed by running the algorithm backward starting from (q,p). If the reversibility check is successful, the proposal is accepted with the usual Metropolis-Hastings probability min(1,π¯(q,p)/π¯(qn,pn)).

5k HMC chains ran in parallel: the momentum is not refreshed

The article () provides a detailed description of several of these ideas along with detailed analysis and extensions.

References

Au, Khai Xiang, Matthew M Graham, and Alexandre H Thiery. 2022. “Manifold Lifting: Scaling MCMC to the Vanishing Noise Regime.” Journal of the Royal Statistical Society: Series B. https://arxiv.org/abs/2003.03950.
Barth, Eric, Krzysztof Kuczera, Benedict Leimkuhler, and Robert D Skeel. 1995. “Algorithms for Constrained Molecular Dynamics.” Journal of Computational Chemistry 16 (10). Wiley Online Library: 1192–1209. https://doi.org/10.1002/jcc.540161003.
Bradbury, James, Roy Frostig, Peter Hawkins, Matthew James Johnson, Chris Leary, Dougal Maclaurin, George Necula, et al. 2018. JAX: Composable Transformations of Python+NumPy Programs.” http://github.com/google/jax.
Lelièvre, Tony, Mathias Rousset, and Gabriel Stoltz. 2019. “Hybrid Monte Carlo Methods for Sampling Probability Measures on Submanifolds.” Numerische Mathematik 143 (2). Springer: 379–421. https://arxiv.org/abs/1807.02356.
Roberts, Gareth O, and Jeffrey S Rosenthal. 2001. “Optimal Scaling for Various Metropolis-Hastings Algorithms.” Statistical Science 16 (4). Institute of Mathematical Statistics: 351–67. https://doi.org/10.1214/ss/1015346320.
Zappa, Emilio, Miranda Holmes-Cerfon, and Jonathan Goodman. 2018. “Monte Carlo on Manifolds: Sampling Densities and Integrating Functions.” Communications on Pure and Applied Mathematics 71 (12). Wiley Online Library: 2609–47. https://arxiv.org/abs/1702.08446.