Training LeNet with Constrained Convolution Kernels by JAX and FLAX#

The following code illustrates how to train LeNet with orthogonally constrained convolution kernels by JAX and FLAX.


We first install essential packages. If you run this example on Google colab, you need to install CDOpt and FLAX everytime you run this notebook.

! pip install cdopt --upgrade
!pip install -q ml-collections git+

Import essential modules#

import jax
import jax.numpy as jnp                # JAX NumPy

from flax import linen as nn           # The Linen API
from import train_state  # Useful dataclass to keep train state

import numpy as np                     # Ordinary NumPy
import optax                           # Optimizers
import tensorflow_datasets as tfds     # TFDS for MNIST

import cdopt
from cdopt.manifold_jax import sphere_jax, stiefel_jax, euclidean_jax
from cdopt.linen import Dense_cdopt, Conv_cdopt

Creat neural network#

class CNN(nn.Module):
  """A simple CNN model."""

  def __call__(self, x):
    x, quad_penalty0 = Conv_cdopt(features=32, kernel_size=(3, 3), manifold_class = stiefel_jax)(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x, quad_penalty1 = Dense_cdopt(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    return x, quad_penalty0

Define essential components#

# Define the metric as cross entropy loss
def cross_entropy_loss(*, logits, labels):
  labels_onehot = jax.nn.one_hot(labels, num_classes=10)
  return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()

def compute_metrics(*, logits, labels, feas = 0):
  loss = cross_entropy_loss(logits=logits, labels=labels)
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  metrics = {
      'loss': loss,
      'accuracy': accuracy,
      'feas': feas
  return metrics
# Set training process
def create_train_state(rng, learning_rate, momentum):
  """Creates initial `TrainState`."""
  cnn = CNN()
  params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
  tx = optax.sgd(learning_rate, momentum)
  return train_state.TrainState.create(
      apply_fn=cnn.apply, params=params, tx=tx)

def train_step(state, batch):
  """Train for a single step."""
  def loss_fn(params):
    logits, quad_penalty = CNN().apply({'params': params}, batch['image'])
    loss = cross_entropy_loss(logits=logits, labels=batch['label']) + 0.05*quad_penalty
    return loss, logits
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (_, logits), grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  metrics = compute_metrics(logits=logits, labels=batch['label'])
  return state, metrics

def train_epoch(state, train_ds, batch_size, epoch, rng):
  """Train for a single epoch."""
  train_ds_size = len(train_ds['image'])
  steps_per_epoch = train_ds_size // batch_size

  perms = jax.random.permutation(rng, train_ds_size)
  perms = perms[:steps_per_epoch * batch_size]  # skip incomplete batch
  perms = perms.reshape((steps_per_epoch, batch_size))
  batch_metrics = []
  for perm in perms:
    batch = {k: v[perm, ...] for k, v in train_ds.items()}
    state, metrics = train_step(state, batch)

  # compute mean of metrics across each batch in epoch.
  batch_metrics_np = jax.device_get(batch_metrics)
  epoch_metrics_np = {
      k: np.mean([metrics[k] for metrics in batch_metrics_np])
      for k in batch_metrics_np[0]}

  print('train epoch: %d, loss: %.4f, accuracy: %.2f' % (
      epoch, epoch_metrics_np['loss'], epoch_metrics_np['accuracy'] * 100))

  return state
# Set testing process
def eval_step(params, batch):
  logits, quad_penalty = CNN().apply({'params': params}, batch['image'])
  return compute_metrics(logits=logits, labels=batch['label'], feas = quad_penalty)

def eval_model(params, test_ds):
  metrics = eval_step(params, test_ds)
  metrics = jax.device_get(metrics)
  summary = jax.tree_map(lambda x: x.item(), metrics)
  return summary['loss'], summary['accuracy'], summary['feas']
# Import dataset
def get_datasets():
  """Load MNIST train and test datasets into memory."""
  ds_builder = tfds.builder('mnist')
  train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
  test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
  train_ds['image'] = jnp.float32(train_ds['image']) / 255.
  test_ds['image'] = jnp.float32(test_ds['image']) / 255.
  return train_ds, test_ds

train_ds, test_ds = get_datasets()
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
learning_rate = 0.05
momentum = 0.9
state = create_train_state(init_rng, learning_rate, momentum)
del init_rng  # Must not be used anymore.
num_epochs = 10
batch_size = 64
for epoch in range(1, num_epochs + 1):
  # Use a separate PRNG key to permute image data during shuffling
  rng, input_rng = jax.random.split(rng)
  # Run an optimization step over a training batch
  state = train_epoch(state, train_ds, batch_size, epoch, input_rng)
  # Evaluate on the test set after each training epoch 
  test_loss, test_accuracy, feas = eval_model(state.params, test_ds)
  print(' test epoch: %d, loss: %.2f, accuracy: %.2f, feas: %.2e' % (
      epoch, test_loss, test_accuracy * 100, feas))
train epoch: 1, loss: 0.2912, accuracy: 90.97
 test epoch: 1, loss: 0.09, accuracy: 97.14, feas: 4.37e-05
train epoch: 2, loss: 0.0841, accuracy: 97.43
 test epoch: 2, loss: 0.07, accuracy: 97.72, feas: 5.71e-05
train epoch: 3, loss: 0.0633, accuracy: 98.10
 test epoch: 3, loss: 0.06, accuracy: 98.26, feas: 3.86e-03
train epoch: 4, loss: 0.0497, accuracy: 98.44
 test epoch: 4, loss: 0.04, accuracy: 98.71, feas: 2.13e-06
train epoch: 5, loss: 0.0415, accuracy: 98.75
 test epoch: 5, loss: 0.04, accuracy: 98.89, feas: 1.25e-04
train epoch: 6, loss: 0.0362, accuracy: 98.91
 test epoch: 6, loss: 0.04, accuracy: 98.69, feas: 1.20e-04
train epoch: 7, loss: 0.0308, accuracy: 99.06
 test epoch: 7, loss: 0.04, accuracy: 98.64, feas: 2.92e-06
train epoch: 8, loss: 0.0280, accuracy: 99.12
 test epoch: 8, loss: 0.04, accuracy: 98.89, feas: 1.20e-04
train epoch: 9, loss: 0.0241, accuracy: 99.22
 test epoch: 9, loss: 0.04, accuracy: 98.77, feas: 3.83e-05
train epoch: 10, loss: 0.0209, accuracy: 99.31
 test epoch: 10, loss: 0.04, accuracy: 98.87, feas: 4.69e-04



