Illustration of several algorithms

This notebook illustrates a few algorithms on a simple one-dimensional regression task. We first start by generating some synthetic data, then apply:

  • Several SGMCMC methods (SGLD, pSGLD, cSGLD, SGLD-CV, SGLD-SVRG, SGHMC, SGHMC-CV, SGHMC-SVRG);

  • Hamiltonian Monte Carlo;

  • Deep ensembles;

  • Stochastic Weight Averaging Gaussian (SWAG);

  • Monte Carlo dropout.

In these examples, a homoscedastic regression problem is considered with a noise level assumed to be known.

from time import time

import blackjax
import flax.linen as nn
import jax
import jax.numpy as jnp
import jax.random as jr
import jax.scipy.stats as stats
import matplotlib.pyplot as plt
import numpy as np
from jax.flatten_util import ravel_pytree
from matplotlib.gridspec import GridSpec

from pbnn.deep_ensembles import deep_ensembles_fn
from pbnn.map_estimation import train_fn
from pbnn.mcdropout import mcdropout_fn
from pbnn.mcmc.hamiltonian import hmc, sghmc, sghmc_cv, sghmc_svrg
from pbnn.mcmc.langevin import (
    cyclical_sgld,
    pSGLD,
    sgld,
    sgld_cv,
    sgld_svrg,
)
from pbnn.swag import swag_fn
from pbnn.utils.analytical_functions import gramacy_function
from pbnn.utils.plot import plot_on_axis

%load_ext watermark

Generate data

n = 100
noise_level = 0.1

np.random.seed(0)
X = 20 * np.random.rand(n, 1)
X_test = np.linspace(0, 20, 200)[:, None]
X, X_test = jnp.array(X), jnp.array(X_test)

noise, noise_test = (
    np.random.randn(n, 1) * noise_level,
    np.random.randn(len(X_test), 1) * noise_level,
)

y = gramacy_function(X, noise)
y_test = gramacy_function(X_test, noise_test)
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

Define the network, loglikelihood and logprior

# define loglikelihood et logprior
class MLP(nn.Module):
    """Simple MLP."""

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(
            features=50,
            kernel_init=nn.initializers.normal(),
            bias_init=nn.initializers.normal(),
        )(x)
        x = nn.tanh(x)
        x = nn.Dense(
            features=50,
            kernel_init=nn.initializers.normal(),
            bias_init=nn.initializers.normal(),
        )(x)
        x = nn.tanh(x)
        x = nn.Dense(
            features=1,
            kernel_init=nn.initializers.normal(),
            bias_init=nn.initializers.normal(),
        )(x)
        return x


network = MLP()


def loglikelihood_fn(parameters, data, sig_noise: float = noise_level):
    """Gaussian log-likelihood"""
    X, y = data
    return -jnp.sum(
        0.5 * (y - network.apply({"params": parameters}, X)) ** 2 / sig_noise**2
    )


def logprior_fn(parameters):
    """Compute the value of the log-prior density function."""
    flat_params, _ = ravel_pytree(parameters)
    return jnp.sum(stats.norm.logpdf(flat_params))

MAP estimation

# define the log-posterior function
def logposterior_estimator_fn(logprior_fn, loglikelihood_fn, data_size: int):
    """Log posterior function"""

    def logposterior_fn(parameters, data_batch):
        logprior = logprior_fn(parameters)
        batch_loglikelihood = jax.vmap(loglikelihood_fn, in_axes=(None, 0))
        return logprior + data_size * jnp.mean(
            batch_loglikelihood(parameters, data_batch), axis=0
        )

    return logposterior_fn


logposterior_fn = logposterior_estimator_fn(logprior_fn, loglikelihood_fn, len(X))
train_ds = dict(x=X, y=y)

# train a first network to get centering parameters that will be used for control variates
key = jr.PRNGKey(np.random.randint(low=0, high=12345))

map_params = train_fn(logposterior_fn, network, train_ds, 32, 10_000, 1e-2, key)
map_pred_test_1 = network.apply({"params": map_params}, X_test)
mse = jnp.mean((map_pred_test_1 - y_test) ** 2)

# train a second network to get initial parameters for the SGMCMC algorithms
_, key = jr.split(key)
init_params = train_fn(logposterior_fn, network, train_ds, 32, 10_000, 1e-2, key)
map_pred_test_2 = network.apply({"params": init_params}, X_test)

