Training LeNet with Constrained Convolution Kernels by JAX and FLAX

Training LeNet with Constrained Convolution Kernels by JAX and FLAX#

The following code illustrates how to train LeNet with orthogonally constrained convolution kernels by JAX and FLAX.

Set-up#

We first install essential packages. If you run this example on Google colab, you need to install CDOpt and FLAX everytime you run this notebook.

! pip install cdopt --upgrade
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Requirement already satisfied: cdopt in /usr/local/lib/python3.7/dist-packages (0.3.0)
Requirement already satisfied: autograd in /usr/local/lib/python3.7/dist-packages (from cdopt) (1.4)
Requirement already satisfied: numpy>=1.16 in /usr/local/lib/python3.7/dist-packages (from cdopt) (1.21.6)
Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from cdopt) (1.4.1)
Requirement already satisfied: future>=0.15.2 in /usr/local/lib/python3.7/dist-packages (from autograd->cdopt) (0.16.0)
!pip install -q ml-collections git+https://github.com/google/flax

Import essential modules#

import jax
import jax.numpy as jnp                # JAX NumPy

from flax import linen as nn           # The Linen API
from flax.training import train_state  # Useful dataclass to keep train state

import numpy as np                     # Ordinary NumPy
import optax                           # Optimizers
import tensorflow_datasets as tfds     # TFDS for MNIST

import cdopt
from cdopt.manifold_jax import sphere_jax, stiefel_jax, euclidean_jax
from cdopt.linen import Dense_cdopt, Conv_cdopt

Creat neural network#

