Stochastic gradient Langevin dynamics

import blackjax
import flax.linen as nn
import jax
import jax.numpy as jnp
import jax.random as jr
import jax.scipy.stats as stats
import matplotlib.pyplot as plt
import numpy as np
from jax.flatten_util import ravel_pytree
from matplotlib.gridspec import GridSpec
from sklearn.preprocessing import StandardScaler

from pbnn.mcmc.langevin import sgld
from pbnn.utils.plot import plot_on_axis

%load_ext watermark

Generate data

def test_function(x):
    return np.sin(2.0 * x) + np.cos(x)


n = 100
noise_level = 0.1

np.random.seed(0)
X = 20 * np.random.rand(n, 1)
X_test = np.linspace(0, 20, 200)[:, None]

noise, noise_test = (
    np.random.randn(n, 1) * noise_level,
    np.random.randn(len(X_test), 1) * noise_level,
)

y = test_function(X) + noise
y_test = test_function(X_test) + noise_test

scaler = StandardScaler()
X_ = scaler.fit_transform(X)
X_test_ = scaler.transform(X_test)

X_, y = jnp.array(X_), jnp.array(y)
X_test_, y_test = jnp.array(X_test_), jnp.array(y_test)

Define the network, loglikelihood and logprior

# define loglikelihood et logprior
class MLP(nn.Module):
    """Simple MLP."""

    hidden_features: int
    out_features: int

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(
            features=self.hidden_features,
            kernel_init=nn.initializers.normal(),
            bias_init=nn.initializers.normal(),
        )(x)
        x = nn.tanh(x)
        x = nn.Dense(
            features=self.hidden_features,
            kernel_init=nn.initializers.normal(),
            bias_init=nn.initializers.normal(),
        )(x)
        x = nn.tanh(x)
        x = nn.Dense(
            features=self.out_features,
            kernel_init=nn.initializers.normal(),
            bias_init=nn.initializers.normal(),
        )(x)
        return x


network = MLP(hidden_features=50, out_features=1)


def loglikelihood_fn(parameters, data, sig_noise: float = noise_level):
    """Gaussian log-likelihood"""
    X, y = data
    return -jnp.sum(
        0.5 * (y - network.apply({"params": parameters}, X)) ** 2 / sig_noise**2
    )


def logprior_fn(parameters):
    """Compute the value of the log-prior density function."""
    flat_params, _ = ravel_pytree(parameters)
    return jnp.sum(stats.norm.logpdf(flat_params))

Run the SGLD algorithm

# set some hyperparameters
step_size = 1e-5
num_iterations = 100_000
burnin = 80_000
batch_size = 32

# generate initial positions
rng_key = jr.PRNGKey(0)
keys = jr.split(rng_key)

init_positions = network.init(keys[0], X_)["params"]

# run SGLD
positions, ravel_fn, predict_fn = sgld(
    X=X_,
    y=y,
    loglikelihood_fn=loglikelihood_fn,
    logprior_fn=logprior_fn,
    init_positions=init_positions,
    batch_size=batch_size,
    step_size=step_size,
    num_iterations=num_iterations,
    rng_key=keys[1],
)

# remove burnin and thin manually
positions = jax.tree_util.tree_map(lambda xx: xx[burnin::100], positions)

# predict
f_predictions = predict_fn(network, positions, X_test_).squeeze()

# generate the noisy predictions
_, key = jr.split(keys[1])
y_predictions = f_predictions + noise_level * jr.normal(
    key, shape=(len(f_predictions), 1)
)

Plot the prediction intervals

fig = plt.figure(constrained_layout=True, figsize=(5, 4))
gs = GridSpec(nrows=1, ncols=1, figure=fig)
alpha = 0.05

mean_prediction = jnp.median(y_predictions, axis=0)
qlow = jnp.quantile(y_predictions, 0.5 * alpha, axis=0)
qhigh = jnp.quantile(y_predictions, (1 - 0.5 * alpha), axis=0)

ax = fig.add_subplot(gs[0])
plot_on_axis(ax, X_test, y_test, mean_prediction, qlow, qhigh, title="sgld")
../../_images/fae9b2f7faf97bf291d2143f51fa4c7707a8fb9c196efb6bf47bf0add515b4df.png
%reload_ext watermark
%watermark -n -u -v -iv -w -a 'Brian Staber'
Author: Brian Staber

Last updated: Wed Mar 13 2024

Python implementation: CPython
Python version       : 3.11.7
IPython version      : 8.21.0

jax       : 0.4.23
numpy     : 1.26.3
blackjax  : 1.1.0
matplotlib: 3.8.2
flax      : 0.8.0

Watermark: 2.4.3