Skip to content

API Reference

Banana-shaped 2D distribution (Rosenbrock-like).

The density factorizes as

p(x0, x1) = N(x0; 1, 1) * N(x1; x0^2, sigma^2)

This creates a curved, banana-shaped distribution centered around the parabola x1 = x0^2. The marginal in x0 is standard normal shifted to mean 1, while x1 follows x0^2 with Gaussian noise.

Attributes:

Name Type Description
sigma float

Controls the "thickness" of the banana. Smaller values (e.g., 0.01) create a thin, tightly curved banana that is challenging for MCMC. Larger values (e.g., 1.0) create a fatter, easier distribution. Default: 0.1.

Example

dist = Banana2D(sigma=0.1) x = jnp.array([1.0, 1.0]) dist(x) # log prob dist.sample(key, 100) # 100 samples jax.grad(dist)(x) # gradient

dim property

__call__(x)

Evaluate log probability density.

Parameters:

Name Type Description Default
x Array

Input point(s) of shape (..., 2).

required

Returns:

Type Description
Array

Log probability density of shape (...).

sample(key, n)

Draw exact samples from the distribution.

Parameters:

Name Type Description Default
key Array

JAX PRNG key.

required
n int

Number of samples to draw.

required

Returns:

Type Description
Array

Samples of shape (n, 2).

log_normalization()

Log normalizing constant (already included in call).

Returns:

Type Description
Array

Scalar log(Z) where Z is the normalizing constant.

Neal's funnel distribution in D dimensions.

The density factorizes as

p(x) = N(x0; 0, sigma^2) * prod_{i=1}^{D-1} N(x_i; 0, exp(x0))

The first coordinate x0 controls the scale of all other coordinates. When x0 is large, the remaining coordinates spread out; when x0 is small (negative), they concentrate near zero. This creates a funnel shape that is notoriously difficult for MCMC.

Attributes:

Name Type Description
dim int

Dimensionality of the distribution. Must be >= 2.

sigma float

Std dev of x0 (the "mouth width" of the funnel). Larger values create a wider range of scales, making sampling harder. Default: 3.0 (standard benchmark setting).

Example

dist = NealFunnel(dim=10, sigma=3.0) x = jnp.zeros(10) dist(x) # log prob dist.sample(key, 100) # 100 samples jax.grad(dist)(x) # gradient

dim = 10 class-attribute instance-attribute

Dimensionality of the distribution.

__call__(x)

Evaluate log probability density.

Parameters:

Name Type Description Default
x Array

Input point(s) of shape (..., dim).

required

Returns:

Type Description
Array

Log probability density of shape (...).

sample(key, n)

Draw exact samples from the distribution.

Parameters:

Name Type Description Default
key Array

JAX PRNG key.

required
n int

Number of samples to draw.

required

Returns:

Type Description
Array

Samples of shape (n, dim).

log_normalization()

Log normalizing constant (already included in call).

Returns:

Type Description
Array

Scalar 0.0 (distribution is normalized by construction).