# also generate random initial positions that will be used for some algorithms
_, key = jr.split(key)
rng_init_positions = network.init(key, X_test[0])["params"]

# sanity check: plot the associated predictions
fig = plt.figure()
ax = fig.add_subplot(111)
ax.plot(X_test, map_pred_test_1, ls="-", lw=2, color="k", label="1st MAP prediction")
ax.plot(X_test, map_pred_test_2, ls="-", lw=2, color="r", label="2nd MAP prediction")
# ax.plot(X_test, rng_pred_test, ls="-", lw=2, color="m", label="Random network")
ax.plot(
    X, y, ls="", marker="o", markerfacecolor="b", markeredgecolor="k", markeredgewidth=1
)
ax.set_xlabel(r"$x$", fontsize=14)
ax.set_ylabel(r"$y$", fontsize=14)
ax.legend(fontsize=12)
<matplotlib.legend.Legend at 0x7f8b34e672d0>
../../_images/9711ce9e6230f54e4bd3aad37bbd27a5ebf46ba6f82879a0bcef8a6e15c39954.png

MCMC algorithms

Define a helper function for SGMCMC methods

# global parameters
batch_size = 32


def sgmcmc_fn(algorithm, burnin, thin_freq, init_positions, rng_key, **kwargs):
    keys = jr.split(rng_key)
    positions, ravel_fn, predict_fn = algorithm(
        X=X,
        y=y,
        loglikelihood_fn=loglikelihood_fn,
        logprior_fn=logprior_fn,
        init_positions=init_positions,
        batch_size=batch_size,
        rng_key=keys[0],
        **kwargs,
    )

    # remove burnin and thin
    positions = jax.tree_util.tree_map(lambda xx: xx[burnin::thin_freq], positions)

    # predict
    f_predictions = predict_fn(network, positions, X_test).squeeze()

    # generate the noisy predictions
    _, key = jr.split(keys[11])
    y_predictions = f_predictions + noise_level * jr.normal(
        key, shape=(len(f_predictions), 1)
    )

    return ravel_fn(positions), y_predictions

Set some hyperparameters

algorithms = [
    (
        sgld,
        {
            "step_size": 1e-8,
            "num_iterations": 100_000,
            "burnin": 80_000,
            "thin_freq": 10,
            "init_positions": init_params,
        },
    ),
    (
        pSGLD,
        {
            "step_size": 5e-5,
            "preconditioning_factor": 0.95,
            "num_iterations": 100_000,
            "burnin": 80_000,
            "thin_freq": 10,
            "init_positions": rng_init_positions,
        },
    ),
    (
        cyclical_sgld,
        {
            "step_size": 1e-6,
            "num_cycles": 5,
            "num_sgd_steps": 1_00,
            "num_sgld_steps": 20_000,
            "burnin_sgld": 10_000,
            "burnin": 0,
            "thin_freq": 10,
            "init_positions": init_params,
        },
    ),
    (
        sgld_cv,
        {
            "step_size": 1e-7,
            "num_iterations": 10_000,
            "burnin": 5_000,
            "thin_freq": 10,
            "centering_positions": map_params,
            "init_positions": map_params,
        },
    ),
    (
        sgld_svrg,
        {
            "step_size": 2e-6,
            "num_cv_iterations": 10,
            "num_svrg_iterations": 2000,
            "burnin": 10_000,
            "thin_freq": 10,
            "centering_positions": map_params,
            "init_positions": map_params,
        },
    ),
    (
        sghmc,
        {
            "step_size": 5e-5,
            "num_integration_steps": 100,
            "num_iterations": 2_000,
            "burnin": 1_000,
            "thin_freq": 10,
            "init_positions": init_params,
        },
    ),
    (
        sghmc_cv,
        {
            "step_size": 5e-5,
            "num_iterations": 10_000,
            "burnin": 5_000,
            "thin_freq": 10,
            "centering_positions": map_params,
            "num_integration_steps": 40,
            "init_positions": map_params,
        },
    ),
    (
        sghmc_svrg,
        {
            "step_size": 5e-5,
            "num_cv_iterations": 10,
            "num_svrg_iterations": 2000,
            "burnin": 10_000,
            "thin_freq": 10,
            "centering_positions": map_params,
            "num_integration_steps": 40,
            "init_positions": map_params,
        },
    ),
]

Run SGMCMC methods