class CNN(nn.Module):
  """A simple CNN model."""

  @nn.compact
  def __call__(self, x):
    x, quad_penalty0 = Conv_cdopt(features=32, kernel_size=(3, 3), manifold_class = stiefel_jax)(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x, quad_penalty1 = Dense_cdopt(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    return x, quad_penalty0

Define essential components#

# Define the metric as cross entropy loss
def cross_entropy_loss(*, logits, labels):
  labels_onehot = jax.nn.one_hot(labels, num_classes=10)
  return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()

def compute_metrics(*, logits, labels, feas = 0):
  loss = cross_entropy_loss(logits=logits, labels=labels)
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  metrics = {
      'loss': loss,
      'accuracy': accuracy,
      'feas': feas
  }
  return metrics
# Set training process
def create_train_state(rng, learning_rate, momentum):
  """Creates initial `TrainState`."""
  cnn = CNN()
  params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
  tx = optax.sgd(learning_rate, momentum)
  return train_state.TrainState.create(
      apply_fn=cnn.apply, params=params, tx=tx)

@jax.jit
def train_step(state, batch):
  """Train for a single step."""
  def loss_fn(params):
    logits, quad_penalty = CNN().apply({'params': params}, batch['image'])
    loss = cross_entropy_loss(logits=logits, labels=batch['label']) + 0.05*quad_penalty
    return loss, logits
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (_, logits), grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  metrics = compute_metrics(logits=logits, labels=batch['label'])
  return state, metrics

def train_epoch(state, train_ds, batch_size, epoch, rng):
  """Train for a single epoch."""
  train_ds_size = len(train_ds['image'])
  steps_per_epoch = train_ds_size // batch_size

  perms = jax.random.permutation(rng, train_ds_size)
  perms = perms[:steps_per_epoch * batch_size]  # skip incomplete batch
  perms = perms.reshape((steps_per_epoch, batch_size))
  batch_metrics = []
  for perm in perms:
    batch = {k: v[perm, ...] for k, v in train_ds.items()}
    state, metrics = train_step(state, batch)
    batch_metrics.append(metrics)

  # compute mean of metrics across each batch in epoch.
  batch_metrics_np = jax.device_get(batch_metrics)
  epoch_metrics_np = {
      k: np.mean([metrics[k] for metrics in batch_metrics_np])
      for k in batch_metrics_np[0]}

  print('train epoch: %d, loss: %.4f, accuracy: %.2f' % (
      epoch, epoch_metrics_np['loss'], epoch_metrics_np['accuracy'] * 100))

  return state
# Set testing process
@jax.jit
def eval_step(params, batch):
  logits, quad_penalty = CNN().apply({'params': params}, batch['image'])
  return compute_metrics(logits=logits, labels=batch['label'], feas = quad_penalty)

def eval_model(params, test_ds):
  metrics = eval_step(params, test_ds)
  metrics = jax.device_get(metrics)
  summary = jax.tree_map(lambda x: x.item(), metrics)
  return summary['loss'], summary['accuracy'], summary['feas']
# Import dataset
def get_datasets():
  """Load MNIST train and test datasets into memory."""
  ds_builder = tfds.builder('mnist')
  ds_builder.download_and_prepare()
  train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
  test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
  train_ds['image'] = jnp.float32(train_ds['image']) / 255.
  test_ds['image'] = jnp.float32(test_ds['image']) / 255.
  return train_ds, test_ds


train_ds, test_ds = get_datasets()
WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/tensorflow_datasets/core/dataset_builder.py:598: get_single_element (from tensorflow.python.data.experimental.ops.get_single_element) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Dataset.get_single_element()`.
WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/tensorflow_datasets/core/dataset_builder.py:598: get_single_element (from tensorflow.python.data.experimental.ops.get_single_element) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Dataset.get_single_element()`.

Training#

rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
learning_rate = 0.05
momentum = 0.9
state = create_train_state(init_rng, learning_rate, momentum)
del init_rng  # Must not be used anymore.
num_epochs = 10
batch_size = 64
for epoch in range(1, num_epochs + 1):
  # Use a separate PRNG key to permute image data during shuffling
  rng, input_rng = jax.random.split(rng)
  # Run an optimization step over a training batch
  state = train_epoch(state, train_ds, batch_size, epoch, input_rng)
  # Evaluate on the test set after each training epoch 
  test_loss, test_accuracy, feas = eval_model(state.params, test_ds)
  print(' test epoch: %d, loss: %.2f, accuracy: %.2f, feas: %.2e' % (
      epoch, test_loss, test_accuracy * 100, feas))
train epoch: 1, loss: 0.2912, accuracy: 90.97
 test epoch: 1, loss: 0.09, accuracy: 97.14, feas: 4.37e-05
train epoch: 2, loss: 0.0841, accuracy: 97.43
 test epoch: 2, loss: 0.07, accuracy: 97.72, feas: 5.71e-05
train epoch: 3, loss: 0.0633, accuracy: 98.10
 test epoch: 3, loss: 0.06, accuracy: 98.26, feas: 3.86e-03
train epoch: 4, loss: 0.0497, accuracy: 98.44
 test epoch: 4, loss: 0.04, accuracy: 98.71, feas: 2.13e-06
train epoch: 5, loss: 0.0415, accuracy: 98.75
 test epoch: 5, loss: 0.04, accuracy: 98.89, feas: 1.25e-04
train epoch: 6, loss: 0.0362, accuracy: 98.91
 test epoch: 6, loss: 0.04, accuracy: 98.69, feas: 1.20e-04
train epoch: 7, loss: 0.0308, accuracy: 99.06
 test epoch: 7, loss: 0.04, accuracy: 98.64, feas: 2.92e-06
train epoch: 8, loss: 0.0280, accuracy: 99.12
 test epoch: 8, loss: 0.04, accuracy: 98.89, feas: 1.20e-04
train epoch: 9, loss: 0.0241, accuracy: 99.22
 test epoch: 9, loss: 0.04, accuracy: 98.77, feas: 3.83e-05
train epoch: 10, loss: 0.0209, accuracy: 99.31
 test epoch: 10, loss: 0.04, accuracy: 98.87, feas: 4.69e-04

Reference#

  1. google/flax

  2. Hu X, Xiao N, Liu X, et al. A Constraint Dissolving Approach for Nonsmooth Optimization over the Stiefel Manifold[J]. arXiv preprint arXiv:2205.10500, 2022.