API Reference
Banana-shaped 2D distribution (Rosenbrock-like).
The density factorizes as
p(x0, x1) = N(x0; 1, 1) * N(x1; x0^2, sigma^2)
This creates a curved, banana-shaped distribution centered around the parabola x1 = x0^2. The marginal in x0 is standard normal shifted to mean 1, while x1 follows x0^2 with Gaussian noise.
Attributes:
| Name | Type | Description |
|---|---|---|
sigma |
float
|
Controls the "thickness" of the banana. Smaller values (e.g., 0.01) create a thin, tightly curved banana that is challenging for MCMC. Larger values (e.g., 1.0) create a fatter, easier distribution. Default: 0.1. |
Example
dist = Banana2D(sigma=0.1) x = jnp.array([1.0, 1.0]) dist(x) # log prob dist.sample(key, 100) # 100 samples jax.grad(dist)(x) # gradient
dim
property
__call__(x)
Evaluate log probability density.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input point(s) of shape (..., 2). |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Log probability density of shape (...). |
sample(key, n)
Draw exact samples from the distribution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
Array
|
JAX PRNG key. |
required |
n
|
int
|
Number of samples to draw. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Samples of shape (n, 2). |
log_normalization()
Log normalizing constant (already included in call).
Returns:
| Type | Description |
|---|---|
Array
|
Scalar log(Z) where Z is the normalizing constant. |
Neal's funnel distribution in D dimensions.
The density factorizes as
p(x) = N(x0; 0, sigma^2) * prod_{i=1}^{D-1} N(x_i; 0, exp(x0))
The first coordinate x0 controls the scale of all other coordinates. When x0 is large, the remaining coordinates spread out; when x0 is small (negative), they concentrate near zero. This creates a funnel shape that is notoriously difficult for MCMC.
Attributes:
| Name | Type | Description |
|---|---|---|
dim |
int
|
Dimensionality of the distribution. Must be >= 2. |
sigma |
float
|
Std dev of x0 (the "mouth width" of the funnel). Larger values create a wider range of scales, making sampling harder. Default: 3.0 (standard benchmark setting). |
Example
dist = NealFunnel(dim=10, sigma=3.0) x = jnp.zeros(10) dist(x) # log prob dist.sample(key, 100) # 100 samples jax.grad(dist)(x) # gradient
dim = 10
class-attribute
instance-attribute
Dimensionality of the distribution.
__call__(x)
Evaluate log probability density.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input point(s) of shape (..., dim). |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Log probability density of shape (...). |
sample(key, n)
Draw exact samples from the distribution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
Array
|
JAX PRNG key. |
required |
n
|
int
|
Number of samples to draw. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Samples of shape (n, dim). |
log_normalization()
Log normalizing constant (already included in call).
Returns:
| Type | Description |
|---|---|
Array
|
Scalar 0.0 (distribution is normalized by construction). |