Training neural networks with manifold constraints#

Training deep neural networks is usually thought to be challenging both theoretically and practically, for which the vanishing/exploding gradients is one of the most important reasons. To address such issue, several recent works focus on imposing Riemannian constraints to the weights of the layers in these deep neural networks. For example, some existing works demonstrate that the orthogonal constraints can stabilize the distribution of activations over layers within convolutional neural networks and make their optimization more efficient. And they observe encouraging improvements in the accuracy and robustness of the networks with orthogonal constraints.

CDOpt supports PyTorch functions in addition to Manifold optimization. Researchers and developers can easily train neural networks with constrained weights based on the combination of CDOpt and PyTorch. Compared with existing PyTorch-based Riemannian optimization packages, CDOpt has the following features,

  • CDOpt utilizes tensor computation and GPU/TPU acceleration based on PyTorch and JAX.

  • CDOpt is compatible to all the optimizers provided in torch.optim,torch_optimizers and Optax.

  • CDOpt provides plug-in neural layers in cdopt.nn and cdopt.linen. These layers can be directly plugged in any network built by PyTorch and JAX.

Supported components#

This would be an ever increasing list of features. CDOpt currently supports:

Manifolds

  • All the manifolds in cdopt.manifold_torch and cdopt.manifold_jax.

Optimizers

  • All the optimizers from PyTorch.

  • All the optimizers from Torch-optimizer.

  • All the optimizers from Optax.

Neural layers

For PyTorch:

  • Linear layers and Bilinear layers.

  • Convolutional layers: Conv1d, Conv2d, Conv3d.

  • Recurrent Layers: RNN, LSTM, GRU, and their cells.

For JAX/FLAX:

  • Linear layers

  • Convolutional layers

Impose manifold constraints by predefined layers#

For those users that aims to train neural networks with manifold constraints, CDOpt provides various predefined neural layers in cdopt.nn and cdopt.linen modules for PyTorch and Flax, respectively. These predefined layers in CDOpt preserve the same APIs as the layers from PyTorch and Flax, hence users can plug these layers into the neural networks with minimal modification to the standard PyTorch or Flax codes.

Training by PyTorch#

cdopt.nn provides various of predefined layers for PyTorch, which inherit the same APIs as standard neural layers from torch.nn. In the instantiation of these neural layers, we need to provide the manifold_class argument to set the type of manifold constraints, use penalty_param to set the penalty parameters, and choose the weight_var_transfer argument to determine how the weights of the layers are transferred into the variables of the manifolds.

Let us start with a simple example on training neural networks with orthogonal weights. We first import essential packages.

from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR

import cdopt 
from cdopt.nn.modules import Linear_cdopt, Conv2d_cdopt
from cdopt.manifold_torch import stiefel_torch
from cdopt.nn import get_quad_penalty

Then we build the neural network, where we restrict the weights of the first FC layer on the Stiefel manifold, and set the penalty parameter as 0.02.

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3)
        self.conv2 = nn.Conv2d(32, 64, 3)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = Linear_cdopt(9216, 128, manifold_class= stiefel_torch, penalty_param = 0.02)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

Next, we define the training and testing functions. DO NOT forget to add the quadratic penalty term to the loss function by the get_quad_penalty() function from cdopt.nn.

def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target) + get_quad_penalty(model)
        # equivalent to 
        # loss = F.nll_loss(output, target) +  0.02 * model.conv1.quad_penalty()

        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

We then set the arguments and load the dataset

class ARGS():
    pass
args = ARGS()
args.batch_size = 64    
args.test_batch_size = 1000 
args.epochs = 5
args.lr = 0.5    # learning rate
args.gamma = 0.7 # weight-decay parameter
args.no_cuda = False  # whether use cuda 
args.seed = 1  # random seed for traning
args.log_interval = 200   # the interval to print trainning information
args.save_model = False   # whether to save the model
use_cuda = torch.cuda.is_available()

torch.manual_seed(1)

train_kwargs = {'batch_size': 64}
test_kwargs = {'batch_size': 1000}
if use_cuda:
    device = torch.device("cuda")
    cuda_kwargs = {'num_workers': 1,
                    'pin_memory': True,
                    'shuffle': True}
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)

transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
    ])
dataset1 = datasets.MNIST('../data', train=True, download=True,
                    transform=transform)
dataset2 = datasets.MNIST('../data', train=False,
                    transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

model = Net().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=0.1)

scheduler = StepLR(optimizer, step_size=1, gamma=0.7)

Finally, we start training the neural network.

for epoch in range(1, 11):
    train(args, model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)
    scheduler.step()

Training by JAX and FLAX#

Let us start with a simple example on training neural networks with orthogonal weights by FLAX, a neural network library developed from JAX . We first import essential packages.

import jax
import jax.numpy as jnp                # JAX NumPy

from flax import linen as nn           # The Linen API
from flax.training import train_state  # Useful dataclass to keep train state

import numpy as np                     # Ordinary NumPy
import optax                           # Optimizers
import tensorflow_datasets as tfds     # TFDS for MNIST

import cdopt
from cdopt.linen import Conv_cdopt, Dense_cdopt
from cdopt.manifold_jax import sphere_jax, stiefel_jax, euclidean_jax

Then we build the network by the neural layers from cdopt.linen.

class CNN(nn.Module):
  """A simple CNN model."""

  @nn.compact
  def __call__(self, x):
    x, quad_penalty = Conv_cdopt(features=32, kernel_size=(3, 3), manifold_class = sphere_jax)(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    return x, quad_penalty

Then we define the cross entropy loss and metrics

def cross_entropy_loss(*, logits, labels):
  labels_onehot = jax.nn.one_hot(labels, num_classes=10)
  return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()
  
def compute_metrics(*, logits, labels, feas = 0):
  loss = cross_entropy_loss(logits=logits, labels=labels)
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  metrics = {
      'loss': loss,
      'accuracy': accuracy,
      'feas': feas
  }
  return metrics

Then we define how to train the network by utilizing the train_state class provided in FLAX,

def create_train_state(rng, learning_rate, momentum):
  """Creates initial `TrainState`."""
  cnn = CNN()
  params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
  tx = optax.sgd(learning_rate, momentum)
  return train_state.TrainState.create(
      apply_fn=cnn.apply, params=params, tx=tx)
      
      
@jax.jit
def train_step(state, batch):
  """Train for a single step."""
  def loss_fn(params):
    logits, quad_penalty = CNN().apply({'params': params}, batch['image'])
    loss = cross_entropy_loss(logits=logits, labels=batch['label']) + 0.05*quad_penalty
    return loss, logits
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (_, logits), grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  metrics = compute_metrics(logits=logits, labels=batch['label'])
  return state, metrics
  
def train_epoch(state, train_ds, batch_size, epoch, rng):
  """Train for a single epoch."""
  train_ds_size = len(train_ds['image'])
  steps_per_epoch = train_ds_size // batch_size

  perms = jax.random.permutation(rng, train_ds_size)
  perms = perms[:steps_per_epoch * batch_size]  # skip incomplete batch
  perms = perms.reshape((steps_per_epoch, batch_size))
  batch_metrics = []
  for perm in perms:
    batch = {k: v[perm, ...] for k, v in train_ds.items()}
    state, metrics = train_step(state, batch)
    batch_metrics.append(metrics)

  # compute mean of metrics across each batch in epoch.
  batch_metrics_np = jax.device_get(batch_metrics)
  epoch_metrics_np = {
      k: np.mean([metrics[k] for metrics in batch_metrics_np])
      for k in batch_metrics_np[0]}

  print('train epoch: %d, loss: %.4f, accuracy: %.2f' % (
      epoch, epoch_metrics_np['loss'], epoch_metrics_np['accuracy'] * 100))

  return state

Then we define the test steps,

@jax.jit
def eval_step(params, batch):
  logits, quad_penalty = CNN().apply({'params': params}, batch['image'])
  return compute_metrics(logits=logits, labels=batch['label'], feas = quad_penalty)
  
  
def eval_model(params, test_ds):
  metrics = eval_step(params, test_ds)
  metrics = jax.device_get(metrics)
  summary = jax.tree_map(lambda x: x.item(), metrics)
  return summary['loss'], summary['accuracy'], summary['feas']

Next, we load the dataset by Tensorflow,

def get_datasets():
  """Load MNIST train and test datasets into memory."""
  ds_builder = tfds.builder('mnist')
  ds_builder.download_and_prepare()
  train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
  test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
  train_ds['image'] = jnp.float32(train_ds['image']) / 255.
  test_ds['image'] = jnp.float32(test_ds['image']) / 255.
  return train_ds, test_ds
  
train_ds, test_ds = get_datasets()

Finally, we set the arguments and start the training,

rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)

