Training Multi-Layer RNN with Constrained Weights#

The multi-layer RNN are developed to create higher-level abstractions and capture more non-linearities between the data. With N_layer layers, multi-layer RNN has N_layer hidden states. Here is an example for two-layer RNN,

MLRNN1.png

Similar to single-layer cases, when applied to the classification tasks on MNIST dataset, the structure of multi-layer RNN can be expressed as

MLRNN2.png

If we unfold the Multilayer RNN layer, it can be expressed as

MLRNN3.png

Import modules#

# Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os
import numpy as np

import torchvision
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
%matplotlib inline
import sklearn.metrics
import seaborn as sns
import random

from torch.nn.parameter import Parameter

# To display youtube videos
from IPython.display import YouTubeVideo
from cdopt.manifold_torch import euclidean_torch, stiefel_torch
from cdopt.nn import RNN_cdopt, get_quad_penalty


def set_seed(seed = 1234):
    '''Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY.'''
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    # Set a fixed value for the hash seed
    os.environ['PYTHONHASHSEED'] = str(seed)
    
set_seed()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device available now:', device)
Device available now: cuda

Define the network#

Define a neurnal with constrained weights are quite simple via CDOpt, we only need the following two procedures:

  1. Replace the layers in torch.nn by the layers from cdopt.utils_torch.nn and specify the manifold_class options.

  2. Add the layer.quad_penalty() to the loss function.

class MultilayerRNN_MNIST(nn.Module):
    def __init__(self, input_size, hidden_size, layer_size, output_size, relu=True):
        super(MultilayerRNN_MNIST, self).__init__()
        self.input_size, self.hidden_size, self.layer_size, self.output_size = input_size, hidden_size, layer_size, output_size
        
        # Create RNN. Users can try other manifold classes from `cdopt.manifold_torch` in the setup of `cdopt.nn.RNN_cdopt`. 
        if relu:
            self.rnn = RNN_cdopt(input_size, hidden_size, layer_size, batch_first=True, nonlinearity='relu', manifold_class = stiefel_torch, penalty_param = 0.2)
        else:
            self.rnn = RNN_cdopt(input_size, hidden_size, layer_size, batch_first=True, nonlinearity='tanh', manifold_class = stiefel_torch, penalty_param = 0.2)
            
        # Create FNN
        self.fnn = nn.Linear(hidden_size, output_size)
        
    def forward(self, images, prints=False):
        if prints: print('images shape:', images.shape)
        
        # Instantiate hidden_state at timestamp 0
        hidden_state = torch.zeros(self.layer_size, images.size(0), self.hidden_size).to(device)
        hidden_state = hidden_state.requires_grad_()
        if prints: print('Hidden State shape:', hidden_state.shape)
        
        # Compute RNN
        # .detach() is required to prevent vanishing gradient problem
        output, last_hidden_state = self.rnn(images, hidden_state.detach())
        if prints: print('RNN Output shape:', output.shape, '\n' +
                         'RNN last_hidden_state shape', last_hidden_state.shape)
        
        # Compute FNN
        # We get rid of the second size
        output = self.fnn(output[:, -1, :])
        if prints: print('FNN Output shape:', output.shape)
        
        return output

Training RNN#

We’ll use get_accuracy() and train_network() functions from the examples on Kaggle, but with some changes (suited to the RNN’s needs).

# Customized transform (transforms to tensor, here you can normalize, perform Data Augmentation etc.)
my_transform = transforms.Compose([transforms.ToTensor()])

# Download data
mnist_train = torchvision.datasets.MNIST('data', train = True, download=True, transform=my_transform)
mnist_test = torchvision.datasets.MNIST('data', train = False, download=True, transform=my_transform)
def get_accuracy(out, actual_labels, batchSize):
    '''Saves the Accuracy of the batch.
    Takes in the log probabilities, actual label and the batchSize (to average the score).'''
    predictions = out.max(dim=1)[1]
    correct = (predictions == actual_labels).sum().item()
    accuracy = correct/batch_size
    
    return accuracy
