Training a network: Flax vs PyTorch

import flax.linen as jnn
import jax
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import numpy as np
import optax
import torch
import torch.nn as nn
import torch.optim as optim
from flax.training import train_state
from jax.flatten_util import ravel_pytree
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, TensorDataset

from pbnn.utils.data import NumpyDataset, NumpyLoader

%load_ext watermark

Generate data

def test_function(x):
    return np.sin(2.0 * x) + np.cos(x)


n = 100
noise_level = 0.1

np.random.seed(0)
X = 20 * np.random.rand(n, 1)
X_test = np.linspace(0, 20, 200)[:, None]

noise, noise_test = (
    np.random.randn(n, 1) * noise_level,
    np.random.randn(len(X_test), 1) * noise_level,
)

y = test_function(X) + noise
y_test = test_function(X_test) + noise_test

scaler = StandardScaler()
X_ = scaler.fit_transform(X)
X_test_ = scaler.transform(X_test)

Set some hyperparameters

num_epochs = 1000
step_size = 5e-3
batch_size = 32

Training with Flax

Define network

class MLP(jnn.Module):
    """Simple MLP."""

    @jnn.compact
    def __call__(self, x):
        x = jnn.Dense(features=50)(x)
        x = jnn.tanh(x)
        x = jnn.Dense(features=50)(x)
        x = jnn.tanh(x)
        x = jnn.Dense(features=1)(x)
        return x


network = MLP()

Define loss function

def loss_fn(params, batch):
    X, y = batch
    pred = network.apply({"params": params}, X)
    return jnp.mean(optax.l2_loss(y, pred))

Training pipeline

@jax.jit
def train_step(state, batch):
    grad_fn = jax.grad(loss_fn)
    grads = grad_fn(state.params, batch)
    state = state.apply_gradients(grads=grads)
    return state


rng_key = jr.PRNGKey(0)
rng_key, init_rng = jr.split(rng_key)

init_params = network.init(init_rng, X_)["params"]
tx = optax.adam(step_size)
state = train_state.TrainState.create(apply_fn=network.apply, params=init_params, tx=tx)

dataset = NumpyDataset(X_, y)
data_loader = NumpyLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)

for epoch in range(num_epochs):
    for batch in data_loader:
        state = train_step(state, batch)

pred_flax = network.apply({"params": state.params}, X_test_)
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

Training with PyTorch

Get initial parameters

Using the same initial state as the flax model

parameters, _ = ravel_pytree(init_params)
parameters = torch.tensor(np.array(parameters), dtype=torch.float32)

model = nn.Sequential(
    nn.Linear(1, 50), nn.Tanh(), nn.Linear(50, 50), nn.Tanh(), nn.Linear(50, 1)
)

with torch.no_grad():
    pointer = 0
    for param in parameters:
        num_param = param.numel()
        param.copy_(parameters[pointer : pointer + num_param].view(param.size()))
        pointer += num_param

Training pipeline

mse_fn = nn.MSELoss()

train_loader = DataLoader(
    dataset=TensorDataset(
        torch.tensor(X_, dtype=torch.float32), torch.tensor(y, dtype=torch.float32)
    ),
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
)

optimizer = optim.Adam(model.parameters(), lr=step_size)
for epoch in range(num_epochs):
    for ib, (x_batch, y_batch) in enumerate(train_loader):
        optimizer.zero_grad()
        pred = model(x_batch)
        loss = mse_fn(y_batch, pred)
        loss.backward()
        optimizer.step()

with torch.no_grad():
    pred_torch = (
        model(torch.tensor(X_test_, dtype=torch.float32)).detach().cpu().numpy()
    )

Compare predictions

plt.figure()
plt.plot(
    X,
    y,
    ls="",
    marker="o",
    markerfacecolor="b",
    markeredgecolor="k",
    markeredgewidth=1,
    label="Train data",
)
plt.plot(X_test, pred_flax, color="r", lw=2, label="Flax")
plt.plot(X_test, pred_torch, color="b", lw=2, label="PyTorch")
plt.legend(fontsize=12)
plt.xlabel(r"$x$", fontsize=12)
plt.ylabel(r"$y$", fontsize=12)
Text(0, 0.5, '$y$')
../../_images/b6fd75170a0a6c8e082a2269cc50210408b9c820eca91a3e776f97217c90ed5d.png
%reload_ext watermark
%watermark -n -u -v -iv -w -a 'Brian Staber'
Author: Brian Staber

Last updated: Sat Mar 16 2024

Python implementation: CPython
Python version       : 3.11.7
IPython version      : 8.21.0

numpy     : 1.26.3
jax       : 0.4.23
flax      : 0.8.0
optax     : 0.1.9
matplotlib: 3.8.2
torch     : 2.2.1

Watermark: 2.4.3