learning_rate = 0.05
momentum = 0.9

state = create_train_state(init_rng, learning_rate, momentum)
num_epochs = 10
batch_size = 64
for epoch in range(1, num_epochs + 1):
  # Use a separate PRNG key to permute image data during shuffling
  rng, input_rng = jax.random.split(rng)
  # Run an optimization step over a training batch
  state = train_epoch(state, train_ds, batch_size, epoch, input_rng)
  # Evaluate on the test set after each training epoch 
  test_loss, test_accuracy, feas = eval_model(state.params, test_ds)
  print(' test epoch: %d, loss: %.2f, accuracy: %.2f, feas: %.2e' % (
      epoch, test_loss, test_accuracy * 100, feas))

Impose manifold constraints by set_constraint_dissolving()#

Furthermore, for those neural layers that are not predefined in cdopt.nn, CDOpt provides a simple way to add manifold constraints to the parameters of these neural layers. Through the set_constraint_dissolving function from cdopt.nn.utils.set\_constraints, users can set the manifold constraints to the layers by just providing the neural layers, the name of target parameters and the manifold class. The following example illustrates how to set the manifold constraints to the first full connect layer for LeNet.

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(256, 120)  # 5*5 from image dimension 
        set_constraint_dissolving(self.fc1, 'weight', manifold_class = stiefel_torch, penalty_param= 0.02)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # If the size is a square, you can specify with a single number
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = torch.flatten(x, 1) # flatten all dimensions except the batch dimension
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        x = F.log_softmax(x, dim=1)
        return x

Functional API for modules in PyTorch#

PyTorch introduces a new feature to functionally apply Module computation with a given set of parameters. Sometimes, the traditional PyTorch Module usage pattern that maintains a static set of parameters internally is too restrictive. This is often the case when implementing algorithms for meta-learning, where multiple sets of parameters may need to be maintained across optimizer steps. Based on the functions fromtorch.nn.utils.stateless, we develop functions from cdopt.nn.utils.stateless, which allows the

  • Module/feasibility computation with full flexibility over the set of parameters used

  • No need to reimplement your module in a functional way

  • Any parameter or buffer present in the module can be swapped with an externally-defined value for use in the call. Naming for referencing parameters / buffers follows the fully-qualified form in the module’s state_dict()

Here is an simple example:

import torch
import cdopt
from torch import nn
from cdopt.manifold_torch import stiefel_torch
from cdopt.nn.utils.stateless import functional_call, get_quad_penalty_call, functional_quad_penalty_call

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = cdopt.nn.Linear_cdopt(3, 3, manifold_class= stiefel_torch, penalty_param=0.1)
        self.bn = nn.BatchNorm1d(3)
        self.fc2 = nn.Linear(3, 3)

    def forward(self, x):
        return self.fc2(self.bn(self.fc1(x)))

m = MyModule()

# Define parameter / buffer values to use during module computation.
my_weight = torch.randn(3, 3, requires_grad=True)
my_bias = torch.tensor([1., 2., 3.], requires_grad=True)
params_and_buffers = {
    'fc1.weight': my_weight,
    'fc1.bias': my_bias,
    # Custom buffer values can be used too.
    'bn.running_mean': torch.randn(3),
}

# Apply module computation to the input with the specified parameters / buffers.
inp = torch.randn(5, 3)
output1 = functional_call(m, params_and_buffers, inp)
quad_penalty1 = get_quad_penalty_call(m,params_and_buffers)
output2, quad_penalty2 = functional_quad_penalty_call(m, params_and_buffers, inp)