def train_network(model, train_data, test_data, batchSize=64, num_epochs=1, learning_rate=0.0005):
    
    '''Trains the model and computes the average accuracy for train and test data.'''
    
    print('Get data ready...')
    # Create dataloader for training dataset - so we can train on multiple batches
    # Shuffle after every epoch
    train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=batchSize, shuffle=True, drop_last=True)
    test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=batchSize, shuffle=True, drop_last=True)
    
    # Create criterion and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    
    print('Training started...')
    # Train the data multiple times
    for epoch in range(num_epochs):
        
        # Save Train and Test Loss
        train_loss = 0
        train_acc = 0

        
        # Set model in training mode:
        model.train()

        
        
        for k, (images, labels) in enumerate(train_loader):
            
            # Get rid of the channel
            images = images.view(-1, 28, 28)
            images = images.to(device)
            labels = labels.to(device)
            # print(labels.device)
            # Create log probabilities
            out = model(images)
            # Clears the gradients from previous iteration
            optimizer.zero_grad()
            # Computes loss: how far is the prediction from the actual?
            loss = criterion(out, labels) + get_quad_penalty(model)
            # Computes gradients for neurons
            loss.backward()
            # Updates the weights
            optimizer.step()
            
            # Save Loss & Accuracy after each iteration
            train_loss += loss.item()
            train_acc += get_accuracy(out, labels, batchSize)
            
        
        # Print Average Train Loss & Accuracy after each epoch
        print('TRAIN | Epoch: {}/{} | Loss: {:.2f} | Accuracy: {:.2f}'.format(epoch+1, num_epochs, train_loss/k, train_acc/k))
            
            
    print('Testing Started...')
    # Save Test Accuracy
    test_acc = 0
    # Evaluation mode
    model.eval()
    
    for k, (images, labels) in enumerate(test_loader):
        # Get rid of the channel
        images = images.view(-1, 28, 28)
        images = images.to(device)
        labels = labels.to(device)
        # Create logit predictions
        out = model(images)
        # Add Accuracy of this batch
        test_acc += get_accuracy(out, labels, batchSize)
        
    # Print Final Test Accuracy
    print('TEST | Average Accuracy per {} Loaders: {:.5f}'.format(k, test_acc/k) )
# ==== STATICS ====
batch_size = 64
input_size = 28
hidden_size = 100  
layer_size = 2         
output_size = 10

# Instantiate the model
# We'll use TANH as our activation function
multilayer_rnn = MultilayerRNN_MNIST(input_size, hidden_size, layer_size, output_size, relu=False)
multilayer_rnn.to(device)
# ==== TRAIN ====
train_network(multilayer_rnn, mnist_train, mnist_test, num_epochs=10)
Get data ready...
Training started...
TRAIN | Epoch: 1/10 | Loss: 0.56 | Accuracy: 0.85
TRAIN | Epoch: 2/10 | Loss: 0.24 | Accuracy: 0.93
TRAIN | Epoch: 3/10 | Loss: 0.19 | Accuracy: 0.95
TRAIN | Epoch: 4/10 | Loss: 0.16 | Accuracy: 0.95
TRAIN | Epoch: 5/10 | Loss: 0.14 | Accuracy: 0.96
TRAIN | Epoch: 6/10 | Loss: 0.13 | Accuracy: 0.96
TRAIN | Epoch: 7/10 | Loss: 0.12 | Accuracy: 0.96
TRAIN | Epoch: 8/10 | Loss: 0.12 | Accuracy: 0.96
TRAIN | Epoch: 9/10 | Loss: 0.11 | Accuracy: 0.97
TRAIN | Epoch: 10/10 | Loss: 0.11 | Accuracy: 0.97
Testing Started...
TEST | Average Accuracy per 155 Loaders: 0.97016
multilayer_rnn.rnn.quad_penalty()
tensor(0.0024, device='cuda:0', grad_fn=<AddBackward0>)

Reference#

  1. https://www.kaggle.com/code/andradaolteanu/pytorch-rnns-and-lstms-explained-acc-0-99

  2. Jing L, Gulcehre C, Peurifoy J, et al. Gated orthogonal recurrent units: On learning to forget[J]. Neural computation, 2019, 31(4): 765-783.

  3. Hu X, Xiao N, Liu X, Toh KC. A Constraint Dissolving Approach for Nonsmooth Optimization over the Stiefel Manifold[J]. arXiv preprint arXiv:2205.10500, 2022.