Conformal prediction with Flax and MAPIE

import warnings

import jax.numpy as jnp
import jax.random as jr
import mapie
import matplotlib.pyplot as plt
import numpy as np
import optax
from jax import Array
from mapie.regression import MapieRegressor
from matplotlib.gridspec import GridSpec
from sklearn.base import BaseEstimator, RegressorMixin
from sklearn.metrics import r2_score
from sklearn.model_selection import train_test_split

from pbnn.map_estimation import train_fn
from pbnn.models import MLP
from pbnn.utils.analytical_functions import gramacy_function
from pbnn.utils.plot import plot_on_axis

warnings.filterwarnings('ignore')

%load_ext watermark

Generate data

n = 100
noise_level = 0.1

np.random.seed(0)
X = 20 * np.random.rand(n, 1)
X_test = np.linspace(0, 20, 200)[:, None]
X, X_test = jnp.array(X), jnp.array(X_test)

noise, noise_test = (
    np.random.randn(n, 1) * noise_level,
    np.random.randn(len(X_test), 1) * noise_level,
)

y = gramacy_function(X, noise)
y_test = gramacy_function(X_test, noise_test)
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

Define a neural network using Flax

Let’s define a simple MLP network using pbnn.models. In order to be able to use MAPIE, we also need to define scikit-learn-like regressors that has fit and predict methods.

Two regressors are implemented below: a simple MLP regressor trained with a MSE loss, and a MLP quantile regressor trained with the pinball loss.

class MLPRegressor(RegressorMixin, BaseEstimator):
    def __init__(
        self,
        hidden_features: int,
        output_features: int,
        batch_size: int = 32,
        num_epochs: int = 1000,
        lr: float = 5e-3,
        rng_key: Array = jr.PRNGKey(0),
    ):
        self.model = MLP(hidden_features=hidden_features, out_features=output_features)
        self.hidden_features = hidden_features
        self.output_features = output_features
        self.batch_size = batch_size
        self.num_epochs = num_epochs
        self.lr = lr
        self.rng_key = rng_key

    def fit(self, X, y):
        X, y = jnp.array(X), jnp.array(y)

        def loss_fn(params, batch_of_data):
            X_batch, y_batch = batch_of_data
            predictions = self.model.apply({"params": params}, X_batch).squeeze()
            return -optax.losses.l2_loss(predictions, y_batch).mean()

        self.params = train_fn(
            logposterior_fn=loss_fn,
            network=self.model,
            train_ds={"x": X, "y": y},
            batch_size=self.batch_size,
            num_epochs=self.num_epochs,
            learning_rate=self.lr,
            rng_key=self.rng_key,
        )

        _, self.rng_key = jr.split(self.rng_key)

        return self

    def predict(self, X):
        X = jnp.array(X)
        return self.model.apply({"params": self.params}, X).squeeze()


class QuantileMLPRegressor(RegressorMixin, BaseEstimator):
    def __init__(
        self,
        hidden_features: int,
        output_features: int,
        alpha: float,
        batch_size: int = 32,
        num_epochs: int = 1000,
        lr: float = 5e-3,
        rng_key: Array = jr.PRNGKey(0),
    ):
        self.model = MLP(hidden_features=hidden_features, out_features=output_features)
        self.hidden_features = hidden_features
        self.output_features = output_features
        self.batch_size = batch_size
        self.num_epochs = num_epochs
        self.lr = lr
        self.alpha = alpha
        self.alpha_ = [alpha, 0.5, 1 - alpha]
        self.rng_key = rng_key

        assert output_features == len(self.alpha_)

    def fit(self, X, y):
        X, y = jnp.array(X), jnp.array(y)

        def loss_fn(params, batch_of_data):
            X, target = batch_of_data
            preds = self.model.apply({"params": params}, X)

            losses = []
            preds_reshape = preds.reshape((-1, target.shape[1], len(self.alpha_)))
            for i, q in enumerate(self.alpha_):
                errors = target - preds_reshape[..., i]
                losses.append(jnp.maximum((q - 1) * errors, q * errors))

            loss = jnp.mean(jnp.sum(jnp.concatenate(losses, axis=-1), axis=-1))
            return -loss

        self.params = train_fn(
            logposterior_fn=loss_fn,
            network=self.model,
            train_ds={"x": X, "y": y},
            batch_size=self.batch_size,
            num_epochs=self.num_epochs,
            learning_rate=self.lr,
            rng_key=self.rng_key,
        )

        _, self.rng_key = jr.split(self.rng_key)

        return self

    def predict(self, X):
        X = jnp.array(X)
        return self.model.apply({"params": self.params}, X).squeeze()