# Create an empty dict for storing predictions
y_predictions = dict(
    sgld=None,
    pSGLD=None,
    sgld_cv=None,
    sgld_svrg=None,
    sghmc=None,
    sghmc_cv=None,
    sghmc_svrg=None,
    cyclical_sgld=None,
    hmc=None,
)

for algorithm, hparams in algorithms:
    _, key = jr.split(key)
    burnin, thin_freq, init_pos = (
        hparams["burnin"],
        hparams["thin_freq"],
        hparams["init_positions"],
    )
    hparams.pop("burnin")
    hparams.pop("thin_freq")
    hparams.pop("init_positions")
    t0 = time()
    positions, y_prediction = sgmcmc_fn(
        algorithm, burnin, thin_freq, init_pos, key, **hparams
    )
    print(f"Elapsed time for {algorithm.__name__}: {time()-t0}")
    y_predictions[algorithm.__name__] = y_prediction
Elapsed time for sgld: 9.930523157119751
Elapsed time for pSGLD: 11.697723627090454
Elapsed time for cyclical_sgld: 12.502580881118774
Elapsed time for sgld_cv: 2.9192466735839844
Elapsed time for sgld_svrg: 4.193420648574829
Elapsed time for sghmc: 22.310141563415527
Elapsed time for sghmc_cv: 39.58222699165344
Elapsed time for sghmc_svrg: 82.1585681438446

Run HMC

def logprob_fn(parameters):
    logprior = logprior_fn(parameters)
    batch_loglikelihood = jax.vmap(loglikelihood_fn, (None, 0))(parameters, (X, y))
    return logprior + jnp.sum(batch_loglikelihood)


_, key = jr.split(key)
t0 = time()
positions, ravel_fn, predict_fn = hmc(
    logprob_fn=logprob_fn,
    init_positions=init_params,
    num_samples=200,
    step_size=1e-4,
    inverse_mass_matrix=jnp.ones(positions.shape[1]),
    num_integration_steps=40,
    rng_key=key,
)
print(f"Elapsed time for HMC: {time()-t0}")

# predict
f_predictions = predict_fn(network, positions, X_test).squeeze()

# generate the noisy predictions
_, key = jr.split(key)
hmc_predictions = f_predictions + noise_level * jr.normal(
    key, shape=(len(f_predictions), 1)
)
y_predictions["hmc"] = hmc_predictions
Elapsed time for HMC: 2.103886842727661

Plot the prediction intervals

fig = plt.figure(constrained_layout=True, figsize=(3 * 5, 3 * 4))
gs = GridSpec(nrows=3, ncols=3, figure=fig)
alpha = 0.05

for i, (name, y_pred) in enumerate(y_predictions.items()):
    mean_prediction = jnp.median(y_pred, axis=0)
    qlow = jnp.quantile(y_pred, 0.5 * alpha, axis=0)
    qhigh = jnp.quantile(y_pred, (1 - 0.5 * alpha), axis=0)

    ax = fig.add_subplot(gs[i])
    plot_on_axis(ax, X_test, y_test, mean_prediction, qlow, qhigh, title=f"{name}")
../../_images/fc5602f1fada3544fdf7be084671dafe096773768340b36184bf356cc3a081b7.png

Deep ensembles

key = jr.PRNGKey(np.random.randint(low=0, high=12345))

positions, ravel_fn, predict_fn = deep_ensembles_fn(
    X, y, loglikelihood_fn, logprior_fn, network, batch_size, 10_000, 1e-2, 10, key
)

f_predictions = predict_fn(network, positions, X_test).squeeze()

key = jr.PRNGKey(np.random.randint(low=0, high=12345))
y_predictions = f_predictions + noise_level * jr.normal(
    key, shape=(len(f_predictions), 1)
)

fig = plt.figure(constrained_layout=True, figsize=(1 * 5, 1 * 4))
gs = GridSpec(nrows=1, ncols=1, figure=fig)

mean_prediction = jnp.median(y_predictions, axis=0)
qlow = jnp.quantile(y_predictions, 0.5 * alpha, axis=0)
qhigh = jnp.quantile(y_predictions, (1 - 0.5 * alpha), axis=0)

ax = fig.add_subplot(gs[0])
plot_on_axis(ax, X_test, y_test, mean_prediction, qlow, qhigh, title="Deep Ens.")
../../_images/ab0d9c1cdd6cf6e84e61d71b8c064c6a4fac7a4273d17f7847aa84d2df49d906.png