Custom implementation of split CQR

Altough split conformalized quantile regression is available in MAPIE, it can’t be used with an arbitrary quantile regressor. However, it can easily be implemented as follows.

from typing import Union

from sklearn.base import clone


class SplitCQR(object):
    def __init__(
        self,
        regressor,
        X_train: Union[list, np.ndarray],
        y_train: np.ndarray,
        X_cal: Union[list, np.ndarray],
        y_cal: np.ndarray,
        alpha: float,
    ) -> object:
        self.alpha = alpha

        self.X_train = X_train
        self.y_train = y_train

        self.X_cal = X_cal
        self.y_cal = y_cal

        self.output_dim = y_train.shape[1]
        self.regressor = clone(regressor)
        self.__class__.__name__ = "split_cqr"

    def fit(self):
        self.regressor.fit(self.X_train, self.y_train)
        pred_cal = self.regressor.predict(self.X_cal).reshape(-1, self.output_dim, 3)

        r_lo = pred_cal[:, :, 0] - self.y_cal
        r_hi = self.y_cal - pred_cal[:, :, 2]
        self.residuals = np.max(np.stack([r_lo, r_hi], axis=2), axis=2)
        return self

    def predict(self, X_test: np.ndarray):
        n = len(self.X_cal)

        pred_test = self.regressor.predict(X_test)
        q_med, q_lo, q_hi = pred_test[:, 1::3], pred_test[:, 0::3], pred_test[:, 2::3]

        qleft = np.empty((len(X_test), self.output_dim))
        qright = np.empty((len(X_test), self.output_dim))
        for k in range(self.output_dim):
            q = np.quantile(
                self.residuals[:, k],
                (1 - self.alpha) / (1.0 + 1.0 / n),
                axis=0,
                method="higher",
            )
            qleft[:, k] = q_lo[:, k] - q
            qright[:, k] = q_hi[:, k] + q
        return q_med, qleft, qright

Apply conformal prediction using MAPIE

alpha = 0.1

algorithms = {
    "split_cv": dict(method="base", cv="split", test_size=0.1),
    "cv_plus": dict(method="plus", cv=20),
    "split_cqr": dict(test_size=0.2, alpha=alpha),
}

mlp_regressor_params = {
    "hidden_features": 100,
    "output_features": 1,
    "num_epochs": 1_000,
    "lr": 5e-3,
}

quantile_regressor_params = {
    "hidden_features": 100,
    "output_features": 3,
    "num_epochs": 1_000,
    "lr": 5e-3,
    "alpha": alpha,
}

fig = plt.figure(constrained_layout=True, figsize=(3 * 5, 1 * 4))
gs = GridSpec(nrows=1, ncols=3, figure=fig)

for i, (algorithm, params) in enumerate(algorithms.items()):
    if algorithm == "split_cqr":
        regressor = QuantileMLPRegressor(**quantile_regressor_params)
        X_train, X_cal, y_train, y_cal = train_test_split(
            X, y, test_size=params["test_size"]
        )
        model = SplitCQR(
            regressor, X_train, y_train, X_cal, y_cal, alpha=params["alpha"]
        )
        model.fit()
        y_pred, qlow, qhigh = model.predict(X_test)
    else:
        regressor = MLPRegressor(**mlp_regressor_params)
        mapie = MapieRegressor(
            regressor,
            **params,
        )
        mapie.fit(X, y.ravel())
        y_pred, y_pis = mapie.predict(X_test, alpha=alpha)
        qlow, qhigh = y_pis[:, 0, 0], y_pis[:, 1, 0]

    q2_score = r2_score(y_test, y_pred)
    avg_width = np.mean(qhigh - qlow)

    ax = fig.add_subplot(gs[i])
    plot_on_axis(
        ax,
        X_test,
        y_test,
        y_pred.ravel(),
        qlow.ravel(),
        qhigh.ravel(),
        title=f"{algorithm}, q2: {q2_score:.2f}, avg_width: {avg_width:.2f}",
    )
../../_images/57942ea3b4cdb8b725f535292872731ba5d44f1f075a11d5fb16ff3dd2cfa7e6.png
%reload_ext watermark
%watermark -n -u -v -iv -w -a 'Brian Staber'
Author: Brian Staber

Last updated: Thu Mar 28 2024

Python implementation: CPython
Python version       : 3.11.7
IPython version      : 8.21.0

numpy     : 1.26.3
optax     : 0.1.9
matplotlib: 3.8.2
jax       : 0.4.23

Watermark: 2.4.3