SWAG

key = jr.PRNGKey(np.random.randint(low=0, high=12345))

positions, ravel_fn, predict_fn = swag_fn(
    X,
    y,
    loglikelihood_fn,
    logprior_fn,
    network,
    map_params,
    batch_size,
    1000,
    1e-5,
    20,
    key,
)

f_predictions = predict_fn(network, positions, X_test)

key = jr.PRNGKey(np.random.randint(low=0, high=12345))
y_predictions = f_predictions + noise_level * jr.normal(
    key, shape=(len(f_predictions), 1)
)

fig = plt.figure(constrained_layout=True, figsize=(1 * 5, 1 * 4))
gs = GridSpec(nrows=1, ncols=1, figure=fig)

alpha = 0.05
mean_prediction = jnp.median(y_predictions, axis=0)
qlow = jnp.quantile(y_predictions, 0.5 * alpha, axis=0)
qhigh = jnp.quantile(y_predictions, (1 - 0.5 * alpha), axis=0)

ax = fig.add_subplot(gs[0])
plot_on_axis(ax, X_test, y_test, mean_prediction, qlow, qhigh, title="SWAG")
../../_images/162128484fcf0cccf44bbb87751a80331a295d9b4db67cbf77e90f6c8b61192a.png

Monte Carlo dropout

dropout_rate = 0.05


class network_dropout(nn.Module):
    """Simple MLP."""

    @nn.compact
    def __call__(self, x, deterministic=False):
        x = nn.Dense(
            features=50,
            kernel_init=nn.initializers.normal(),
            bias_init=nn.initializers.normal(),
        )(x)
        x = nn.Dropout(dropout_rate, deterministic=deterministic)(x)
        x = nn.tanh(x)
        x = nn.Dense(
            features=50,
            kernel_init=nn.initializers.normal(),
            bias_init=nn.initializers.normal(),
        )(x)
        x = nn.Dropout(dropout_rate, deterministic=deterministic)(x)
        x = nn.tanh(x)
        x = nn.Dense(
            features=1,
            kernel_init=nn.initializers.normal(),
            bias_init=nn.initializers.normal(),
        )(x)
        return x


def loglikelihood_fn_dropout(
    parameters, data, dropout_rng, sig_noise: float = noise_level
):
    """Gaussian log-likelihood"""
    X, y = data
    return -jnp.sum(
        0.5
        * (
            y
            - network_dropout().apply(
                {"params": parameters}, X, rngs={"dropout": dropout_rng}
            )
        )
        ** 2
        / sig_noise**2
    )


key = jr.PRNGKey(np.random.randint(low=0, high=12345))

positions, ravel_fn, predict_fn = mcdropout_fn(
    X,
    y,
    loglikelihood_fn_dropout,
    logprior_fn,
    network_dropout,
    batch_size,
    10_000,
    1e-2,
    key,
)

key = jr.PRNGKey(np.random.randint(low=0, high=12345))
keys = jr.split(key, 100)

f_predictions = jnp.stack(
    [predict_fn(network_dropout, positions, X_test, key) for key in keys]
).squeeze()

_, key = jr.split(keys[-1])
y_predictions = f_predictions + noise_level * jr.normal(
    key, shape=(len(f_predictions), 1)
)

fig = plt.figure(constrained_layout=True, figsize=(1 * 5, 1 * 4))
gs = GridSpec(nrows=1, ncols=1, figure=fig)

alpha = 0.05
mean_prediction = jnp.median(y_predictions, axis=0)
qlow = jnp.quantile(y_predictions, 0.5 * alpha, axis=0)
qhigh = jnp.quantile(y_predictions, (1 - 0.5 * alpha), axis=0)

ax = fig.add_subplot(gs[0])
plot_on_axis(ax, X_test, y_test, mean_prediction, qlow, qhigh, title="MC Dropout")
../../_images/270c095cdf324adaaa8c533109f07c450e84120e51eef03a2a6e8764b11ef6ab.png
%reload_ext watermark
%watermark -n -u -v -iv -w -a 'Brian Staber'
Author: Brian Staber

Last updated: Mon Feb 26 2024

Python implementation: CPython
Python version       : 3.11.7
IPython version      : 8.21.0

blackjax  : 1.1.0
numpy     : 1.26.3
matplotlib: 3.8.2
flax      : 0.8.0
jax       : 0.4.23

Watermark: 2.4.3